Task

unimol_tools.tasks oversees the tasks related to the model, such as training and prediction.

Trainer

unimol_tools.tasks.trainer.py contains the Trainer, managing the training, validation, and testing phases.

class unimol_tools.tasks.trainer.Trainer(save_path=None, **params)[source]

A Trainer class is responsible for initializing the model, and managing its training, validation, and testing phases.

__init__(save_path=None, **params)[source]
Parameters:
  • save_path – Path for saving the training outputs. Defaults to None.

  • params – Additional parameters for training.

_init_trainer(**params)[source]

Initializing the trainer class to train model.

Parameters:

params – Containing training arguments.

decorate_batch(batch, feature_name=None)[source]

Prepares a batch of data for processing by the model. This method is a wrapper that delegates to a specific batch decoration method based on the data type.

Parameters:
  • batch – The batch of data to be processed.

  • feature_name – (str, optional) Name of the feature used in batch decoration. Defaults to None.

Returns:

The decorated batch ready for processing by the model.

decorate_graph_batch(batch)[source]

Prepares a graph-based batch of data for processing by the model. Specifically handles graph-based data structures.

Parameters:

batch – The batch of graph-based data to be processed.

Returns:

A tuple of (net_input, net_target) for model processing.

decorate_torch_batch(batch)[source]

Prepares a standard PyTorch batch of data for processing by the model. Handles tensor-based data structures.

Parameters:

batch – The batch of tensor-based data to be processed.

Returns:

A tuple of (net_input, net_target) for model processing.

fit_predict(model, train_dataset, valid_dataset, loss_func, activation_fn, dump_dir, fold, target_scaler, feature_name=None)[source]

Trains the model on the given training dataset and evaluates it on the validation dataset.

Parameters:
  • model – The model to be trained and evaluated.

  • train_dataset – Dataset used for training the model.

  • valid_dataset – Dataset used for validating the model.

  • loss_func – The loss function used during training.

  • activation_fn – The activation function applied to the model’s output.

  • dump_dir – Directory where the best model state is saved.

  • fold – The fold number in a cross-validation setting.

  • target_scaler – Scaler used for scaling the target variable.

  • feature_name – (optional) Name of the feature used in data loading. Defaults to None.

Returns:

Predictions made by the model on the validation dataset.

_early_stop_choice(wait, loss, min_loss, metric_score, max_score, model, dump_dir, fold, patience, epoch)[source]

Determines if early stopping criteria are met, based on either loss improvement or custom metric score.

Parameters:
  • wait – Number of epochs waited since the last improvement in loss or metric score.

  • loss – The current loss value.

  • min_loss – The minimum loss value observed so far.

  • metric_score – Current metric score.

  • max_score – The maximum metric score observed so far.

  • model – The model being trained.

  • dump_dir – Directory to save the best model state.

  • fold – The fold number in cross-validation.

  • patience – Number of epochs to wait for an improvement before stopping.

  • epoch – The current epoch number.

Returns:

A tuple (is_early_stop, min_val_loss, wait, max_score) indicating if early stopping criteria are met, the minimum validation loss, the updated wait time, and the maximum metric score.

_judge_early_stop_loss(wait, loss, min_loss, model, dump_dir, fold, patience, epoch)[source]

Determines whether early stopping should be triggered based on the loss comparison.

Parameters:
  • wait – The number of epochs to wait after min_loss has stopped improving.

  • loss – The current loss value of the model.

  • min_loss – The minimum loss value observed so far.

  • model – The neural network model being trained.

  • dump_dir – Directory to save the model state.

  • fold – The current fold number in a cross-validation setting.

  • patience – The number of epochs to wait for an improvement before stopping.

  • epoch – The current epoch number.

Returns:

A tuple (is_early_stop, min_loss, wait), where is_early_stop is a boolean indicating whether early stopping should occur, min_loss is the updated minimum loss, and wait is the updated wait counter.

predict(model, dataset, loss_func, activation_fn, dump_dir, fold, target_scaler=None, epoch=1, load_model=False, feature_name=None)[source]

Executes the prediction on a given dataset using the specified model.

Parameters:
  • model – The model to be used for predictions.

  • dataset – The dataset to perform predictions on.

  • loss_func – The loss function used during training.

  • activation_fn – The activation function applied to the model’s output.

  • dump_dir – Directory where the model state is saved.

  • fold – The fold number in cross-validation.

  • target_scaler – (optional) Scaler to inverse transform the model’s output. Defaults to None.

  • epoch – (int) The current epoch number. Defaults to 1.

  • load_model – (bool) Whether to load the model from a saved state. Defaults to False.

  • feature_name – (str, optional) Name of the feature for data processing. Defaults to None.

Returns:

A tuple (y_preds, val_loss, metric_score), where y_preds are the predicted outputs, val_loss is the validation loss, and metric_score is the calculated metric score.

inference(model, dataset, return_repr=False, return_atomic_reprs=False, feature_name=None)[source]

Runs inference on the given dataset using the provided model. This method can return various representations based on the model’s output.

Parameters:
  • model – The neural network model to be used for inference.

  • dataset – The dataset on which inference is to be performed.

  • return_repr – (bool, optional) If True, returns class-level representations. Defaults to False.

  • return_atomic_reprs – (bool, optional) If True, returns atomic-level representations. Defaults to False.

  • feature_name – (str, optional) Name of the feature used for data loading. Defaults to None.

Returns:

A dictionary containing different types of representations based on the model’s output and the specified parameters. This can include class-level representations, atomic coordinates, atomic representations, and atomic symbols.

set_seed(seed)[source]

Sets a random seed for torch and numpy to ensure reproducibility. :param seed: (int) The seed number to be set.

unimol_tools.tasks.trainer.NNDataLoader(feature_name=None, dataset=None, batch_size=None, shuffle=False, collate_fn=None, drop_last=False)[source]

Creates a DataLoader for neural network training or inference. This function is a wrapper around the standard PyTorch DataLoader, allowing for custom feature handling and additional configuration.

Parameters:
  • feature_name – (str, optional) Name of the feature used for data loading. This can be used to specify a particular type of data processing. Defaults to None.

  • dataset – (Dataset, optional) The dataset from which to load the data. Defaults to None.

  • batch_size – (int, optional) Number of samples per batch to load. Defaults to None.

  • shuffle – (bool, optional) Whether to shuffle the data at every epoch. Defaults to False.

  • collate_fn – (callable, optional) Merges a list of samples to form a mini-batch. Defaults to None.

  • drop_last – (bool, optional) Set to True to drop the last incomplete batch. Defaults to False.

Returns:

DataLoader configured according to the provided parameters.

unimol_tools.tasks.trainer.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1)[source]

Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

Args:
optimizer ([~torch.optim.Optimizer]):

The optimizer for which to schedule the learning rate.

num_warmup_steps (int):

The number of steps for the warmup phase.

num_training_steps (int):

The total number of training steps.

last_epoch (int, optional, defaults to -1):

The index of the last epoch when resuming training.

Return:

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.