Source code for tdc.generation.generation_dataset

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT

import pandas as pd
import numpy as np
import os, sys, json 
import warnings
warnings.filterwarnings("ignore")

from .. import base_dataset
from ..utils import distribution_dataset_load, generation_paired_dataset_load, three_dim_dataset_load, print_sys
from ..utils import create_fold

[docs]class DataLoader(base_dataset.DataLoader): """A base dataset loader class. Attributes: dataset_names (str): name of the dataset. name (str): The name fo the dataset. path (str): the path to save the data file. smiles_lst (list): a list of smiles strings as training data for distribution learning. """ def __init__(self, name, path, print_stats, column_name): """To create a base dataloader object that each generation task can inherit from. Args: name (str): the name of the dataset. path (str): the path to save the data file. print_stats (bool): whether to print the basic statistics of the dataset. column_name (str): The name of the column containing smiles strings. """ from ..metadata import single_molecule_dataset_names self.smiles_lst = distribution_dataset_load(name, path, single_molecule_dataset_names, column_name = column_name) ### including fuzzy-search self.name = name self.path = path self.dataset_names = single_molecule_dataset_names if print_stats: self.print_stats() print_sys('Done!')
[docs] def print_stats(self): """Print the basic statistics of the dataset. """ print("There are " + str(len(self.smiles_lst)) + ' molecules ', flush = True, file = sys.stderr)
[docs] def get_data(self, format = 'df'): """Return the data from the whole dataset. Args: format (str, optional): the desired format for molecular data. Returns: pandas DataFrame/dict: a dataframe of the dataset/a distionary for information Raises: AttributeError: Use the correct format as input (df, dict) """ if format == 'df': return pd.DataFrame({'smiles': self.smiles_lst}) elif format == 'dict': return {'smiles': self.smiles_lst} else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method = 'random', seed = 42, frac = [0.7, 0.1, 0.2]): '''Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) ''' df = self.get_data(format = 'df') if method == 'random': return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")
[docs]class PairedDataLoader(base_dataset.DataLoader): """A basic class for generation of biomedical entities conditioned on other entities, such as reaction prediction. Attributes: dataset_names (str): the name fo the dataset. name (str): the name of the dataset. path (str): the path to save the data file. """ def __init__(self, name, path, print_stats, input_name, output_name): '''To create a object for paired biomedical entities generation. Arguments: name (str): fuzzy name of the generation dataset. e.g., uspto50k, qed, drd, ... path (str): directory path that stores the dataset, e.g., ./data print_stats (bool): whether print the stats. input_name (str): The column name of input biomedical entities. output_name (str): The column name of output biomedical entities. ''' from ..metadata import paired_dataset_names self.input_smiles_lst, self.output_smiles_lst = generation_paired_dataset_load(name, path, paired_dataset_names, input_name, output_name) ### including fuzzy-search self.name = name self.path = path self.dataset_names = paired_dataset_names if print_stats: self.print_stats() print_sys('Done!')
[docs] def print_stats(self): """Print the statistics of the dataset. """ print("There are " + str(len(self.input_smiles_lst)) + ' paired samples', flush = True, file = sys.stderr)
[docs] def get_data(self, format = 'df'): """Return the data from the whole dataset. Args: format (str, optional): the desired format for molecular data. Returns: pandas DataFrame/dict: a dataframe of the dataset/a distionary for information Raises: AttributeError: Use the correct format as input (df, dict) """ if format == 'df': return pd.DataFrame({'input': self.input_smiles_lst, 'output':self.output_smiles_lst}) elif format == 'dict': return {'input': self.input_smiles_lst, 'output':self.output_smiles_lst} else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method = 'random', seed = 42, frac = [0.7, 0.1, 0.2]): '''Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) ''' df = self.get_data(format = 'df') if method == 'random': return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")
[docs]class DataLoader3D(base_dataset.DataLoader): """A basic class for generation of 3D biomedical entities. (under construction) Attributes: df (str): the dataset in pandas DataFrame format. name (str): the name of the dataset. path (str): the path to save the data file. """ ### locally, unzip a folder, with the main file the dataframe with SMILES, Mol Object for various kinds of entities. ### also, for each column, contains a sdf file. def __init__(self, name, path, print_stats, dataset_names, column_name): """To create an object for 3D biomedical entities generation. Args: name (str): the name of the dataset. path (str): the path to save the data file. print_stats (bool): whether to print the basic statistics of the dataset. column_name (str): The name of the column containing smiles strings. """ self.df, self.path, self.name = three_dim_dataset_load(name, path, dataset_names) if print_stats: self.print_stats() print_sys('Done!')
[docs] def print_stats(self): """Print the basic statistics of the dataset. """ print("There are " + str(len(self.df)) + ' data points ', flush = True, file = sys.stderr)
[docs] def get_data(self, format = 'df', more_features = 'None'): """Return the data from the whole dataset. Args: format (str, optional): the desired format for molecular data. more_features (str, optional): 3D feature format, choose from [Graph3D, Coulumb] Returns: pandas DataFrame/dict: a dataframe of the dataset/a distionary for information Raises: AttributeError: Use the correct format as input (df, dict) ImportError: Please install rdkit by 'conda install -c conda-forge rdkit' """ if more_features in ['None', 'SMILES']: pass elif more_features in ['Graph3D', 'Coulumb', 'SELFIES']: # why SELFIES here? try: from rdkit.Chem.PandasTools import LoadSDF from rdkit import rdBase rdBase.DisableLog('rdApp.error') except: raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") from ..chem_utils import MolConvert from ..metadata import sdf_file_names convert = MolConvert(src = 'SDF', dst = more_features) for i in sdf_file_names[self.name]: self.df[i + '_' + more_features] = convert(self.path + i + '.sdf') if format == 'df': return self.df else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method = 'random', seed = 42, frac = [0.7, 0.1, 0.2]): '''Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) ''' df = self.get_data(format = 'df') if method == 'random': return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")