Source code for tdc.generation.retrosyn

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT

import warnings


from . import generation_dataset
from ..metadata import dataset_names
from ..utils import create_fold

[docs]class RetroSyn(generation_dataset.PairedDataLoader): """Data loader class accessing to retro-synthetic prediction task.""" def __init__( self, name, path="./data", print_stats=False, input_name="product", output_name="reactant", ): """To create an data loader object for forward reaction prediction task. The goal is to predict the reaction products given a set of reactants Args: name (str): the name of the datset path (str, optional): the path to the saved data file. print_stats (bool, optional): whether to print the basic statistics input_name (str, optional): the name of the column containing input molecular data (product) output_name (str, optional): the name of the column containing output molecular data (reactant) """ super().__init__(name, path, print_stats, input_name, output_name)
[docs] def get_split( self, method="random", seed=42, frac=[0.7, 0.1, 0.2], include_reaction_type=False, ): """Return the data splitted as train, valid, test sets. Arguments: method (str): splitting schemes: random, scaffold seed (int): random seed, default 42 frac (list of float): ratio of train/val/test split include_reaction_type (bool): whether or not to include reaction type in the split Returns: pandas DataFrame/dict: a dataframe of the dataset Raises: AttributeError: Use the correct split method as input (random, scaffold) """ df = self.get_data(format="df") if include_reaction_type: from ..utils import get_reaction_type try: rt = get_reaction_type( df["reaction_type"] = rt except: raise ValueError( "Reaction Type Unavailable for " + str( + "! Please turn include_reaction_type to be false!" ) if method == "random": return create_fold(df, seed, frac) else: raise AttributeError("Please use the correct split method")