Source code for unimol_tools.data.datahub

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import os
import numpy as np
from rdkit.Chem import PandasTools

from ..utils import logger
from .conformer import ConformerGen, UniMolV2Feature
from .datareader import MolDataReader
from .datascaler import TargetScaler
from .split import Splitter


[docs] class DataHub(object): """ The DataHub class is responsible for storing and preprocessing data for machine learning tasks. It initializes with configuration options to handle different types of tasks such as regression, classification, and others. It also supports data scaling and handling molecular data. """
[docs] def __init__(self, data=None, is_train=True, save_path=None, **params): """ Initializes the DataHub instance with data and configuration for the ML task. :param data: Initial dataset to be processed. :param is_train: (bool) Indicates if the DataHub is being used for training. :param save_path: (str) Path to save any necessary files, like scalers. :param params: Additional parameters for data preprocessing and model configuration. """ self.raw_data = data self.is_train = is_train self.save_path = save_path self.task = params.get('task', None) self.target_cols = params.get('target_cols', None) self.multiclass_cnt = params.get('multiclass_cnt', None) self.ss_method = params.get('target_normalize', 'none') self.conf_cache_level = params.get('conf_cache_level', 1) self._init_data(**params) self._init_split(**params)
[docs] def _init_data(self, **params): """ Initializes and preprocesses the data based on the task and parameters provided. This method handles reading raw data, scaling targets, and transforming data for use with molecular inputs. It tailors the preprocessing steps based on the task type, such as regression or classification. :param params: Additional parameters for data processing. :raises ValueError: If the task type is unknown. """ self.data = MolDataReader().read_data(self.raw_data, self.is_train, **params) self.data['target_scaler'] = TargetScaler( self.ss_method, self.task, self.save_path ) if self.task == 'regression': target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.float32) if self.is_train: self.data['target_scaler'].fit(target, self.save_path) self.data['target'] = self.data['target_scaler'].transform(target) else: self.data['target'] = target elif self.task == 'classification': target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32) self.data['target'] = target elif self.task == 'multiclass': target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32) self.data['target'] = target if not self.is_train: self.data['multiclass_cnt'] = self.multiclass_cnt elif self.task == 'multilabel_regression': target = ( np.array(self.data['raw_target']) .reshape(-1, self.data['num_classes']) .astype(np.float32) ) if self.is_train: self.data['target_scaler'].fit(target, self.save_path) self.data['target'] = self.data['target_scaler'].transform(target) else: self.data['target'] = target elif self.task == 'multilabel_classification': target = ( np.array(self.data['raw_target']) .reshape(-1, self.data['num_classes']) .astype(np.int32) ) self.data['target'] = target elif self.task == 'repr': self.data['target'] = self.data['raw_target'] else: raise ValueError('Unknown task: {}'.format(self.task)) if params.get('model_name', None) == 'unimolv1': if 'mols' in self.data: no_h_list = ConformerGen(**params).transform_mols(self.data['mols']) mols = None elif 'atoms' in self.data and 'coordinates' in self.data: no_h_list = ConformerGen(**params).transform_raw( self.data['atoms'], self.data['coordinates'] ) mols = None else: smiles_list = self.data["smiles"] no_h_list, mols = ConformerGen(**params).transform(smiles_list) elif params.get('model_name', None) == 'unimolv2': if 'mols' in self.data: no_h_list = UniMolV2Feature(**params).transform_mols(self.data['mols']) mols = None elif 'atoms' in self.data and 'coordinates' in self.data: no_h_list = UniMolV2Feature(**params).transform_raw( self.data['atoms'], self.data['coordinates'] ) mols = None else: smiles_list = self.data["smiles"] no_h_list, mols = UniMolV2Feature(**params).transform(smiles_list) self.data['unimol_input'] = no_h_list if mols is not None: self.save_mol2sdf(self.data['raw_data'], mols, params)
def _init_split(self, **params): self.split_method = params.get('split_method', '5fold_random') kfold, method = ( int(self.split_method.split('fold')[0]), self.split_method.split('_')[-1], ) # Nfold_xxxx self.kfold = params.get('kfold', kfold) self.method = params.get('split', method) self.split_seed = params.get('split_seed', 42) self.data['kfold'] = self.kfold if not self.is_train: return self.splitter = Splitter(self.method, self.kfold, seed=self.split_seed) split_nfolds = self.splitter.split(**self.data) if self.kfold == 1: logger.info(f"Kfold is 1, all data is used for training.") else: logger.info(f"Split method: {self.method}, fold: {self.kfold}") nfolds = np.zeros(len(split_nfolds[0][0]) + len(split_nfolds[0][1]), dtype=int) for enu, (tr_idx, te_idx) in enumerate(split_nfolds): nfolds[te_idx] = enu self.data['split_nfolds'] = split_nfolds return split_nfolds
[docs] def save_mol2sdf(self, data, mols, params): """ Save the conformers to a SDF file. :param data: DataFrame containing the raw data. :param mols: List of RDKit molecule objects. """ if isinstance(self.raw_data, str): base_name = os.path.splitext(os.path.basename(self.raw_data))[0] elif isinstance(self.raw_data, list) or isinstance(self.raw_data, np.ndarray): # If the raw_data is a list of smiles, we can use a default name. base_name = 'unimol_conformers' else: logger.warning('Warning: raw_data is not a path or list, cannot save sdf.') return if params.get('sdf_save_path') is None: if self.save_path is not None: params['sdf_save_path'] = self.save_path else: return save_path = os.path.join(params.get('sdf_save_path'), f"{base_name}.sdf") if self.conf_cache_level == 0: logger.warning(f"conf_cache_level is 0, do not save conformers.") return elif self.conf_cache_level == 1 and os.path.exists(save_path): logger.warning(f"conf_cache_level is 1, but {save_path} exists, so do not save conformers.") return elif self.conf_cache_level == 2 or not os.path.exists(save_path): logger.info(f"conf_cache_level is {self.conf_cache_level}, saving conformers to {save_path}.") else: logger.warning(f"Unknown conf_cache_level: {self.conf_cache_level}, do not saving conformers.") return sdf_result = data.copy() sdf_result['ROMol'] = mols os.makedirs(os.path.dirname(save_path), exist_ok=True) try: PandasTools.WriteSDF( sdf_result, save_path, properties=list(sdf_result.columns), idName='RowID', ) logger.info(f"Successfully saved sdf file to {save_path}") except Exception as e: logger.warning(f"Failed to write sdf file: {e}") pass