Task
unimol_tools.tasks oversees the tasks related to the model, such as training and prediction.
Trainer
unimol_tools.tasks.trainer.py contains the Trainer, managing the training, validation, and testing phases.
- class unimol_tools.tasks.trainer.Trainer(save_path=None, **params)[source]
A
Trainerclass is responsible for initializing the model, and managing its training, validation, and testing phases.- __init__(save_path=None, **params)[source]
- Parameters:
save_path – Path for saving the training outputs. Defaults to None.
params – Additional parameters for training.
- _init_trainer(**params)[source]
Initializing the trainer class to train model.
- Parameters:
params – Containing training arguments.
- decorate_batch(batch, feature_name=None)[source]
Prepares a batch of data for processing by the model. This method is a wrapper that delegates to a specific batch decoration method based on the data type.
- Parameters:
batch – The batch of data to be processed.
feature_name – (str, optional) Name of the feature used in batch decoration. Defaults to None.
- Returns:
The decorated batch ready for processing by the model.
- decorate_graph_batch(batch)[source]
Prepares a graph-based batch of data for processing by the model. Specifically handles graph-based data structures.
- Parameters:
batch – The batch of graph-based data to be processed.
- Returns:
A tuple of (net_input, net_target) for model processing.
- decorate_torch_batch(batch)[source]
Prepares a standard PyTorch batch of data for processing by the model. Handles tensor-based data structures.
- Parameters:
batch – The batch of tensor-based data to be processed.
- Returns:
A tuple of (net_input, net_target) for model processing.
- fit_predict(model, train_dataset, valid_dataset, loss_func, activation_fn, dump_dir, fold, target_scaler, feature_name=None)[source]
Trains the model on the given training dataset and evaluates it on the validation dataset.
- Parameters:
model – The model to be trained and evaluated.
train_dataset – Dataset used for training the model.
valid_dataset – Dataset used for validating the model.
loss_func – The loss function used during training.
activation_fn – The activation function applied to the model’s output.
dump_dir – Directory where the best model state is saved.
fold – The fold number in a cross-validation setting.
target_scaler – Scaler used for scaling the target variable.
feature_name – (optional) Name of the feature used in data loading. Defaults to None.
- Returns:
Predictions made by the model on the validation dataset.
- _early_stop_choice(wait, loss, min_loss, metric_score, max_score, model, dump_dir, fold, patience, epoch)[source]
Determines if early stopping criteria are met, based on either loss improvement or custom metric score.
- Parameters:
wait – Number of epochs waited since the last improvement in loss or metric score.
loss – The current loss value.
min_loss – The minimum loss value observed so far.
metric_score – Current metric score.
max_score – The maximum metric score observed so far.
model – The model being trained.
dump_dir – Directory to save the best model state.
fold – The fold number in cross-validation.
patience – Number of epochs to wait for an improvement before stopping.
epoch – The current epoch number.
- Returns:
A tuple (is_early_stop, min_val_loss, wait, max_score) indicating if early stopping criteria are met, the minimum validation loss, the updated wait time, and the maximum metric score.
- _judge_early_stop_loss(wait, loss, min_loss, model, dump_dir, fold, patience, epoch)[source]
Determines whether early stopping should be triggered based on the loss comparison.
- Parameters:
wait – The number of epochs to wait after min_loss has stopped improving.
loss – The current loss value of the model.
min_loss – The minimum loss value observed so far.
model – The neural network model being trained.
dump_dir – Directory to save the model state.
fold – The current fold number in a cross-validation setting.
patience – The number of epochs to wait for an improvement before stopping.
epoch – The current epoch number.
- Returns:
A tuple (is_early_stop, min_loss, wait), where is_early_stop is a boolean indicating whether early stopping should occur, min_loss is the updated minimum loss, and wait is the updated wait counter.
- predict(model, dataset, loss_func, activation_fn, dump_dir, fold, target_scaler=None, epoch=1, load_model=False, feature_name=None)[source]
Executes the prediction on a given dataset using the specified model.
- Parameters:
model – The model to be used for predictions.
dataset – The dataset to perform predictions on.
loss_func – The loss function used during training.
activation_fn – The activation function applied to the model’s output.
dump_dir – Directory where the model state is saved.
fold – The fold number in cross-validation.
target_scaler – (optional) Scaler to inverse transform the model’s output. Defaults to None.
epoch – (int) The current epoch number. Defaults to 1.
load_model – (bool) Whether to load the model from a saved state. Defaults to False.
feature_name – (str, optional) Name of the feature for data processing. Defaults to None.
- Returns:
A tuple (y_preds, val_loss, metric_score), where y_preds are the predicted outputs, val_loss is the validation loss, and metric_score is the calculated metric score.
- inference(model, dataset, return_repr=False, return_atomic_reprs=False, feature_name=None)[source]
Runs inference on the given dataset using the provided model. This method can return various representations based on the model’s output.
- Parameters:
model – The neural network model to be used for inference.
dataset – The dataset on which inference is to be performed.
return_repr – (bool, optional) If True, returns class-level representations. Defaults to False.
return_atomic_reprs – (bool, optional) If True, returns atomic-level representations. Defaults to False.
feature_name – (str, optional) Name of the feature used for data loading. Defaults to None.
- Returns:
A dictionary containing different types of representations based on the model’s output and the specified parameters. This can include class-level representations, atomic coordinates, atomic representations, and atomic symbols.
- unimol_tools.tasks.trainer.NNDataLoader(feature_name=None, dataset=None, batch_size=None, shuffle=False, collate_fn=None, drop_last=False)[source]
Creates a DataLoader for neural network training or inference. This function is a wrapper around the standard PyTorch DataLoader, allowing for custom feature handling and additional configuration.
- Parameters:
feature_name – (str, optional) Name of the feature used for data loading. This can be used to specify a particular type of data processing. Defaults to None.
dataset – (Dataset, optional) The dataset from which to load the data. Defaults to None.
batch_size – (int, optional) Number of samples per batch to load. Defaults to None.
shuffle – (bool, optional) Whether to shuffle the data at every epoch. Defaults to False.
collate_fn – (callable, optional) Merges a list of samples to form a mini-batch. Defaults to None.
drop_last – (bool, optional) Set to True to drop the last incomplete batch. Defaults to False.
- Returns:
DataLoader configured according to the provided parameters.
- unimol_tools.tasks.trainer.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1)[source]
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
- Args:
- optimizer ([~torch.optim.Optimizer]):
The optimizer for which to schedule the learning rate.
- num_warmup_steps (int):
The number of steps for the warmup phase.
- num_training_steps (int):
The total number of training steps.
- last_epoch (int, optional, defaults to -1):
The index of the last epoch when resuming training.
- Return:
torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.