Source code for tdc.base_dataset

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT
"""
This file contains a base data loader object that specific one can inherit from. 
"""

import pandas as pd
import numpy as np
import sys
import warnings

warnings.filterwarnings("ignore")

from . import utils


[docs]class DataLoader: """base data loader class that contains functions shared by almost all data loader classes.""" def __init__(self): """empty data loader class, to be overwritten""" pass
[docs] def get_data(self, format="df"): """ Arguments: format (str, optional): the dataset format Returns: pd.DataFrame/dict/np.array: when format is df/dict/DeepPurpose Raises: AttributeError: format not supported """ if format == "df": return pd.DataFrame( { self.entity1_name + "_ID": self.entity1_idx, self.entity1_name: self.entity1, "Y": self.y, } ) elif format == "dict": return { self.entity1_name + "_ID": self.entity1_idx, self.entity1_name: self.entity1, "Y": self.y, } elif format == "DeepPurpose": return self.entity1, self.y else: raise AttributeError("Please use the correct format input")
[docs] def print_stats(self): """print statistics""" print( "There are " + str(len(np.unique(self.entity1))) + " unique " + self.entity1_name.lower() + "s", flush=True, file=sys.stderr, )
[docs] def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): """ split function, overwritten by single_pred/multi_pred/generation for more specific splits Arguments: method: splitting schemes seed: random seed frac: train/val/test split fractions Returns: dict: a dictionary of train/valid/test dataframes Raises: AttributeError: split method not supported """ df = self.get_data(format="df") if method == "random": return utils.create_fold(df, seed, frac) elif method == "cold_" + self.entity1_name.lower(): return utils.create_fold_setting_cold(df, seed, frac, self.entity1_name) else: raise AttributeError("Please specify the correct splitting method")
[docs] def label_distribution(self): """visualize distribution of labels""" utils.label_dist(self.y, self.name)
[docs] def binarize(self, threshold=None, order="descending"): """binarize the labels Args: threshold (float, optional): the threshold to binarize the label. order (str, optional): the order of binarization, if ascending, flip 1 to larger values and vice versus for descending Returns: DataLoader: data loader class with updated label Raises: AttributeError: no threshold specified for binarization """ if threshold is None: raise AttributeError( "Please specify the threshold to binarize the data by " "'binarize(threshold = N)'!" ) if len(np.unique(self.y)) == 2: print("The data is already binarized!", flush=True, file=sys.stderr) else: print( "Binariztion using threshold " + str(threshold) + ", default, we assume the smaller values are 1 " "and larger ones is 0, you can change the order " "by 'binarize(order = 'ascending')'", flush=True, file=sys.stderr, ) if ( np.unique(self.y) .reshape( -1, ) .shape[0] < 2 ): raise AttributeError("Adjust your threshold, there is only one class.") self.y = utils.binarize(self.y, threshold, order) return self
def __len__(self): """get number of data points Returns: int: number of data points """ return len(self.get_data(format="df"))
[docs] def convert_to_log(self, form="standard"): """convert labels to log-scale Args: form (str, optional): standard log-transformation or binding nM <-> p transformation. """ print("To log space...", flush=True, file=sys.stderr) self.log_flag = True if form == "binding": self.y = utils.convert_to_log(self.y) elif form == "standard": self.sign = np.sign(self.y) self.y = self.sign * np.log(abs(self.y) + 1e-10)
[docs] def convert_from_log(self, form="standard"): """convert labels from log-scale Args: form (str, optional): standard log-transformation or binding nM <-> p transformation. """ print("Convert Back To Original space...", flush=True, file=sys.stderr) if form == "binding": self.y = utils.convert_back_log(self.y) elif form == "standard": self.y = self.sign * (np.exp(self.sign * self.y) - 1e-10) self.log_flag = False
[docs] def get_label_meaning(self, output_format="dict"): """get the biomedical meaning of label Args: output_format (str, optional): dict/df/array for label Returns: dict/pd.DataFrame/np.array: when output_format is dict/df/array """ return utils.get_label_map( self.name, self.path, self.target, file_format=self.file_format, output_format=output_format, )
[docs] def balanced(self, oversample=False, seed=42): """balance the label neg-pos ratio Args: oversample (bool, optional): whether or not to oversample minority or subsample majority to match ratio seed (int, optional): random seed Returns: pd.DataFrame: the updated dataframe with balanced dataset Raises: AttributeError: alert to binarize the data first as continuous values cannot do balancing """ if len(np.unique(self.y)) > 2: raise AttributeError( "You should binarize the data first by calling " "data.binarize(threshold)", flush=True, file=sys.stderr, ) val = self.get_data() class_ = val.Y.value_counts().keys().values major_class = class_[0] minor_class = class_[1] if not oversample: print( " Subsample the majority class is used, if you want to do " "oversample the minority class, set 'balanced(oversample = True)'. ", flush=True, file=sys.stderr, ) val = ( pd.concat( [ val[val.Y == major_class].sample( n=len(val[val.Y == minor_class]), replace=False, random_state=seed, ), val[val.Y == minor_class], ] ) .sample(frac=1, replace=False, random_state=seed) .reset_index(drop=True) ) else: print( " Oversample of minority class is used. ", flush=True, file=sys.stderr ) val = ( pd.concat( [ val[val.Y == minor_class].sample( n=len(val[val.Y == major_class]), replace=True, random_state=seed, ), val[val.Y == major_class], ] ) .sample(frac=1, replace=False, random_state=seed) .reset_index(drop=True) ) return val