Source code for tdc.utils.split

"""Utilities functions for splitting dataset 
"""
import os, sys
import numpy as np
import pandas as pd
from tqdm import tqdm

[docs]def create_fold(df, fold_seed, frac): """create random split Args: df (pd.DataFrame): dataset dataframe fold_seed (int): the random seed frac (list): a list of train/valid/test fractions Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ train_frac, val_frac, test_frac = frac test = df.sample(frac = test_frac, replace = False, random_state = fold_seed) train_val = df[~df.index.isin(test.index)] val = train_val.sample(frac = val_frac/(1-test_frac), replace = False, random_state = 1) train = train_val[~train_val.index.isin(val.index)] return {'train': train.reset_index(drop = True), 'valid': val.reset_index(drop = True), 'test': test.reset_index(drop = True)}
[docs]def create_fold_setting_cold(df, fold_seed, frac, entities): """create cold-split where given one or multiple columns, it first splits based on entities in the columns and then maps all associated data points to the partition Args: df (pd.DataFrame): dataset dataframe fold_seed (int): the random seed frac (list): a list of train/valid/test fractions entities (Union[str, List[str]]): either a single "cold" entity or a list of "cold" entities on which the split is done Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ if isinstance(entities, str): entities = [entities] train_frac, val_frac, test_frac = frac # For each entity, sample the instances belonging to the test datasets test_entity_instances = [ df[e].drop_duplicates().sample( frac=test_frac, replace=False, random_state=fold_seed ).values for e in entities ] # Select samples where all entities are in the test set test = df.copy() for entity, instances in zip(entities, test_entity_instances): test = test[test[entity].isin(instances)] if len(test) == 0: raise ValueError( 'No test samples found. Try another seed, increasing the test frac or a ' 'less stringent splitting strategy.' ) # Proceed with validation data train_val = df.copy() for i, e in enumerate(entities): train_val = train_val[~train_val[e].isin(test_entity_instances[i])] val_entity_instances = [ train_val[e].drop_duplicates().sample( frac=val_frac/(1-test_frac), replace=False, random_state=fold_seed ).values for e in entities ] val = train_val.copy() for entity, instances in zip(entities, val_entity_instances): val = val[val[entity].isin(instances)] if len(val) == 0: raise ValueError( 'No validation samples found. Try another seed, increasing the test frac ' 'or a less stringent splitting strategy.' ) train = train_val.copy() for i,e in enumerate(entities): train = train[~train[e].isin(val_entity_instances[i])] return {'train': train.reset_index(drop = True), 'valid': val.reset_index(drop = True), 'test': test.reset_index(drop = True)}
[docs]def create_scaffold_split(df, seed, frac, entity): """create scaffold split. it first generates molecular scaffold for each molecule and then split based on scaffolds reference: https://github.com/chemprop/chemprop/blob/master/chemprop/data/scaffold.py Args: df (pd.DataFrame): dataset dataframe fold_seed (int): the random seed frac (list): a list of train/valid/test fractions entity (str): the column name for where molecule stores Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ try: from rdkit import Chem from rdkit.Chem.Scaffolds import MurckoScaffold from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') except: raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") from tqdm import tqdm from random import Random from collections import defaultdict random = Random(seed) s = df[entity].values scaffolds = defaultdict(set) idx2mol = dict(zip(list(range(len(s))),s)) error_smiles = 0 for i, smiles in tqdm(enumerate(s), total=len(s)): try: scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol = Chem.MolFromSmiles(smiles), includeChirality = False) scaffolds[scaffold].add(i) except: print_sys(smiles + ' returns RDKit error and is thus omitted...') error_smiles += 1 train, val, test = [], [], [] train_size = int((len(df) - error_smiles) * frac[0]) val_size = int((len(df) - error_smiles) * frac[1]) test_size = (len(df) - error_smiles) - train_size - val_size train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0 #index_sets = sorted(list(scaffolds.values()), key=lambda i: len(i), reverse=True) index_sets = list(scaffolds.values()) big_index_sets = [] small_index_sets = [] for index_set in index_sets: if len(index_set) > val_size / 2 or len(index_set) > test_size / 2: big_index_sets.append(index_set) else: small_index_sets.append(index_set) random.seed(seed) random.shuffle(big_index_sets) random.shuffle(small_index_sets) index_sets = big_index_sets + small_index_sets if frac[2] == 0: for index_set in index_sets: if len(train) + len(index_set) <= train_size: train += index_set train_scaffold_count += 1 else: val += index_set val_scaffold_count += 1 else: for index_set in index_sets: if len(train) + len(index_set) <= train_size: train += index_set train_scaffold_count += 1 elif len(val) + len(index_set) <= val_size: val += index_set val_scaffold_count += 1 else: test += index_set test_scaffold_count += 1 return {'train': df.iloc[train].reset_index(drop = True), 'valid': df.iloc[val].reset_index(drop = True), 'test': df.iloc[test].reset_index(drop = True)}
[docs]def create_combination_generation_split(df1, df2, seed, frac): """create random split Args: df (pd.DataFrame): dataset dataframe fold_seed (int): the random seed frac (list): a list of train/valid/test fractions Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ train_frac, val_frac, test_frac = frac length = len(df1) indices = np.random.permutation(length) train_idx, val_idx, test_idx = indices[:int(length*train_frac)], indices[int(length*train_frac):int(length*(train_frac+val_frac))], indices[int(length*(train_frac+val_frac)):] return {'train': {"pocket": [df1[i] for i in train_idx], "ligand": [df2[i] for i in train_idx]}, 'valid': {"pocket": [df1[i] for i in val_idx], "ligand": [df2[i] for i in val_idx]}, 'test': {"pocket": [df1[i] for i in test_idx], "ligand": [df2[i] for i in test_idx]}}
[docs]def create_combination_split(df, seed, frac): """ Function for splitting drug combination dataset such that no combinations are shared across the split Args: df (pd.Dataframe): dataset to split seed (int): random seed frac (list): split fraction as a list Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ test_size = int(len(df) * frac[2]) train_size = int(len(df) * frac[0]) val_size = len(df) - train_size - test_size np.random.seed(seed) # Create a new column for combination names df['concat'] = df['Drug1_ID'] + ',' + df['Drug2_ID'] # Identify shared drug combinations across all target classes combinations = [] for c in df['Cell_Line_ID'].unique(): df_cell = df[df['Cell_Line_ID'] == c] combinations.append(set(df_cell['concat'].values)) intxn = combinations[0] for c in combinations: intxn = intxn.intersection(c) # Split combinations into train, val and test test_choices = np.random.choice(list(intxn), int(test_size / len(df['Cell_Line_ID'].unique())), replace=False) trainval_intxn = intxn.difference(test_choices) val_choices = np.random.choice(list(trainval_intxn), int(val_size / len(df['Cell_Line_ID'].unique())), replace=False) ## Create train and test set test_set = df[df['concat'].isin(test_choices)].drop(columns=['concat']) val_set = df[df['concat'].isin(val_choices)] train_set = df[~df['concat'].isin(test_choices)].reset_index(drop=True) train_set = train_set[~train_set['concat'].isin(val_choices)] return {'train': train_set.reset_index(drop = True), 'valid': val_set.reset_index(drop = True), 'test': test_set.reset_index(drop = True)}
# create time split
[docs]def create_fold_time(df, frac, date_column): """create splits based on time Args: df (pd.DataFrame): the dataset dataframe frac (list): list of train/valid/test fractions date_column (str): the name of the column that contains the time info Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ df = df.sort_values(by = date_column).reset_index(drop = True) train_frac, val_frac, test_frac = frac[0], frac[1], frac[2] split_date = df[:int(len(df) * (train_frac + val_frac))].iloc[-1][date_column] test = df[df[date_column] >= split_date].reset_index(drop = True) train_val = df[df[date_column] < split_date] split_date_valid = train_val[:int(len(train_val) * train_frac/(train_frac + val_frac))].iloc[-1][date_column] train = train_val[train_val[date_column] <= split_date_valid].reset_index(drop = True) valid = train_val[train_val[date_column] > split_date_valid].reset_index(drop = True) return {'train': train, 'valid': valid, 'test': test, 'split_time': {'train_time_frame': (df.iloc[0][date_column], split_date_valid), 'valid_time_frame': (split_date_valid, split_date), 'test_time_frame': (split_date, df.iloc[-1][date_column])}}
[docs]def create_group_split(train_val, seed, holdout_frac, group_column): """split within each stratification defined by the group column for training/validation split Args: train_val (pd.DataFrame): the train+valid dataframe to split on seed (int): the random seed holdout_frac (float): the fraction of validation group_column (str): the name of the group column Returns: dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe """ train_df = pd.DataFrame() val_df = pd.DataFrame() for i in train_val[group_column].unique(): train_val_temp = train_val[train_val[group_column] == i] np.random.seed(seed) msk = np.random.rand(len(train_val_temp)) < (1 - holdout_frac) train_df = train_df.append(train_val_temp[msk]) val_df = val_df.append(train_val_temp[~msk]) return {'train': train_df.reset_index(drop = True), 'valid': val_df.reset_index(drop = True)}