Source code for unimol_tools.predictor

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

from .data import DataHub
from .models import UniMolModel, UniMolV2Model
from .tasks import Trainer


[docs] class MolDataset(Dataset): """ A :class:`MolDataset` class is responsible for interface of molecular dataset. """
[docs] def __init__(self, data, label=None): self.data = data self.label = label if label is not None else np.zeros((len(data), 1))
def __getitem__(self, idx): return self.data[idx], self.label[idx] def __len__(self): return len(self.data)
[docs] class UniMolRepr(object): """ A :class:`UniMolRepr` class is responsible for interface of molecular representation by unimol """
[docs] def __init__( self, data_type='molecule', batch_size=32, remove_hs=False, model_name='unimolv1', model_size='84m', use_cuda=True, use_ddp=False, use_gpu='all', save_path=None, **kwargs, ): """ Initialize a :class:`UniMolRepr` class. :param data_type: str, default='molecule', currently support molecule, oled. :param batch_size: int, default=32, batch size for training. :param remove_hs: bool, default=False, whether to remove hydrogens in molecular. :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2. :param model_size: str, default='84m', model size of unimolv2. Avaliable: 84m, 164m, 310m, 570m, 1.1B. :param use_cuda: bool, default=True, whether to use gpu. :param use_ddp: bool, default=False, whether to use distributed data parallel. :param use_gpu: str, default='all', which gpu to use. """ self.device = torch.device( "cuda:0" if torch.cuda.is_available() and use_cuda else "cpu" ) if model_name == 'unimolv1': self.model = UniMolModel( output_dim=1, data_type=data_type, remove_hs=remove_hs ).to(self.device) elif model_name == 'unimolv2': self.model = UniMolV2Model(output_dim=1, model_size=model_size).to( self.device ) else: raise ValueError('Unknown model name: {}'.format(model_name)) self.model.eval() self.params = { 'data_type': data_type, 'batch_size': batch_size, 'remove_hs': remove_hs, 'model_name': model_name, 'model_size': model_size, 'use_cuda': use_cuda, 'use_ddp': use_ddp, 'use_gpu': use_gpu, 'save_path': save_path, }
[docs] def get_repr(self, data=None, return_atomic_reprs=False): """ Get molecular representation by unimol. :param data: str, dict or list, default=None, input data for unimol. - str: smiles string or path to a smiles file. - dict: custom conformers, should take atoms and coordinates as input. - list: list of smiles strings. :param return_atomic_reprs: bool, default=False, whether to return atomic representations. :return: dict of molecular representation. """ if isinstance(data, str): if data.endswith('.sdf'): # Datahub will process sdf file. pass elif data.endswith('.csv'): # read csv file. data = pd.read_csv(data) assert 'SMILES' in data.columns data = data['SMILES'].values else: # single smiles string. data = [data] data = np.array(data) elif isinstance(data, dict): # custom conformers, should take atoms and coordinates as input. assert 'atoms' in data and 'coordinates' in data elif isinstance(data, list): # list of smiles strings. assert isinstance(data[-1], str) data = np.array(data) else: raise ValueError('Unknown data type: {}'.format(type(data))) datahub = DataHub( data=data, task='repr', is_train=False, **self.params, ) dataset = MolDataset(datahub.data['unimol_input']) self.trainer = Trainer(task='repr', **self.params) repr_output = self.trainer.inference( self.model, return_repr=True, return_atomic_reprs=return_atomic_reprs, dataset=dataset, ) return repr_output