"""Utilities functions for splitting dataset
"""
import os, sys
import numpy as np
import pandas as pd
from tqdm import tqdm
from .misc import print_sys
[docs]def create_fold(df, fold_seed, frac):
"""create random split
Args:
df (pd.DataFrame): dataset dataframe
fold_seed (int): the random seed
frac (list): a list of train/valid/test fractions
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
train_frac, val_frac, test_frac = frac
test = df.sample(frac=test_frac, replace=False, random_state=fold_seed)
train_val = df[~df.index.isin(test.index)]
val = train_val.sample(
frac=val_frac / (1 - test_frac), replace=False, random_state=1
)
train = train_val[~train_val.index.isin(val.index)]
return {
"train": train.reset_index(drop=True),
"valid": val.reset_index(drop=True),
"test": test.reset_index(drop=True),
}
[docs]def create_fold_setting_cold(df, fold_seed, frac, entities):
"""create cold-split where given one or multiple columns, it first splits based on
entities in the columns and then maps all associated data points to the partition
Args:
df (pd.DataFrame): dataset dataframe
fold_seed (int): the random seed
frac (list): a list of train/valid/test fractions
entities (Union[str, List[str]]): either a single "cold" entity or a list of
"cold" entities on which the split is done
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
if isinstance(entities, str):
entities = [entities]
train_frac, val_frac, test_frac = frac
# For each entity, sample the instances belonging to the test datasets
test_entity_instances = [
df[e]
.drop_duplicates()
.sample(frac=test_frac, replace=False, random_state=fold_seed)
.values
for e in entities
]
# Select samples where all entities are in the test set
test = df.copy()
for entity, instances in zip(entities, test_entity_instances):
test = test[test[entity].isin(instances)]
if len(test) == 0:
raise ValueError(
"No test samples found. Try another seed, increasing the test frac or a "
"less stringent splitting strategy."
)
# Proceed with validation data
train_val = df.copy()
for i, e in enumerate(entities):
train_val = train_val[~train_val[e].isin(test_entity_instances[i])]
val_entity_instances = [
train_val[e]
.drop_duplicates()
.sample(frac=val_frac / (1 - test_frac), replace=False, random_state=fold_seed)
.values
for e in entities
]
val = train_val.copy()
for entity, instances in zip(entities, val_entity_instances):
val = val[val[entity].isin(instances)]
if len(val) == 0:
raise ValueError(
"No validation samples found. Try another seed, increasing the test frac "
"or a less stringent splitting strategy."
)
train = train_val.copy()
for i, e in enumerate(entities):
train = train[~train[e].isin(val_entity_instances[i])]
return {
"train": train.reset_index(drop=True),
"valid": val.reset_index(drop=True),
"test": test.reset_index(drop=True),
}
[docs]def create_scaffold_split(df, seed, frac, entity):
"""create scaffold split. it first generates molecular scaffold for each molecule and then split based on scaffolds
reference: https://github.com/chemprop/chemprop/blob/master/chemprop/data/scaffold.py
Args:
df (pd.DataFrame): dataset dataframe
fold_seed (int): the random seed
frac (list): a list of train/valid/test fractions
entity (str): the column name for where molecule stores
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
try:
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")
except:
raise ImportError(
"Please install rdkit by 'conda install -c conda-forge rdkit'! "
)
from tqdm import tqdm
from random import Random
from collections import defaultdict
random = Random(seed)
s = df[entity].values
scaffolds = defaultdict(set)
idx2mol = dict(zip(list(range(len(s))), s))
error_smiles = 0
for i, smiles in tqdm(enumerate(s), total=len(s)):
try:
scaffold = MurckoScaffold.MurckoScaffoldSmiles(
mol=Chem.MolFromSmiles(smiles), includeChirality=False
)
scaffolds[scaffold].add(i)
except:
print_sys(smiles + " returns RDKit error and is thus omitted...")
error_smiles += 1
train, val, test = [], [], []
train_size = int((len(df) - error_smiles) * frac[0])
val_size = int((len(df) - error_smiles) * frac[1])
test_size = (len(df) - error_smiles) - train_size - val_size
train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0
# index_sets = sorted(list(scaffolds.values()), key=lambda i: len(i), reverse=True)
index_sets = list(scaffolds.values())
big_index_sets = []
small_index_sets = []
for index_set in index_sets:
if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
big_index_sets.append(index_set)
else:
small_index_sets.append(index_set)
random.seed(seed)
random.shuffle(big_index_sets)
random.shuffle(small_index_sets)
index_sets = big_index_sets + small_index_sets
if frac[2] == 0:
for index_set in index_sets:
if len(train) + len(index_set) <= train_size:
train += index_set
train_scaffold_count += 1
else:
val += index_set
val_scaffold_count += 1
else:
for index_set in index_sets:
if len(train) + len(index_set) <= train_size:
train += index_set
train_scaffold_count += 1
elif len(val) + len(index_set) <= val_size:
val += index_set
val_scaffold_count += 1
else:
test += index_set
test_scaffold_count += 1
return {
"train": df.iloc[train].reset_index(drop=True),
"valid": df.iloc[val].reset_index(drop=True),
"test": df.iloc[test].reset_index(drop=True),
}
[docs]def create_combination_generation_split(dict1, dict2, seed, frac):
"""create random split
Args:
dict: data dict
fold_seed (int): the random seed
frac (list): a list of train/valid/test fractions
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
train_frac, val_frac, test_frac = frac
length = len(dict1["coord"])
indices = np.random.permutation(length)
train_idx, val_idx, test_idx = (
indices[: int(length * train_frac)],
indices[int(length * train_frac) : int(length * (train_frac + val_frac))],
indices[int(length * (train_frac + val_frac)) :],
)
return {
"train": {
"protein_coord": [dict1["coord"][i] for i in train_idx],
"protein_atom_type": [dict1["atom_type"][i] for i in train_idx],
"ligand_coord": [dict2["coord"][i] for i in train_idx],
"ligand_atom_type": [dict2["atom_type"][i] for i in train_idx],
},
"valid": {
"protein_coord": [dict1["coord"][i] for i in val_idx],
"protein_atom_type": [dict1["atom_type"][i] for i in val_idx],
"ligand_coord": [dict2["coord"][i] for i in val_idx],
"ligand_atom_type": [dict2["atom_type"][i] for i in val_idx],
},
"test": {
"protein_coord": [dict1["coord"][i] for i in test_idx],
"protein_atom_type": [dict1["atom_type"][i] for i in test_idx],
"ligand_coord": [dict2["coord"][i] for i in test_idx],
"ligand_atom_type": [dict2["atom_type"][i] for i in test_idx],
},
}
[docs]def create_combination_split(df, seed, frac):
"""
Function for splitting drug combination dataset such that no combinations are shared across the split
Args:
df (pd.Dataframe): dataset to split
seed (int): random seed
frac (list): split fraction as a list
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
test_size = int(len(df) * frac[2])
train_size = int(len(df) * frac[0])
val_size = len(df) - train_size - test_size
np.random.seed(seed)
# Create a new column for combination names
df["concat"] = df["Drug1_ID"] + "," + df["Drug2_ID"]
# Identify shared drug combinations across all target classes
combinations = []
for c in df["Cell_Line_ID"].unique():
df_cell = df[df["Cell_Line_ID"] == c]
combinations.append(set(df_cell["concat"].values))
intxn = combinations[0]
for c in combinations:
intxn = intxn.intersection(c)
# Split combinations into train, val and test
test_choices = np.random.choice(
list(intxn), int(test_size / len(df["Cell_Line_ID"].unique())), replace=False
)
trainval_intxn = intxn.difference(test_choices)
val_choices = np.random.choice(
list(trainval_intxn),
int(val_size / len(df["Cell_Line_ID"].unique())),
replace=False,
)
## Create train and test set
test_set = df[df["concat"].isin(test_choices)].drop(columns=["concat"])
val_set = df[df["concat"].isin(val_choices)]
train_set = df[~df["concat"].isin(test_choices)].reset_index(drop=True)
train_set = train_set[~train_set["concat"].isin(val_choices)]
return {
"train": train_set.reset_index(drop=True),
"valid": val_set.reset_index(drop=True),
"test": test_set.reset_index(drop=True),
}
# create time split
[docs]def create_fold_time(df, frac, date_column):
"""create splits based on time
Args:
df (pd.DataFrame): the dataset dataframe
frac (list): list of train/valid/test fractions
date_column (str): the name of the column that contains the time info
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
df = df.sort_values(by=date_column).reset_index(drop=True)
train_frac, val_frac, test_frac = frac[0], frac[1], frac[2]
split_date = df[: int(len(df) * (train_frac + val_frac))].iloc[-1][date_column]
test = df[df[date_column] >= split_date].reset_index(drop=True)
train_val = df[df[date_column] < split_date]
split_date_valid = train_val[
: int(len(train_val) * train_frac / (train_frac + val_frac))
].iloc[-1][date_column]
train = train_val[train_val[date_column] <= split_date_valid].reset_index(drop=True)
valid = train_val[train_val[date_column] > split_date_valid].reset_index(drop=True)
return {
"train": train,
"valid": valid,
"test": test,
"split_time": {
"train_time_frame": (df.iloc[0][date_column], split_date_valid),
"valid_time_frame": (split_date_valid, split_date),
"test_time_frame": (split_date, df.iloc[-1][date_column]),
},
}
[docs]def create_group_split(train_val, seed, holdout_frac, group_column):
"""split within each stratification defined by the group column for training/validation split
Args:
train_val (pd.DataFrame): the train+valid dataframe to split on
seed (int): the random seed
holdout_frac (float): the fraction of validation
group_column (str): the name of the group column
Returns:
dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
"""
train_df = pd.DataFrame()
val_df = pd.DataFrame()
for i in train_val[group_column].unique():
train_val_temp = train_val[train_val[group_column] == i]
np.random.seed(seed)
msk = np.random.rand(len(train_val_temp)) < (1 - holdout_frac)
train_df = train_df.append(train_val_temp[msk])
val_df = val_df.append(train_val_temp[~msk])
return {
"train": train_df.reset_index(drop=True),
"valid": val_df.reset_index(drop=True),
}