Source code for unimol_tools.predictor

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import numpy as np
import torch
from torch.utils.data import Dataset
from .data import DataHub
from .models import UniMolModel, UniMolV2Model
from .tasks import Trainer

[docs] class MolDataset(Dataset): """ A :class:`MolDataset` class is responsible for interface of molecular dataset. """
[docs] def __init__(self, data, label=None): self.data = data self.label = label if label is not None else np.zeros((len(data), 1))
def __getitem__(self, idx): return self.data[idx], self.label[idx] def __len__(self): return len(self.data)
[docs] class UniMolRepr(object): """ A :class:`UniMolRepr` class is responsible for interface of molecular representation by unimol """
[docs] def __init__(self, data_type='molecule', remove_hs=False, model_name='unimolv1', model_size='84m', use_gpu=True): """ Initialize a :class:`UniMolRepr` class. :param data_type: str, default='molecule', currently support molecule, oled. :param remove_hs: bool, default=False, whether to remove hydrogens in molecular. :param use_gpu: bool, default=True, whether to use gpu. :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2. :param model_size: str, default='84m', model size of unimolv2. """ self.device = torch.device("cuda:0" if torch.cuda.is_available() and use_gpu else "cpu") if model_name == 'unimolv1': self.model = UniMolModel(output_dim=1, data_type=data_type, remove_hs=remove_hs).to(self.device) elif model_name == 'unimolv2': self.model = UniMolV2Model(output_dim=1, model_size=model_size).to(self.device) else: raise ValueError('Unknown model name: {}'.format(model_name)) self.model.eval() self.params = { 'data_type': data_type, 'remove_hs': remove_hs, 'model_name': model_name, 'model_size': model_size, }
[docs] def get_repr(self, data=None, return_atomic_reprs=False): """ Get molecular representation by unimol. :param data: str, dict or list, default=None, input data for unimol. - str: smiles string or path to a smiles file. - dict: custom conformers, should take atoms and coordinates as input. - list: list of smiles strings. :param return_atomic_reprs: bool, default=False, whether to return atomic representations. :return: dict of molecular representation. """ if isinstance(data, str): # single smiles string. data = [data] data = np.array(data) elif isinstance(data, dict): # custom conformers, should take atoms and coordinates as input. assert 'atoms' in data and 'coordinates' in data elif isinstance(data, list): # list of smiles strings. assert isinstance(data[-1], str) data = np.array(data) else: raise ValueError('Unknown data type: {}'.format(type(data))) datahub = DataHub(data=data, task='repr', is_train=False, **self.params, ) dataset = MolDataset(datahub.data['unimol_input']) self.trainer = Trainer(task='repr', cuda=self.device) repr_output = self.trainer.inference(self.model, return_repr=True, return_atomic_reprs=return_atomic_reprs, dataset=dataset) return repr_output