Source code for tdc.single_pred.single_pred_dataset

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT

import pandas as pd
import numpy as np
import os, sys, json 
import warnings
warnings.filterwarnings("ignore")

from .. import base_dataset
from ..utils import dataset2target_lists, \
					property_dataset_load,\
					create_fold, \
					create_fold_setting_cold,\
					create_scaffold_split, \
					print_sys

[docs]class DataLoader(base_dataset.DataLoader): """A base data loader class. Args: name (str): the dataset name. path (str): The path to save the data file label_name (str): For multi-label dataset, specify the label name print_stats (bool): Whether to print basic statistics of the dataset dataset_names (list): A list of dataset names available for a task convert_format (str): Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe Attributes: convert_format (str): conversion format of an entity convert_result (list): a placeholder for a list of conversion outputs entity1 (Pandas Series): a list of the single entites entity1_idx (Pandas Series): a list of the single entites index entity1_name (Pandas Series): a list of the single entites names file_format (str): the format of the downloaded dataset label_name (str): for multi-label dataset, the label name of interest name (str): dataset name path (str): path to save and retrieve the dataset y (Pandas Series): a list of the single entities label """ def __init__(self, name, path, label_name, print_stats, dataset_names, convert_format, raw_format = 'SMILES'): """Create a base dataloader object that each single instance prediction task dataloader class can inherit from. Raises: ValueError: for a dataset with multiple labels, specify the label. Use tdc.utils.retrieve_label_name_list to see the available label names """ if name.lower() in dataset2target_lists.keys(): if label_name is None: raise ValueError("Please select a label name. You can use tdc.utils.retrieve_label_name_list('" + name.lower() + "') to retrieve all available label names.") entity1, y, entity1_idx = property_dataset_load(name, path, label_name, dataset_names) self.entity1 = entity1 self.y = y self.entity1_idx = entity1_idx self.name = name self.entity1_name = 'Drug' self.path = path self.file_format = 'csv' self.label_name = label_name self.convert_format = convert_format self.convert_result = None self.raw_format = raw_format ### 'SMILES' for most data, 'Raw3D' for QM9, ...
[docs] def get_data(self, format = 'df'): ''' Arguments: format (str, optional): the returning dataset format, defaults to 'df' Returns: pandas DataFrame/dict: a dataframe of a dataset/a dictionary for key information in the dataset Raises: AttributeError: Use the correct format input (df, dict, DeepPurpose) ''' if (self.convert_format is not None) and (self.convert_result is None): from ..chem_utils import MolConvert converter = MolConvert(src = self.raw_format, dst = self.convert_format) convert_result = converter(self.entity1.values) self.convert_result = [i for i in convert_result] if format == 'df': if self.convert_format is not None: return pd.DataFrame({self.entity1_name + '_ID': self.entity1_idx, self.entity1_name: self.entity1, self.entity1_name + '_' + self.convert_format: self.convert_result, 'Y': self.y}) else: return pd.DataFrame({self.entity1_name + '_ID': self.entity1_idx, self.entity1_name: self.entity1, 'Y': self.y}) elif format == 'dict': if self.convert_format is not None: return {self.entity1_name + '_ID': self.entity1_idx.values, self.entity1_name: self.entity1.values, self.entity1_name + '_' + self.convert_format: self.convert_result, 'Y': self.y.values} else: return {self.entity1_name + '_ID': self.entity1_idx.values, self.entity1_name: self.entity1.values, 'Y': self.y.values} elif format == 'DeepPurpose': return self.entity1.values, self.y.values else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method = 'random', seed = 42, frac = [0.7, 0.1, 0.2]): ''' Arguments: method: splitting schemes, choose from random, cold_{entity}, scaffold, defaults to 'random' seed: the random seed for splitting dataset, defaults to '42' frac: train/val/test split fractions, defaults to '[0.7, 0.1, 0.2]' Returns: dict: a dictionary with three keys ('train', 'valid', 'test'), each value is a pandas dataframe object of the splitted dataset Raises: AttributeError: the input split method is not available. ''' df = self.get_data(format = 'df') if method == 'random': return create_fold(df, seed, frac) elif method == 'cold_' + self.entity1_name.lower(): return create_fold_setting_cold(df, seed, frac, self.entity1_name) elif method == 'scaffold': return create_scaffold_split(df, seed, frac, self.entity1_name) else: raise AttributeError("Please specify the correct splitting method")
[docs] def print_stats(self): """Print basic data statistics. """ print_sys('--- Dataset Statistics ---') try: x = np.unique(self.entity1) except: x = np.unique(self.entity1_idx) print(str(len(x)) + ' unique ' + self.entity1_name.lower() + 's.', flush = True, file = sys.stderr) print_sys('--------------------------')