Source code for unimol_tools.models.nnmodel

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import joblib
from torch.utils.data import Dataset
import numpy as np
from ..utils import logger
from .unimol import UniMolModel
from .unimolv2 import UniMolV2Model
from .loss import GHMC_Loss, FocalLossWithLogits, myCrossEntropyLoss, MAEwithNan


NNMODEL_REGISTER = {
    'unimolv1': UniMolModel,
    'unimolv2': UniMolV2Model,
}

LOSS_RREGISTER = {
    'classification': myCrossEntropyLoss,
    'multiclass': myCrossEntropyLoss,
    'regression': nn.MSELoss(),
    'multilabel_classification': {
        'bce': nn.BCEWithLogitsLoss(),
        'ghm': GHMC_Loss(bins=10, alpha=0.5),
        'focal': FocalLossWithLogits,
    },
    'multilabel_regression': MAEwithNan,
}
ACTIVATION_FN = {
    # predict prob shape should be (N, K), especially for binary classification, K equals to 1.
    'classification': lambda x: F.softmax(x, dim=-1)[:, 1:],
    # softmax is used for multiclass classification
    'multiclass': lambda x: F.softmax(x, dim=-1),
    'regression': lambda x: x,
    # sigmoid is used for multilabel classification
    'multilabel_classification': lambda x: F.sigmoid(x),
    # no activation function is used for multilabel regression
    'multilabel_regression': lambda x: x,
}
OUTPUT_DIM = {
    'classification': 2,
    'regression': 1,
}


