# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from .data import DataHub
from .models import UniMolModel, UniMolV2Model
from .tasks import Trainer
[docs]
class MolDataset(Dataset):
"""
A :class:`MolDataset` class is responsible for interface of molecular dataset.
"""
[docs]
def __init__(self, data, label=None):
self.data = data
self.label = label if label is not None else np.zeros((len(data), 1))
def __getitem__(self, idx):
return self.data[idx], self.label[idx]
def __len__(self):
return len(self.data)
[docs]
class UniMolRepr(object):
"""
A :class:`UniMolRepr` class is responsible for interface of molecular representation by unimol
"""
[docs]
def __init__(
self,
data_type='molecule',
batch_size=32,
remove_hs=False,
model_name='unimolv1',
model_size='84m',
use_cuda=True,
use_ddp=False,
use_gpu='all',
save_path=None,
**kwargs,
):
"""
Initialize a :class:`UniMolRepr` class.
:param data_type: str, default='molecule', currently support molecule, oled.
:param batch_size: int, default=32, batch size for training.
:param remove_hs: bool, default=False, whether to remove hydrogens in molecular.
:param model_name: str, default='unimolv1', currently support unimolv1, unimolv2.
:param model_size: str, default='84m', model size of unimolv2. Avaliable: 84m, 164m, 310m, 570m, 1.1B.
:param use_cuda: bool, default=True, whether to use gpu.
:param use_ddp: bool, default=False, whether to use distributed data parallel.
:param use_gpu: str, default='all', which gpu to use.
"""
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() and use_cuda else "cpu"
)
if model_name == 'unimolv1':
self.model = UniMolModel(
output_dim=1, data_type=data_type, remove_hs=remove_hs
).to(self.device)
elif model_name == 'unimolv2':
self.model = UniMolV2Model(output_dim=1, model_size=model_size).to(
self.device
)
else:
raise ValueError('Unknown model name: {}'.format(model_name))
self.model.eval()
self.params = {
'data_type': data_type,
'batch_size': batch_size,
'remove_hs': remove_hs,
'model_name': model_name,
'model_size': model_size,
'use_cuda': use_cuda,
'use_ddp': use_ddp,
'use_gpu': use_gpu,
'save_path': save_path,
}
[docs]
def get_repr(self, data=None, return_atomic_reprs=False):
"""
Get molecular representation by unimol.
:param data: str, dict or list, default=None, input data for unimol.
- str: smiles string or path to a smiles file.
- dict: custom conformers, should take atoms and coordinates as input.
- list: list of smiles strings.
:param return_atomic_reprs: bool, default=False, whether to return atomic representations.
:return: dict of molecular representation.
"""
if isinstance(data, str):
if data.endswith('.sdf'):
# Datahub will process sdf file.
pass
elif data.endswith('.csv'):
# read csv file.
data = pd.read_csv(data)
assert 'SMILES' in data.columns
data = data['SMILES'].values
else:
# single smiles string.
data = [data]
data = np.array(data)
elif isinstance(data, dict):
# custom conformers, should take atoms and coordinates as input.
assert 'atoms' in data and 'coordinates' in data
elif isinstance(data, list):
# list of smiles strings.
assert isinstance(data[-1], str)
data = np.array(data)
else:
raise ValueError('Unknown data type: {}'.format(type(data)))
datahub = DataHub(
data=data,
task='repr',
is_train=False,
**self.params,
)
dataset = MolDataset(datahub.data['unimol_input'])
self.trainer = Trainer(task='repr', **self.params)
repr_output = self.trainer.inference(
self.model,
return_repr=True,
return_atomic_reprs=return_atomic_reprs,
dataset=dataset,
)
return repr_output