Source code for tdc.generation.generation_dataset

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT

import pandas as pd
import numpy as np
import os, sys, json
import warnings

warnings.filterwarnings("ignore")

from .. import base_dataset
from ..utils import (
    distribution_dataset_load,
    generation_paired_dataset_load,
    three_dim_dataset_load,
    print_sys,
)
from ..utils import create_fold


[docs]class DataLoader(base_dataset.DataLoader): """A base dataset loader class. Attributes: dataset_names (str): name of the dataset. name (str): The name fo the dataset. path (str): the path to save the data file. smiles_lst (list): a list of smiles strings as training data for distribution learning. """ def __init__(self, name, path, print_stats, column_name): """To create a base dataloader object that each generation task can inherit from. Args: name (str): the name of the dataset. path (str): the path to save the data file. print_stats (bool): whether to print the basic statistics of the dataset. column_name (str): The name of the column containing smiles strings. """ from ..metadata import single_molecule_dataset_names self.smiles_lst = distribution_dataset_load( name, path, single_molecule_dataset_names, column_name=column_name ) ### including fuzzy-search self.name = name self.path = path self.dataset_names = single_molecule_dataset_names if print_stats: self.print_stats() print_sys("Done!")
[docs] def print_stats(self): """Print the basic statistics of the dataset.""" print( "There are " + str(len(self.smiles_lst)) + " molecules ", flush=True, file=sys.stderr, )
[docs] def get_data(self, format="df"): """Return the data from the whole dataset. Args: format (str, optional): the desired format for molecular data. Returns: pandas DataFrame/dict: a dataframe of the dataset/a distionary for information Raises: AttributeError: Use the correct format as input (df, dict) """ if format == "df": return pd.DataFrame({"smiles": self.smiles_lst}) elif format == "dict": return {"smiles": self.smiles_lst} else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): """Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) """ df = self.get_data(format="df") if method == "random": return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")
[docs]class PairedDataLoader(base_dataset.DataLoader): """A basic class for generation of biomedical entities conditioned on other entities, such as reaction prediction. Attributes: dataset_names (str): the name fo the dataset. name (str): the name of the dataset. path (str): the path to save the data file. """ def __init__(self, name, path, print_stats, input_name, output_name): """To create a object for paired biomedical entities generation. Arguments: name (str): fuzzy name of the generation dataset. e.g., uspto50k, qed, drd, ... path (str): directory path that stores the dataset, e.g., ./data print_stats (bool): whether print the stats. input_name (str): The column name of input biomedical entities. output_name (str): The column name of output biomedical entities. """ from ..metadata import paired_dataset_names self.input_smiles_lst, self.output_smiles_lst = generation_paired_dataset_load( name, path, paired_dataset_names, input_name, output_name ) ### including fuzzy-search self.name = name self.path = path self.dataset_names = paired_dataset_names if print_stats: self.print_stats() print_sys("Done!")
[docs] def print_stats(self): """Print the statistics of the dataset.""" print( "There are " + str(len(self.input_smiles_lst)) + " paired samples", flush=True, file=sys.stderr, )
[docs] def get_data(self, format="df"): """Return the data from the whole dataset. Args: format (str, optional): the desired format for molecular data. Returns: pandas DataFrame/dict: a dataframe of the dataset/a distionary for information Raises: AttributeError: Use the correct format as input (df, dict) """ if format == "df": return pd.DataFrame( {"input": self.input_smiles_lst, "output": self.output_smiles_lst} ) elif format == "dict": return {"input": self.input_smiles_lst, "output": self.output_smiles_lst} else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): """Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) """ df = self.get_data(format="df") if method == "random": return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")
[docs]class DataLoader3D(base_dataset.DataLoader): """A basic class for generation of 3D biomedical entities. (under construction) Attributes: df (str): the dataset in pandas DataFrame format. name (str): the name of the dataset. path (str): the path to save the data file. """ ### locally, unzip a folder, with the main file the dataframe with SMILES, Mol Object for various kinds of entities. ### also, for each column, contains a sdf file. def __init__(self, name, path, print_stats, dataset_names, column_name): """To create an object for 3D biomedical entities generation. Args: name (str): the name of the dataset. path (str): the path to save the data file. print_stats (bool): whether to print the basic statistics of the dataset. column_name (str): The name of the column containing smiles strings. """ self.df, self.path, self.name = three_dim_dataset_load( name, path, dataset_names ) if print_stats: self.print_stats() print_sys("Done!")
[docs] def print_stats(self): """Print the basic statistics of the dataset.""" print( "There are " + str(len(self.df)) + " data points ", flush=True, file=sys.stderr, )
[docs] def get_data(self, format="df", more_features="None"): """Return the data from the whole dataset. Args: format (str, optional): the desired format for molecular data. more_features (str, optional): 3D feature format, choose from [Graph3D, Coulumb] Returns: pandas DataFrame/dict: a dataframe of the dataset/a distionary for information Raises: AttributeError: Use the correct format as input (df, dict) ImportError: Please install rdkit by 'conda install -c conda-forge rdkit' """ if more_features in ["None", "SMILES"]: pass elif more_features in ["Graph3D", "Coulumb", "SELFIES"]: # why SELFIES here? try: from rdkit.Chem.PandasTools import LoadSDF from rdkit import rdBase rdBase.DisableLog("rdApp.error") except: raise ImportError( "Please install rdkit by 'conda install -c conda-forge rdkit'! " ) from ..chem_utils import MolConvert from ..metadata import sdf_file_names convert = MolConvert(src="SDF", dst=more_features) for i in sdf_file_names[self.name]: self.df[i + "_" + more_features] = convert(self.path + i + ".sdf") if format == "df": return self.df else: raise AttributeError("Please use the correct format input")
[docs] def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): """Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) """ df = self.get_data(format="df") if method == "random": return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")