Source code for tdc.base_dataset

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT
"""
This file contains a base data loader object that specific one can inherit from. 
"""

import pandas as pd
import numpy as np
import sys
import warnings

warnings.filterwarnings("ignore")

from . import utils


[docs]class DataLoader: """base data loader class that contains functions shared by almost all data loader classes. """ def __init__(self): """empty data loader class, to be overwritten """ pass
[docs] def get_data(self, format='df'): ''' Arguments: format (str, optional): the dataset format Returns: pd.DataFrame/dict/np.array: when format is df/dict/DeepPurpose Raises: AttributeError: format not supported ''' if format == 'df': return pd.DataFrame({self.entity1_name + '_ID': self.entity1_idx, self.entity1_name: self.entity1, 'Y': self.y}) elif format == 'dict': return {self.entity1_name + '_ID': self.entity1_idx, self.entity1_name: self.entity1, 'Y': self.y} elif format == 'DeepPurpose': return self.entity1, self.y else: raise AttributeError("Please use the correct format input")
[docs] def print_stats(self): """print statistics """ print('There are ' + str(len(np.unique( self.entity1))) + ' unique ' + self.entity1_name.lower() + 's', flush=True, file=sys.stderr)
[docs] def get_split(self, method='random', seed=42, frac=[0.7, 0.1, 0.2]): ''' split function, overwritten by single_pred/multi_pred/generation for more specific splits Arguments: method: splitting schemes seed: random seed frac: train/val/test split fractions Returns: dict: a dictionary of train/valid/test dataframes Raises: AttributeError: split method not supported ''' df = self.get_data(format='df') if method == 'random': return utils.create_fold(df, seed, frac) elif method == 'cold_' + self.entity1_name.lower(): return utils.create_fold_setting_cold(df, seed, frac, self.entity1_name) else: raise AttributeError("Please specify the correct splitting method")
[docs] def label_distribution(self): """visualize distribution of labels """ utils.label_dist(self.y, self.name)
[docs] def binarize(self, threshold=None, order='descending'): """binarize the labels Args: threshold (float, optional): the threshold to binarize the label. order (str, optional): the order of binarization, if ascending, flip 1 to larger values and vice versus for descending Returns: DataLoader: data loader class with updated label Raises: AttributeError: no threshold specified for binarization """ if threshold is None: raise AttributeError( "Please specify the threshold to binarize the data by " "'binarize(threshold = N)'!") if (len(np.unique(self.y)) == 2): print("The data is already binarized!", flush=True, file=sys.stderr) else: print("Binariztion using threshold " + str( threshold) + ", default, we assume the smaller values are 1 " "and larger ones is 0, you can change the order " "by 'binarize(order = 'ascending')'", flush=True, file=sys.stderr) if np.unique(self.y).reshape(-1, ).shape[0] < 2: raise AttributeError( "Adjust your threshold, there is only one class.") self.y = utils.binarize(self.y, threshold, order) return self
def __len__(self): """get number of data points Returns: int: number of data points """ return len(self.get_data(format='df'))
[docs] def convert_to_log(self, form = 'standard'): """convert labels to log-scale Args: form (str, optional): standard log-transformation or binding nM <-> p transformation. """ print('To log space...', flush=True, file=sys.stderr) self.log_flag = True if form == 'binding': self.y = utils.convert_to_log(self.y) elif form == 'standard': self.sign = np.sign(self.y) self.y = self.sign * np.log(abs(self.y) + 1e-10)
[docs] def convert_from_log(self, form = 'standard'): """convert labels from log-scale Args: form (str, optional): standard log-transformation or binding nM <-> p transformation. """ print('Convert Back To Original space...', flush=True, file=sys.stderr) if form == 'binding': self.y = utils.convert_back_log(self.y) elif form == 'standard': self.y = self.sign * (np.exp(self.sign * self.y) - 1e-10) self.log_flag = False
[docs] def get_label_meaning(self, output_format='dict'): """get the biomedical meaning of label Args: output_format (str, optional): dict/df/array for label Returns: dict/pd.DataFrame/np.array: when output_format is dict/df/array """ return utils.get_label_map(self.name, self.path, self.target, file_format=self.file_format, output_format=output_format)
[docs] def balanced(self, oversample=False, seed=42): """balance the label neg-pos ratio Args: oversample (bool, optional): whether or not to oversample minority or subsample majority to match ratio seed (int, optional): random seed Returns: pd.DataFrame: the updated dataframe with balanced dataset Raises: AttributeError: alert to binarize the data first as continuous values cannot do balancing """ if len(np.unique(self.y)) > 2: raise AttributeError( "You should binarize the data first by calling " "data.binarize(threshold)", flush=True, file=sys.stderr) val = self.get_data() class_ = val.Y.value_counts().keys().values major_class = class_[0] minor_class = class_[1] if not oversample: print( " Subsample the majority class is used, if you want to do " "oversample the minority class, set 'balanced(oversample = True)'. ", flush=True, file=sys.stderr) val = pd.concat( [val[val.Y == major_class].sample( n=len(val[val.Y == minor_class]), replace=False, random_state=seed), val[val.Y == minor_class]]).sample( frac=1, replace=False, random_state=seed).reset_index( drop=True) else: print(" Oversample of minority class is used. ", flush=True, file=sys.stderr) val = pd.concat( [val[val.Y == minor_class].sample( n=len(val[val.Y == major_class]), replace=True, random_state=seed), val[val.Y == major_class]]).sample( frac=1, replace=False, random_state=seed).reset_index( drop=True) return val