# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import os
import numpy as np
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from functools import partial
from torch.nn.utils import clip_grad_norm_
# from transformers.optimization import get_linear_schedule_with_warmup
from ..utils import Metrics
from ..utils import logger
from tqdm import tqdm
import time
[docs]
class Trainer(object):
"""A :class:`Trainer` class is responsible for initializing the model, and managing its training, validation, and testing phases."""
[docs]
def __init__(self, save_path=None, **params):
"""
:param save_path: Path for saving the training outputs. Defaults to None.
:param params: Additional parameters for training.
"""
self.save_path = save_path
self.task = params.get('task', None)
if self.task != 'repr':
self.metrics_str = params['metrics']
self.metrics = Metrics(self.task, self.metrics_str)
self._init_trainer(**params)
[docs]
def _init_trainer(self, **params):
"""
Initializing the trainer class to train model.
:param params: Containing training arguments.
"""
### init common params ###
self.split_method = params.get('split_method', '5fold_random')
self.split_seed = params.get('split_seed', 42)
self.seed = params.get('seed', 42)
self.set_seed(self.seed)
self.logger_level = int(params.get('logger_level', 1))
### init NN trainer params ###
self.learning_rate = float(params.get('learning_rate', 1e-4))
self.batch_size = params.get('batch_size', 32)
self.max_epochs = params.get('epochs', 50)
self.warmup_ratio = params.get('warmup_ratio', 0.1)
self.patience = params.get('patience', 10)
self.max_norm = params.get('max_norm', 1.0)
self.cuda = params.get('cuda', False)
self.amp = params.get('amp', False)
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() and self.cuda else "cpu")
self.scaler = torch.cuda.amp.GradScaler(
) if self.device.type == 'cuda' and self.amp == True else None
[docs]
def decorate_batch(self, batch, feature_name=None):
"""
Prepares a batch of data for processing by the model. This method is a wrapper that
delegates to a specific batch decoration method based on the data type.
:param batch: The batch of data to be processed.
:param feature_name: (str, optional) Name of the feature used in batch decoration. Defaults to None.
:return: The decorated batch ready for processing by the model.
"""
return self.decorate_torch_batch(batch)
[docs]
def decorate_graph_batch(self, batch):
"""
Prepares a graph-based batch of data for processing by the model. Specifically handles
graph-based data structures.
:param batch: The batch of graph-based data to be processed.
:return: A tuple of (net_input, net_target) for model processing.
"""
net_input, net_target = {'net_input': batch.to(
self.device)}, batch.y.to(self.device)
if self.task in ['classification', 'multiclass', 'multilabel_classification']:
net_target = net_target.long()
else:
net_target = net_target.float()
return net_input, net_target
[docs]
def decorate_torch_batch(self, batch):
"""
Prepares a standard PyTorch batch of data for processing by the model. Handles tensor-based data structures.
:param batch: The batch of tensor-based data to be processed.
:return: A tuple of (net_input, net_target) for model processing.
"""
net_input, net_target = batch
if isinstance(net_input, dict):
net_input, net_target = {
k: v.to(self.device) for k, v in net_input.items()}, net_target.to(self.device)
else:
net_input, net_target = {'net_input': net_input.to(
self.device)}, net_target.to(self.device)
if self.task == 'repr':
net_target = None
elif self.task in ['classification', 'multiclass', 'multilabel_classification']:
net_target = net_target.long()
else:
net_target = net_target.float()
return net_input, net_target
[docs]
def fit_predict(self, model, train_dataset, valid_dataset, loss_func, activation_fn, dump_dir, fold, target_scaler, feature_name=None):
"""
Trains the model on the given training dataset and evaluates it on the validation dataset.
:param model: The model to be trained and evaluated.
:param train_dataset: Dataset used for training the model.
:param valid_dataset: Dataset used for validating the model.
:param loss_func: The loss function used during training.
:param activation_fn: The activation function applied to the model's output.
:param dump_dir: Directory where the best model state is saved.
:param fold: The fold number in a cross-validation setting.
:param target_scaler: Scaler used for scaling the target variable.
:param feature_name: (optional) Name of the feature used in data loading. Defaults to None.
:return: Predictions made by the model on the validation dataset.
"""
model = model.to(self.device)
train_dataloader = NNDataLoader(
feature_name=feature_name,
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=model.batch_collate_fn,
drop_last=True,
)
# remove last batch, bs=1 can not work on batchnorm1d
min_val_loss = float("inf")
max_score = float("-inf")
wait = 0
### init optimizer ###
num_training_steps = len(train_dataloader) * self.max_epochs
num_warmup_steps = int(num_training_steps * self.warmup_ratio)
optimizer = Adam(model.parameters(), lr=self.learning_rate, eps=1e-6)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
for epoch in range(self.max_epochs):
model = model.train()
# Progress Bar
start_time = time.time()
batch_bar = tqdm(total=len(train_dataloader), dynamic_ncols=True,
leave=False, position=0, desc='Train', ncols=5)
trn_loss = []
for i, batch in enumerate(train_dataloader):
net_input, net_target = self.decorate_batch(
batch, feature_name)
optimizer.zero_grad() # Zero gradients
if self.scaler and self.device.type == 'cuda':
with torch.cuda.amp.autocast():
outputs = model(**net_input)
loss = loss_func(outputs, net_target)
else:
with torch.set_grad_enabled(True):
outputs = model(**net_input)
loss = loss_func(outputs, net_target)
trn_loss.append(float(loss.data))
# tqdm lets you add some details so you can monitor training as you train.
batch_bar.set_postfix(
Epoch="Epoch {}/{}".format(epoch+1, self.max_epochs),
loss="{:.04f}".format(float(sum(trn_loss) / (i + 1))),
lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])))
if self.scaler and self.device.type == 'cuda':
# This is a replacement for loss.backward()
self.scaler.scale(loss).backward()
# unscale the gradients of optimizer's assigned params in-place
self.scaler.unscale_(optimizer)
# Clip the norm of the gradients to max_norm.
clip_grad_norm_(model.parameters(), self.max_norm)
# This is a replacement for optimizer.step()
self.scaler.step(optimizer)
self.scaler.update()
else:
loss.backward()
clip_grad_norm_(model.parameters(), self.max_norm)
optimizer.step()
scheduler.step()
batch_bar.update()
batch_bar.close()
total_trn_loss = np.mean(trn_loss)
y_preds, val_loss, metric_score = self.predict(
model, valid_dataset, loss_func, activation_fn, dump_dir, fold, target_scaler, epoch, load_model=False, feature_name=feature_name)
end_time = time.time()
total_val_loss = np.mean(val_loss)
_score = list(metric_score.values())[0]
_metric = list(metric_score.keys())[0]
message = 'Epoch [{}/{}] train_loss: {:.4f}, val_loss: {:.4f}, val_{}: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch+1, self.max_epochs,
total_trn_loss, total_val_loss,
_metric, _score,
optimizer.param_groups[0]['lr'],
(end_time - start_time))
logger.info(message)
is_early_stop, min_val_loss, wait, max_score = self._early_stop_choice(
wait, total_val_loss, min_val_loss, metric_score, max_score, model, dump_dir, fold, self.patience, epoch)
if is_early_stop:
break
y_preds, _, _ = self.predict(model, valid_dataset, loss_func, activation_fn,
dump_dir, fold, target_scaler, epoch, load_model=True, feature_name=feature_name)
return y_preds
[docs]
def _early_stop_choice(self, wait, loss, min_loss, metric_score, max_score, model, dump_dir, fold, patience, epoch):
"""
Determines if early stopping criteria are met, based on either loss improvement or custom metric score.
:param wait: Number of epochs waited since the last improvement in loss or metric score.
:param loss: The current loss value.
:param min_loss: The minimum loss value observed so far.
:param metric_score: Current metric score.
:param max_score: The maximum metric score observed so far.
:param model: The model being trained.
:param dump_dir: Directory to save the best model state.
:param fold: The fold number in cross-validation.
:param patience: Number of epochs to wait for an improvement before stopping.
:param epoch: The current epoch number.
:return: A tuple (is_early_stop, min_val_loss, wait, max_score) indicating if early stopping criteria are met, the minimum validation loss, the updated wait time, and the maximum metric score.
"""
if not isinstance(self.metrics_str, str) or self.metrics_str in ['loss', 'none', '']:
is_early_stop, min_val_loss, wait = self._judge_early_stop_loss(
wait, loss, min_loss, model, dump_dir, fold, patience, epoch)
else:
is_early_stop, min_val_loss, wait, max_score = self.metrics._early_stop_choice(
wait, min_loss, metric_score, max_score, model, dump_dir, fold, patience, epoch)
return is_early_stop, min_val_loss, wait, max_score
[docs]
def _judge_early_stop_loss(self, wait, loss, min_loss, model, dump_dir, fold, patience, epoch):
"""
Determines whether early stopping should be triggered based on the loss comparison.
:param wait: The number of epochs to wait after min_loss has stopped improving.
:param loss: The current loss value of the model.
:param min_loss: The minimum loss value observed so far.
:param model: The neural network model being trained.
:param dump_dir: Directory to save the model state.
:param fold: The current fold number in a cross-validation setting.
:param patience: The number of epochs to wait for an improvement before stopping.
:param epoch: The current epoch number.
:return: A tuple (is_early_stop, min_loss, wait), where is_early_stop is a boolean indicating
whether early stopping should occur, min_loss is the updated minimum loss,
and wait is the updated wait counter.
"""
is_early_stop = False
if loss <= min_loss:
min_loss = loss
wait = 0
info = {'model_state_dict': model.state_dict()}
os.makedirs(dump_dir, exist_ok=True)
torch.save(info, os.path.join(dump_dir, f'model_{fold}.pth'))
elif loss >= min_loss:
wait += 1
if wait == self.patience:
logger.warning(f'Early stopping at epoch: {epoch+1}')
is_early_stop = True
return is_early_stop, min_loss, wait
[docs]
def predict(self, model, dataset, loss_func, activation_fn, dump_dir, fold, target_scaler=None, epoch=1, load_model=False, feature_name=None):
"""
Executes the prediction on a given dataset using the specified model.
:param model: The model to be used for predictions.
:param dataset: The dataset to perform predictions on.
:param loss_func: The loss function used during training.
:param activation_fn: The activation function applied to the model's output.
:param dump_dir: Directory where the model state is saved.
:param fold: The fold number in cross-validation.
:param target_scaler: (optional) Scaler to inverse transform the model's output. Defaults to None.
:param epoch: (int) The current epoch number. Defaults to 1.
:param load_model: (bool) Whether to load the model from a saved state. Defaults to False.
:param feature_name: (str, optional) Name of the feature for data processing. Defaults to None.
:return: A tuple (y_preds, val_loss, metric_score), where y_preds are the predicted outputs,
val_loss is the validation loss, and metric_score is the calculated metric score.
"""
model = model.to(self.device)
if load_model == True:
load_model_path = os.path.join(dump_dir, f'model_{fold}.pth')
model_dict = torch.load(load_model_path, map_location=self.device)[
"model_state_dict"]
model.load_state_dict(model_dict)
logger.info("load model success!")
dataloader = NNDataLoader(
feature_name=feature_name,
dataset=dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=model.batch_collate_fn,
)
model = model.eval()
batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True,
position=0, leave=False, desc='val', ncols=5)
val_loss = []
y_preds = []
y_truths = []
for i, batch in enumerate(dataloader):
net_input, net_target = self.decorate_batch(batch, feature_name)
# Get model outputs
with torch.no_grad():
outputs = model(**net_input)
if not load_model:
loss = loss_func(outputs, net_target)
val_loss.append(float(loss.data))
y_preds.append(activation_fn(outputs).cpu().numpy())
y_truths.append(net_target.detach().cpu().numpy())
if not load_model:
batch_bar.set_postfix(
Epoch="Epoch {}/{}".format(epoch+1, self.max_epochs),
loss="{:.04f}".format(float(np.sum(val_loss) / (i + 1))))
batch_bar.update()
y_preds = np.concatenate(y_preds)
y_truths = np.concatenate(y_truths)
try:
label_cnt = model.output_dim
except:
label_cnt = None
if target_scaler is not None:
inverse_y_preds = target_scaler.inverse_transform(y_preds)
inverse_y_truths = target_scaler.inverse_transform(y_truths)
metric_score = self.metrics.cal_metric(
inverse_y_truths, inverse_y_preds, label_cnt=label_cnt) if not load_model else None
else:
metric_score = self.metrics.cal_metric(
y_truths, y_preds, label_cnt=label_cnt) if not load_model else None
batch_bar.close()
return y_preds, val_loss, metric_score
[docs]
def inference(self, model, dataset, return_repr=False, return_atomic_reprs=False, feature_name=None):
"""
Runs inference on the given dataset using the provided model. This method can return
various representations based on the model's output.
:param model: The neural network model to be used for inference.
:param dataset: The dataset on which inference is to be performed.
:param return_repr: (bool, optional) If True, returns class-level representations. Defaults to False.
:param return_atomic_reprs: (bool, optional) If True, returns atomic-level representations. Defaults to False.
:param feature_name: (str, optional) Name of the feature used for data loading. Defaults to None.
:return: A dictionary containing different types of representations based on the model's output and the
specified parameters. This can include class-level representations, atomic coordinates,
atomic representations, and atomic symbols.
"""
model = model.to(self.device)
dataloader = NNDataLoader(
feature_name=feature_name,
dataset=dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=model.batch_collate_fn,
)
model = model.eval()
repr_dict = {"cls_repr": [], "atomic_coords": [], "atomic_reprs": [], "atomic_symbol": []}
for batch in tqdm(dataloader):
net_input, _ = self.decorate_batch(batch, feature_name)
with torch.no_grad():
outputs = model(**net_input,
return_repr=return_repr,
return_atomic_reprs=return_atomic_reprs)
assert isinstance(outputs, dict)
repr_dict["cls_repr"].extend(item.cpu().numpy() for item in outputs["cls_repr"])
if return_atomic_reprs:
repr_dict["atomic_symbol"].extend(outputs["atomic_symbol"])
repr_dict['atomic_coords'].extend(item.cpu().numpy() for item in outputs['atomic_coords'])
repr_dict['atomic_reprs'].extend(item.cpu().numpy() for item in outputs['atomic_reprs'])
return repr_dict
[docs]
def set_seed(self, seed):
"""
Sets a random seed for torch and numpy to ensure reproducibility.
:param seed: (int) The seed number to be set.
"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
[docs]
def NNDataLoader(feature_name=None, dataset=None, batch_size=None, shuffle=False, collate_fn=None, drop_last=False):
"""
Creates a DataLoader for neural network training or inference. This function is a wrapper
around the standard PyTorch DataLoader, allowing for custom feature handling and additional
configuration.
:param feature_name: (str, optional) Name of the feature used for data loading.
This can be used to specify a particular type of data processing. Defaults to None.
:param dataset: (Dataset, optional) The dataset from which to load the data. Defaults to None.
:param batch_size: (int, optional) Number of samples per batch to load. Defaults to None.
:param shuffle: (bool, optional) Whether to shuffle the data at every epoch. Defaults to False.
:param collate_fn: (callable, optional) Merges a list of samples to form a mini-batch. Defaults to None.
:param drop_last: (bool, optional) Set to True to drop the last incomplete batch. Defaults to False.
:return: DataLoader configured according to the provided parameters.
"""
dataloader = TorchDataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last)
return dataloader
# source from https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L108C1-L132C54
def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
[docs]
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_linear_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)