[docs] class NNModel(object): """A :class:`NNModel` class is responsible for initializing the model"""
[docs] def __init__(self, data, trainer, **params): """ Initializes the neural network model with the given data and parameters. :param data: (dict) Contains the dataset information, including features and target scaling. :param trainer: (object) An instance of a training class, responsible for managing training processes. :param params: Various additional parameters used for model configuration. The model is configured based on the task type and specific parameters provided. """ self.data = data self.num_classes = self.data['num_classes'] self.target_scaler = self.data['target_scaler'] self.features = data['unimol_input'] self.model_name = params.get('model_name', 'unimolv1') self.data_type = params.get('data_type', 'molecule') self.loss_key = params.get('loss_key', None) self.trainer = trainer #self.splitter = self.trainer.splitter self.model_params = params.copy() self.task = params['task'] if self.task in OUTPUT_DIM: self.model_params['output_dim'] = OUTPUT_DIM[self.task] elif self.task == 'multiclass': self.model_params['output_dim'] = self.data['multiclass_cnt'] else: self.model_params['output_dim'] = self.num_classes self.model_params['device'] = self.trainer.device self.cv = dict() self.metrics = self.trainer.metrics if self.task == 'multilabel_classification': if self.loss_key is None: self.loss_key = 'focal' self.loss_func = LOSS_RREGISTER[self.task][self.loss_key] else: self.loss_func = LOSS_RREGISTER[self.task] self.activation_fn = ACTIVATION_FN[self.task] self.save_path = self.trainer.save_path self.trainer.set_seed(self.trainer.seed) self.model = self._init_model(**self.model_params)
[docs] def _init_model(self, model_name, **params): """ Initializes the neural network model based on the provided model name and parameters. :param model_name: (str) The name of the model to initialize. :param params: Additional parameters for model configuration. :return: An instance of the specified neural network model. :raises ValueError: If the model name is not recognized. """ freeze_layers = params.get('freeze_layers', None) freeze_layers_reversed = params.get('freeze_layers_reversed', False) if model_name in NNMODEL_REGISTER: model = NNMODEL_REGISTER[model_name](**params) if isinstance(freeze_layers, str): freeze_layers = freeze_layers.replace(' ', '').split(',') if isinstance(freeze_layers, list): for layer_name, layer_param in model.named_parameters(): should_freeze = any(layer_name.startswith(freeze_layer) for freeze_layer in freeze_layers) layer_param.requires_grad = not (freeze_layers_reversed ^ should_freeze) else: raise ValueError('Unknown model: {}'.format(self.model_name)) return model
[docs] def collect_data(self, X, y, idx): """ Collects and formats the training or validation data. :param X: (np.ndarray or dict) The input features, either as a numpy array or a dictionary of tensors. :param y: (np.ndarray) The target values as a numpy array. :param idx: Indices to select the specific data samples. :return: A tuple containing processed input data and target values. :raises ValueError: If X is neither a numpy array nor a dictionary. """ assert isinstance(y, np.ndarray), 'y must be numpy array' if isinstance(X, np.ndarray): return torch.from_numpy(X[idx]).float(), torch.from_numpy(y[idx]) elif isinstance(X, list): return {k: v[idx] for k, v in X.items()}, torch.from_numpy(y[idx]) else: raise ValueError('X must be numpy array or dict')
[docs] def run(self): """ Executes the training process of the model. This involves data preparation, model training, validation, and computing metrics for each fold in cross-validation. """ logger.info("start training Uni-Mol:{}".format(self.model_name)) X = np.asarray(self.features) y = np.asarray(self.data['target']) group = np.asarray(self.data['group']) if self.data['group'] is not None else None if self.task == 'classification': y_pred = np.zeros_like( y.reshape(y.shape[0], self.num_classes)).astype(float) else: y_pred = np.zeros((y.shape[0], self.model_params['output_dim'])) for fold, (tr_idx, te_idx) in enumerate(self.data['split_nfolds']): X_train, y_train = X[tr_idx], y[tr_idx] X_valid, y_valid = X[te_idx], y[te_idx] traindataset = NNDataset(X_train, y_train) validdataset = NNDataset(X_valid, y_valid) if fold > 0: # need to initalize model for next fold training self.model = self._init_model(**self.model_params) if self.model_params.get('load_model_dir', None) is not None: load_model_path = os.path.join(self.model_params['load_model_dir'], f'model_{fold}.pth') model_dict = torch.load(load_model_path, map_location=self.model_params['device'])["model_state_dict"] if model_dict['classification_head.out_proj.weight'].shape[0] != self.model.output_dim: current_model_dict = self.model.state_dict() model_dict = {k: v for k, v in model_dict.items() if k in current_model_dict and 'classification_head.out_proj' not in k} current_model_dict.update(model_dict) logger.info("The output_dim of the model is different from the loaded model, only load the common part of the model") self.model.load_state_dict(model_dict, strict=False) else: self.model.load_state_dict(model_dict) logger.info("load model success from {}".format(load_model_path)) _y_pred = self.trainer.fit_predict( self.model, traindataset, validdataset, self.loss_func, self.activation_fn, self.save_path, fold, self.target_scaler) y_pred[te_idx] = _y_pred if 'multiclass_cnt' in self.data: label_cnt = self.data['multiclass_cnt'] else: label_cnt = None logger.info("fold {0}, result {1}".format( fold, self.metrics.cal_metric( self.data['target_scaler'].inverse_transform(y_valid), self.data['target_scaler'].inverse_transform(_y_pred), label_cnt=label_cnt ) ) ) self.cv['pred'] = y_pred self.cv['metric'] = self.metrics.cal_metric(self.data['target_scaler'].inverse_transform( y), self.data['target_scaler'].inverse_transform(self.cv['pred'])) self.dump(self.cv['pred'], self.save_path, 'cv.data') self.dump(self.cv['metric'], self.save_path, 'metric.result') logger.info("Uni-Mol metrics score: \n{}".format(self.cv['metric'])) logger.info("Uni-Mol & Metric result saved!")
[docs] def dump(self, data, dir, name): """ Saves the specified data to a file. :param data: The data to be saved. :param dir: (str) The directory where the data will be saved. :param name: (str) The name of the file to save the data. """ path = os.path.join(dir, name) if not os.path.exists(dir): os.makedirs(dir) joblib.dump(data, path)
[docs] def evaluate(self, trainer=None, checkpoints_path=None): """ Evaluates the model by making predictions on the test set and averaging the results. :param trainer: An optional trainer instance to use for prediction. :param checkpoints_path: (str) The path to the saved model checkpoints. """ logger.info("start predict NNModel:{}".format(self.model_name)) testdataset = NNDataset(self.features, np.asarray(self.data['target'])) for fold in range(self.data['kfold']): model_path = os.path.join(checkpoints_path, f'model_{fold}.pth') self.model.load_state_dict(torch.load( model_path, map_location=self.trainer.device)['model_state_dict']) _y_pred, _, __ = trainer.predict(self.model, testdataset, self.loss_func, self.activation_fn, self.save_path, fold, self.target_scaler, epoch=1, load_model=True) if fold == 0: y_pred = np.zeros_like(_y_pred) y_pred += _y_pred y_pred /= self.data['kfold'] self.cv['test_pred'] = y_pred
[docs] def count_parameters(self, model): """ Counts the number of trainable parameters in the model. :param model: The model whose parameters are to be counted. :return: (int) The number of trainable parameters. """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] def NNDataset(data, label=None): """ Creates a dataset suitable for use with PyTorch models. :param data: The input data. :param label: Optional labels corresponding to the input data. :return: An instance of TorchDataset. """ return TorchDataset(data, label)
[docs] class TorchDataset(Dataset): """ A custom dataset class for PyTorch that handles data and labels. This class is compatible with PyTorch's Dataset interface and can be used with a DataLoader for efficient batch processing. It's designed to work with both numpy arrays and PyTorch tensors. """
[docs] def __init__(self, data, label=None): """ Initializes the dataset with data and labels. :param data: The input data. :param label: The target labels for the input data. """ self.data = data self.label = label if label is not None else np.zeros((len(data), 1))
def __getitem__(self, idx): """ Retrieves the data item and its corresponding label at the specified index. :param idx: (int) The index of the data item to retrieve. :return: A tuple containing the data item and its label. """ return self.data[idx], self.label[idx] def __len__(self): """ Returns the total number of items in the dataset. :return: (int) The size of the dataset. """ return len(self.data)