Source code for tdc.single_pred.develop

# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT

import sys
import warnings


from . import single_pred_dataset
from ..utils import print_sys
from ..metadata import dataset_names

[docs]class Develop(single_pred_dataset.DataLoader): """Data loader class to load datasets in Develop task. More info: Args: name (str): the dataset name. path (str, optional): The path to save the data file, defaults to './data' label_name (str, optional): For multi-label dataset, specify the label name, defaults to None print_stats (bool, optional): Whether to print basic statistics of the dataset, defaults to False convert_format (str, optional): Automatic conversion of SMILES to other molecular formats in MolConvert class. Stored as separate column in dataframe, defaults to None """ def __init__( self, name, path="./data", label_name=None, print_stats=False, convert_format=None, ): """Create a Developability prediction dataloader object.""" super().__init__( name, path, label_name, print_stats, dataset_names=dataset_names["Develop"], convert_format=convert_format, ) self.entity1_name = "Antibody" if print_stats: self.print_stats() print("Done!", flush=True, file=sys.stderr)
[docs] def graphein( self, graph="distance", node_feature=["amino_acid_one_hot"], distance_threshold=6, config=None, convertor=None, ): from typing import Dict import torch from functools import partial import graphein.protein as gp from import GraphFormatConvertor from import InMemoryProteinGraphDataset, ProteinGraphDataset from graphein.protein.utils import get_obsolete_mapping import warnings warnings.filterwarnings("ignore") import logging logging.getLogger("graphein").setLevel("ERROR") try: split = self.split except: raise ValueError("Please define data.split() first!") print_sys( "Note that this task has obsolete PDBs, thus, we remove it in default. For fair comparison using other models, please retrieve the data.split object for the post removal splits!" ) from graphein.protein.utils import get_obsolete_mapping obs = get_obsolete_mapping() train_obs = [t for t in split["train"]["Antibody_ID"] if t in obs.keys()] valid_obs = [t for t in split["valid"]["Antibody_ID"] if t in obs.keys()] test_obs = [t for t in split["test"]["Antibody_ID"] if t in obs.keys()] split["train"] = split["train"].loc[ ~split["train"]["Antibody_ID"].isin(train_obs) ] split["test"] = split["test"].loc[~split["test"]["Antibody_ID"].isin(test_obs)] split["valid"] = split["valid"].loc[ ~split["valid"]["Antibody_ID"].isin(valid_obs) ] self.split = split def get_label_map(split_name: str) -> Dict[str, torch.Tensor]: return dict( zip( split[split_name].Antibody_ID, split[split_name].Y.apply(torch.tensor), ) ) train_labels = get_label_map("train") valid_labels = get_label_map("valid") test_labels = get_label_map("test") if config is None: if graph == "distance": edge_fct = [ partial( gp.add_distance_threshold, threshold=distance_threshold, long_interaction_threshold=0, ) ] node_feature_fct = [] for i in node_feature: if i == "amino_acid_one_hot": node_feature_fct.append(gp.amino_acid_one_hot) from functools import partial graphein_config = gp.ProteinGraphConfig( node_metadata_functions=node_feature_fct, edge_construction_functions=edge_fct, ) convertor = GraphFormatConvertor( src_format="nx", dst_format="pyg", columns=["coords", "edge_index"] + node_feature, ) else: graphein_config = config convertor = convertor train_ds = InMemoryProteinGraphDataset( root="./data/", name="train", pdb_codes=split["train"]["Antibody_ID"], graph_label_map=train_labels, graphein_config=graphein_config, graph_format_convertor=convertor, ) valid_ds = InMemoryProteinGraphDataset( root="./data/", name="valid", pdb_codes=split["valid"]["Antibody_ID"], graph_label_map=valid_labels, graphein_config=graphein_config, graph_format_convertor=convertor, ) test_ds = InMemoryProteinGraphDataset( root="./data/", name="test", pdb_codes=split["test"]["Antibody_ID"], graph_label_map=test_labels, graphein_config=graphein_config, graph_format_convertor=convertor, ) return {"train": train_ds, "valid": valid_ds, "test": test_ds}