# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT
import pandas as pd
import numpy as np
import os, sys, json
import warnings
warnings.filterwarnings("ignore")
from .. import base_dataset
from ..utils import dataset2target_lists, \
property_dataset_load,\
create_fold, \
create_fold_setting_cold,\
create_scaffold_split, \
print_sys
[docs]class DataLoader(base_dataset.DataLoader):
"""A base data loader class.
Args:
name (str): the dataset name.
path (str): The path to save the data file
label_name (str): For multi-label dataset, specify the label name
print_stats (bool): Whether to print basic statistics of the dataset
dataset_names (list): A list of dataset names available for a task
convert_format (str): Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe
Attributes:
convert_format (str): conversion format of an entity
convert_result (list): a placeholder for a list of conversion outputs
entity1 (Pandas Series): a list of the single entites
entity1_idx (Pandas Series): a list of the single entites index
entity1_name (Pandas Series): a list of the single entites names
file_format (str): the format of the downloaded dataset
label_name (str): for multi-label dataset, the label name of interest
name (str): dataset name
path (str): path to save and retrieve the dataset
y (Pandas Series): a list of the single entities label
"""
def __init__(self, name, path, label_name, print_stats, dataset_names, convert_format, raw_format = 'SMILES'):
"""Create a base dataloader object that each single instance prediction task dataloader class can inherit from.
Raises:
ValueError: for a dataset with multiple labels, specify the label. Use tdc.utils.retrieve_label_name_list to see the available label names
"""
if name.lower() in dataset2target_lists.keys():
if label_name is None:
raise ValueError("Please select a label name. You can use tdc.utils.retrieve_label_name_list('" + name.lower() + "') to retrieve all available label names.")
entity1, y, entity1_idx = property_dataset_load(name, path, label_name, dataset_names)
self.entity1 = entity1
self.y = y
self.entity1_idx = entity1_idx
self.name = name
self.entity1_name = 'Drug'
self.path = path
self.file_format = 'csv'
self.label_name = label_name
self.convert_format = convert_format
self.convert_result = None
self.raw_format = raw_format ### 'SMILES' for most data, 'Raw3D' for QM9, ...
[docs] def get_data(self, format = 'df'):
'''
Arguments:
format (str, optional): the returning dataset format, defaults to 'df'
Returns:
pandas DataFrame/dict: a dataframe of a dataset/a dictionary for key information in the dataset
Raises:
AttributeError: Use the correct format input (df, dict, DeepPurpose)
'''
if (self.convert_format is not None) and (self.convert_result is None):
from ..chem_utils import MolConvert
converter = MolConvert(src = self.raw_format, dst = self.convert_format)
convert_result = converter(self.entity1.values)
self.convert_result = [i for i in convert_result]
if format == 'df':
if self.convert_format is not None:
return pd.DataFrame({self.entity1_name + '_ID': self.entity1_idx, self.entity1_name: self.entity1, self.entity1_name + '_' + self.convert_format: self.convert_result, 'Y': self.y})
else:
return pd.DataFrame({self.entity1_name + '_ID': self.entity1_idx, self.entity1_name: self.entity1, 'Y': self.y})
elif format == 'dict':
if self.convert_format is not None:
return {self.entity1_name + '_ID': self.entity1_idx.values, self.entity1_name: self.entity1.values, self.entity1_name + '_' + self.convert_format: self.convert_result, 'Y': self.y.values}
else:
return {self.entity1_name + '_ID': self.entity1_idx.values, self.entity1_name: self.entity1.values, 'Y': self.y.values}
elif format == 'DeepPurpose':
return self.entity1.values, self.y.values
else:
raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method = 'random', seed = 42, frac = [0.7, 0.1, 0.2]):
'''
Arguments:
method: splitting schemes, choose from random, cold_{entity}, scaffold, defaults to 'random'
seed: the random seed for splitting dataset, defaults to '42'
frac: train/val/test split fractions, defaults to '[0.7, 0.1, 0.2]'
Returns:
dict: a dictionary with three keys ('train', 'valid', 'test'), each value is a pandas dataframe object of the splitted dataset
Raises:
AttributeError: the input split method is not available.
'''
df = self.get_data(format = 'df')
if method == 'random':
return create_fold(df, seed, frac)
elif method == 'cold_' + self.entity1_name.lower():
return create_fold_setting_cold(df, seed, frac, self.entity1_name)
elif method == 'scaffold':
return create_scaffold_split(df, seed, frac, self.entity1_name)
else:
raise AttributeError("Please specify the correct splitting method")
[docs] def print_stats(self):
"""Print basic data statistics.
"""
print_sys('--- Dataset Statistics ---')
try:
x = np.unique(self.entity1)
except:
x = np.unique(self.entity1_idx)
print(str(len(x)) + ' unique ' + self.entity1_name.lower() + 's.', flush = True, file = sys.stderr)
print_sys('--------------------------')