Source code for unimol_tools.data.datareader

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import os
import pathlib

import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import PandasTools
from rdkit.Chem.Scaffolds import MurckoScaffold

from ..utils import logger


[docs] class MolDataReader(object): '''A class to read Mol Data.'''
[docs] def read_data(self, data=None, is_train=True, **params): # TO DO # 1. add anomaly detection & outlier removal. # 2. add support for other file format. # 3. add support for multi tasks. """ Reads and preprocesses molecular data from various input formats for model training or prediction. Parsing target columns 1. if target_cols is not None, use target_cols as target columns. 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns. 3. use given target_cols as target columns placeholder with value -1.0 for predict :param data: The input molecular data. Can be a file path (str), a dictionary, or a list of SMILES strings. :param is_train: (bool) A flag indicating if the operation is for training. Determines data processing steps. :param params: A dictionary of additional parameters for data processing. :return: A dictionary containing processed data and related information for model consumption. :raises ValueError: If the input data type is not supported or if any SMILES string is invalid (when strict). """ task = params.get('task', None) target_cols = params.get('target_cols', None) smiles_col = params.get('smiles_col', 'SMILES') target_col_prefix = params.get('target_col_prefix', 'TARGET') anomaly_clean = params.get('anomaly_clean', False) smi_strict = params.get('smi_strict', False) split_group_col = params.get('split_group_col', 'scaffold') if isinstance(data, str): # load from file self.data_path = data if data.endswith('.sdf'): # load sdf file data = PandasTools.LoadSDF(data) elif data.endswith('.csv'): data = pd.read_csv(self.data_path) else: raise ValueError('Unknown file type: {}'.format(data)) elif isinstance(data, dict): # load from dict if 'target' in data: label = np.array(data['target']) if len(label.shape) == 1 or label.shape[1] == 1: data[target_col_prefix] = label.reshape(-1) else: for i in range(label.shape[1]): data[target_col_prefix + str(i)] = label[:, i] _ = data.pop('target', None) data = pd.DataFrame(data).rename(columns={smiles_col: 'SMILES'}) elif isinstance(data, list) or isinstance(data, np.ndarray): # load from smiles list data = pd.DataFrame(data, columns=['SMILES']) else: raise ValueError('Unknown data type: {}'.format(type(data))) #### parsing target columns #### 1. if target_cols is not None, use target_cols as target columns. #### 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns. #### 3. use given target_cols as target columns placeholder with value -1.0 for predict if task == 'repr': # placeholder for repr task targets = None target_cols = None num_classes = None multiclass_cnt = None else: if target_cols is None: target_cols = [ item for item in data.columns if item.startswith(target_col_prefix) ] elif isinstance(target_cols, str): target_cols = [target_col.strip() for target_col in target_cols.split(',')] elif isinstance(target_cols, list): pass else: raise ValueError( 'Unknown target_cols type: {}'.format(type(target_cols)) ) if is_train: if anomaly_clean: data = self.anomaly_clean(data, task, target_cols) if task == 'multiclass': multiclass_cnt = int(data[target_cols].max() + 1) else: for col in target_cols: if col not in data.columns or data[col].isnull().any(): data[col] = -1.0 targets = data[target_cols].values.tolist() num_classes = len(target_cols) dd = { 'raw_data': data, 'raw_target': targets, 'num_classes': num_classes, 'target_cols': target_cols, 'multiclass_cnt': ( multiclass_cnt if task == 'multiclass' and is_train else None ), } if smiles_col in data.columns: mask = data[smiles_col].apply( lambda smi: self.check_smiles(smi, is_train, smi_strict) ) data = data[mask] dd['smiles'] = data[smiles_col].tolist() dd['scaffolds'] = data[smiles_col].map(self.smi2scaffold).tolist() else: dd['smiles'] = None dd['scaffolds'] = None if split_group_col in data.columns: dd['group'] = data[split_group_col].tolist() elif split_group_col == 'scaffold': dd['group'] = dd['scaffolds'] else: dd['group'] = None if 'atoms' in data.columns and 'coordinates' in data.columns: dd['atoms'] = data['atoms'].tolist() dd['coordinates'] = data['coordinates'].tolist() if 'ROMol' in data.columns: dd['mols'] = data['ROMol'].tolist() return dd
[docs] def check_smiles(self, smi, is_train, smi_strict): """ Validates a SMILES string and decides whether it should be included based on training mode and strictness. :param smi: (str) The SMILES string to check. :param is_train: (bool) Indicates if this check is happening during training. :param smi_strict: (bool) If true, invalid SMILES strings raise an error, otherwise they're logged and skipped. :return: (bool) True if the SMILES string is valid, False otherwise. :raises ValueError: If the SMILES string is invalid and strict mode is on. """ if Chem.MolFromSmiles(smi) is None: if is_train and not smi_strict: logger.info(f'Illegal SMILES clean: {smi}') return False else: raise ValueError(f'SMILES rule is illegal: {smi}') return True
[docs] def smi2scaffold(self, smi): """ Converts a SMILES string to its corresponding scaffold. :param smi: (str) The SMILES string to convert. :return: (str) The scaffold of the SMILES string, or the original SMILES if conversion fails. """ try: return MurckoScaffold.MurckoScaffoldSmiles( smiles=smi, includeChirality=True ) except: return smi
[docs] def anomaly_clean(self, data, task, target_cols): """ Performs anomaly cleaning on the data based on the specified task. :param data: (DataFrame) The dataset to be cleaned. :param task: (str) The type of task which determines the cleaning strategy. :param target_cols: (list) The list of target columns to consider for cleaning. :return: (DataFrame) The cleaned dataset. :raises ValueError: If the provided task is not recognized. """ if task in [ 'classification', 'multiclass', 'multilabel_classification', 'multilabel_regression', ]: return data if task == 'regression': return self.anomaly_clean_regression(data, target_cols) else: raise ValueError('Unknown task: {}'.format(task))
[docs] def anomaly_clean_regression(self, data, target_cols): """ Performs anomaly cleaning specifically for regression tasks using a 3-sigma threshold. :param data: (DataFrame) The dataset to be cleaned. :param target_cols: (list) The list of target columns to consider for cleaning. :return: (DataFrame) The cleaned dataset after applying the 3-sigma rule. """ sz = data.shape[0] target_col = target_cols[0] _mean, _std = data[target_col].mean(), data[target_col].std() data = data[ (data[target_col] > _mean - 3 * _std) & (data[target_col] < _mean + 3 * _std) ] logger.info( 'Anomaly clean with 3 sigma threshold: {} -> {}'.format(sz, data.shape[0]) ) return data