Source code for tdc.multi_pred.drugres

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT

import warnings
import sys

from ..utils import print_sys
from ..utils.load import download_wrapper, pd_load
from . import bi_pred_dataset, multi_pred_dataset
from ..metadata import dataset_names

[docs]class DrugRes(bi_pred_dataset.DataLoader): """Data loader class to load datasets in Drug Response Prediction Task. More info: Task Description: Regression. Given the gene expression of cell lines and the SMILES of drug, predict the drug sensitivity level. Args: name (str): the dataset name. path (str, optional): The path to save the data file, defaults to './data' label_name (str, optional): For multi-label dataset, specify the label name, defaults to None print_stats (bool, optional): Whether to print basic statistics of the dataset, defaults to False """ def __init__(self, name, path='./data', label_name=None, print_stats=False): """Create Drug Response Prediction dataloader object """ super().__init__(name, path, label_name, print_stats, dataset_names=dataset_names["DrugRes"]) self.entity1_name = 'Drug' self.entity2_name = 'Cell Line' self.two_types = True self.path = path if print_stats: self.print_stats() print('Done!', flush=True, file=sys.stderr)
[docs] def get_gene_symbols(self): """ Retrieve the gene symbols for the cell line gene expression """ path = self.path name = download_wrapper('gdsc_gene_symbols', path, ['gdsc_gene_symbols']) print_sys('Loading...') import pandas as pd import os df = pd.read_csv(os.path.join(path, name + '.tab'), sep = '\t') return df.values.reshape(-1, )