Source code for unimol_tools.models.unimol

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import os

# import argparse
import pathlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from addict import Dict

from ..config import MODEL_CONFIG
from ..data import Dictionary
from ..utils import logger, pad_1d_tokens, pad_2d, pad_coords
from ..weights import WEIGHT_DIR, weight_download
from .transformers import TransformerEncoderWithPair

BACKBONE = {
    'transformer': TransformerEncoderWithPair,
}


[docs] class UniMolModel(nn.Module): """ UniMolModel is a specialized model for molecular, protein, crystal, or MOF (Metal-Organic Frameworks) data. It dynamically configures its architecture based on the type of data it is intended to work with. The model supports multiple data types and incorporates various architecture configurations and pretrained weights. Attributes: - output_dim: The dimension of the output layer. - data_type: The type of data the model is designed to handle. - remove_hs: Flag to indicate whether hydrogen atoms are removed in molecular data. - pretrain_path: Path to the pretrained model weights. - dictionary: The dictionary object used for tokenization and encoding. - mask_idx: Index of the mask token in the dictionary. - padding_idx: Index of the padding token in the dictionary. - embed_tokens: Embedding layer for token embeddings. - encoder: Transformer encoder backbone of the model. - gbf_proj, gbf: Layers for Gaussian basis functions or numerical embeddings. - classification_head: The final classification head of the model. """
[docs] def __init__(self, output_dim=2, data_type='molecule', **params): """ Initializes the UniMolModel with specified parameters and data type. :param output_dim: (int) The number of output dimensions (classes). :param data_type: (str) The type of data (e.g., 'molecule', 'protein'). :param params: Additional parameters for model configuration. """ super().__init__() if data_type == 'molecule': self.args = molecule_architecture() elif data_type == 'oled': self.args = oled_architecture() elif data_type == 'protein': self.args = protein_architecture() elif data_type == 'crystal': self.args = crystal_architecture() else: raise ValueError('Current not support data type: {}'.format(data_type)) self.output_dim = output_dim self.data_type = data_type self.remove_hs = params.get('remove_hs', False) if data_type == 'molecule': name = "no_h" if self.remove_hs else "all_h" name = data_type + '_' + name else: name = data_type if not os.path.exists(os.path.join(WEIGHT_DIR, MODEL_CONFIG['weight'][name])): weight_download(MODEL_CONFIG['weight'][name], WEIGHT_DIR) if not os.path.exists(os.path.join(WEIGHT_DIR, MODEL_CONFIG['dict'][name])): weight_download(MODEL_CONFIG['dict'][name], WEIGHT_DIR) self.pretrain_path = os.path.join(WEIGHT_DIR, MODEL_CONFIG['weight'][name]) self.dictionary = Dictionary.load( os.path.join(WEIGHT_DIR, MODEL_CONFIG['dict'][name]) ) self.mask_idx = self.dictionary.add_symbol("[MASK]", is_special=True) self.padding_idx = self.dictionary.pad() self.embed_tokens = nn.Embedding( len(self.dictionary), self.args.encoder_embed_dim, self.padding_idx ) self.encoder = BACKBONE[self.args.backbone]( encoder_layers=self.args.encoder_layers, embed_dim=self.args.encoder_embed_dim, ffn_embed_dim=self.args.encoder_ffn_embed_dim, attention_heads=self.args.encoder_attention_heads, emb_dropout=self.args.emb_dropout, dropout=self.args.dropout, attention_dropout=self.args.attention_dropout, activation_dropout=self.args.activation_dropout, max_seq_len=self.args.max_seq_len, activation_fn=self.args.activation_fn, no_final_head_layer_norm=self.args.delta_pair_repr_norm_loss < 0, ) K = 128 n_edge_type = len(self.dictionary) * len(self.dictionary) self.gbf_proj = NonLinearHead( K, self.args.encoder_attention_heads, self.args.activation_fn ) if self.args.kernel == 'gaussian': self.gbf = GaussianLayer(K, n_edge_type) else: self.gbf = NumericalEmbed(K, n_edge_type) """ # To be deprecated in the future. self.classification_head = ClassificationHead( input_dim=self.args.encoder_embed_dim, inner_dim=self.args.encoder_embed_dim, num_classes=self.output_dim, activation_fn=self.args.pooler_activation_fn, pooler_dropout=self.args.pooler_dropout, ) """ if 'pooler_dropout' in params: self.args.pooler_dropout = params['pooler_dropout'] self.classification_head = LinearHead( input_dim=self.args.encoder_embed_dim, num_classes=self.output_dim, pooler_dropout=self.args.pooler_dropout, ) self.load_pretrained_weights(path=self.pretrain_path)
[docs] def load_pretrained_weights(self, path, strict=False): """ Loads pretrained weights into the model. :param path: (str) Path to the pretrained weight file. """ if path is not None: logger.info("Loading pretrained weights from {}".format(path)) state_dict = torch.load(path, map_location=lambda storage, loc: storage) if 'model' in state_dict: state_dict = state_dict['model'] elif 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] try: self.load_state_dict(state_dict, strict=strict) except RuntimeError as e: if 'classification_head.dense.weight' in state_dict: self.classification_head = ClassificationHead( input_dim=self.args.encoder_embed_dim, inner_dim=self.args.encoder_embed_dim, num_classes=self.output_dim, activation_fn=self.args.pooler_activation_fn, pooler_dropout=self.args.pooler_dropout, ) self.load_state_dict(state_dict, strict=strict) logger.warning( "This model is trained with the previous version. The classification_head is reset to previous version to load the model. This will be deprecated in the future. We recommend using the latest version of the model." ) else: raise e
[docs] @classmethod def build_model(cls, args): """ Class method to build a new instance of the UniMolModel. :param args: Arguments for model configuration. :return: An instance of UniMolModel. """ return cls(args)
[docs] def forward( self, src_tokens, src_distance, src_coord, src_edge_type, return_repr=False, return_atomic_reprs=False, **kwargs ): """ Defines the forward pass of the model. :param src_tokens: Tokenized input data. :param src_distance: Additional molecular features. :param src_coord: Additional molecular features. :param src_edge_type: Additional molecular features. :param gas_id: Optional environmental features for MOFs. :param gas_attr: Optional environmental features for MOFs. :param pressure: Optional environmental features for MOFs. :param temperature: Optional environmental features for MOFs. :param return_repr: Flags to return intermediate representations. :param return_atomic_reprs: Flags to return intermediate representations. :return: Output logits or requested intermediate representations. """ padding_mask = src_tokens.eq(self.padding_idx) if not padding_mask.any(): padding_mask = None x = self.embed_tokens(src_tokens) def get_dist_features(dist, et): n_node = dist.size(-1) gbf_feature = self.gbf(dist, et) gbf_result = self.gbf_proj(gbf_feature) graph_attn_bias = gbf_result graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) return graph_attn_bias graph_attn_bias = get_dist_features(src_distance, src_edge_type) ( encoder_rep, _, _, _, _, ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias) cls_repr = encoder_rep[:, 0, :] # CLS token repr all_repr = encoder_rep[:, :, :] # all token repr if return_repr: filtered_tensors = [] filtered_coords = [] for tokens, coord in zip(src_tokens, src_coord): filtered_tensor = tokens[ (tokens != 0) & (tokens != 1) & (tokens != 2) ] # filter out BOS(0), EOS(1), PAD(2) filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)] filtered_tensors.append(filtered_tensor) filtered_coords.append(filtered_coord) lengths = [ len(filtered_tensor) for filtered_tensor in filtered_tensors ] # Compute the lengths of the filtered tensors if return_atomic_reprs: cls_atomic_reprs = [] atomic_symbols = [] for i in range(len(all_repr)): atomic_reprs = encoder_rep[i, 1 : lengths[i] + 1, :] atomic_symbol = [] for atomic_num in filtered_tensors[i]: atomic_symbol.append(self.dictionary.symbols[atomic_num]) atomic_symbols.append(atomic_symbol) cls_atomic_reprs.append(atomic_reprs) return { 'cls_repr': cls_repr, 'atomic_symbol': atomic_symbols, 'atomic_coords': filtered_coords, 'atomic_reprs': cls_atomic_reprs, } else: return {'cls_repr': cls_repr} logits = self.classification_head(cls_repr) return logits
[docs] def batch_collate_fn(self, samples): """ Custom collate function for batch processing non-MOF data. :param samples: A list of sample data. :return: A tuple containing a batch dictionary and labels. """ batch = {} for k in samples[0][0].keys(): if k == 'src_coord': v = pad_coords( [torch.tensor(s[0][k]).float() for s in samples], pad_idx=0.0 ) elif k == 'src_edge_type': v = pad_2d( [torch.tensor(s[0][k]).long() for s in samples], pad_idx=self.padding_idx, ) elif k == 'src_distance': v = pad_2d( [torch.tensor(s[0][k]).float() for s in samples], pad_idx=0.0 ) elif k == 'src_tokens': v = pad_1d_tokens( [torch.tensor(s[0][k]).long() for s in samples], pad_idx=self.padding_idx, ) batch[k] = v try: label = torch.tensor([s[1] for s in samples]) except: label = None return batch, label
[docs] class LinearHead(nn.Module): """Linear head."""
[docs] def __init__( self, input_dim, num_classes, pooler_dropout, ): """ Initialize the Linear head. :param input_dim: Dimension of input features. :param num_classes: Number of classes for output. """ super().__init__() self.out_proj = nn.Linear(input_dim, num_classes) self.dropout = nn.Dropout(p=pooler_dropout)
[docs] def forward(self, features, **kwargs): """ Forward pass for the Linear head. :param features: Input features. :return: Output from the Linear head. """ x = features x = self.dropout(x) x = self.out_proj(x) return x
[docs] class ClassificationHead(nn.Module): """Head for sentence-level classification tasks."""
[docs] def __init__( self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout, ): """ Initialize the classification head. :param input_dim: Dimension of input features. :param inner_dim: Dimension of the inner layer. :param num_classes: Number of classes for classification. :param activation_fn: Activation function name. :param pooler_dropout: Dropout rate for the pooling layer. """ super().__init__() self.dense = nn.Linear(input_dim, inner_dim) self.activation_fn = get_activation_fn(activation_fn) self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes)
[docs] def forward(self, features, **kwargs): """ Forward pass for the classification head. :param features: Input features for classification. :return: Output from the classification head. """ x = features x = self.dropout(x) x = self.dense(x) x = self.activation_fn(x) x = self.dropout(x) x = self.out_proj(x) return x
[docs] class NonLinearHead(nn.Module): """ A neural network module used for simple classification tasks. It consists of a two-layered linear network with a nonlinear activation function in between. Attributes: - linear1: The first linear layer. - linear2: The second linear layer that outputs to the desired dimensions. - activation_fn: The nonlinear activation function. """
[docs] def __init__( self, input_dim, out_dim, activation_fn, hidden=None, ): """ Initializes the NonLinearHead module. :param input_dim: Dimension of the input features. :param out_dim: Dimension of the output. :param activation_fn: The activation function to use. :param hidden: Dimension of the hidden layer; defaults to the same as input_dim if not provided. """ super().__init__() hidden = input_dim if not hidden else hidden self.linear1 = nn.Linear(input_dim, hidden) self.linear2 = nn.Linear(hidden, out_dim) self.activation_fn = get_activation_fn(activation_fn)
[docs] def forward(self, x): """ Forward pass of the NonLinearHead. :param x: Input tensor to the module. :return: Tensor after passing through the network. """ x = self.linear1(x) x = self.activation_fn(x) x = self.linear2(x) return x
@torch.jit.script def gaussian(x, mean, std): """ Gaussian function implemented for PyTorch tensors. :param x: The input tensor. :param mean: The mean for the Gaussian function. :param std: The standard deviation for the Gaussian function. :return: The output tensor after applying the Gaussian function. """ pi = 3.14159 a = (2 * pi) ** 0.5 return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
[docs] def get_activation_fn(activation): """Returns the activation function corresponding to `activation`""" if activation == "relu": return F.relu elif activation == "gelu": return F.gelu elif activation == "tanh": return torch.tanh elif activation == "linear": return lambda x: x else: raise RuntimeError("--activation-fn {} not supported".format(activation))
[docs] class GaussianLayer(nn.Module): """ A neural network module implementing a Gaussian layer, useful in graph neural networks. Attributes: - K: Number of Gaussian kernels. - means, stds: Embeddings for the means and standard deviations of the Gaussian kernels. - mul, bias: Embeddings for scaling and bias parameters. """
[docs] def __init__(self, K=128, edge_types=1024): """ Initializes the GaussianLayer module. :param K: Number of Gaussian kernels. :param edge_types: Number of different edge types to consider. :return: An instance of the configured Gaussian kernel and edge types. """ super().__init__() self.K = K self.means = nn.Embedding(1, K) self.stds = nn.Embedding(1, K) self.mul = nn.Embedding(edge_types, 1) self.bias = nn.Embedding(edge_types, 1) nn.init.uniform_(self.means.weight, 0, 3) nn.init.uniform_(self.stds.weight, 0, 3) nn.init.constant_(self.bias.weight, 0) nn.init.constant_(self.mul.weight, 1)
[docs] def forward(self, x, edge_type): """ Forward pass of the GaussianLayer. :param x: Input tensor representing distances or other features. :param edge_type: Tensor indicating types of edges in the graph. :return: Tensor transformed by the Gaussian layer. """ mul = self.mul(edge_type).type_as(x) bias = self.bias(edge_type).type_as(x) x = mul * x.unsqueeze(-1) + bias x = x.expand(-1, -1, -1, self.K) mean = self.means.weight.float().view(-1) std = self.stds.weight.float().view(-1).abs() + 1e-5 return gaussian(x.float(), mean, std).type_as(self.means.weight)
[docs] class NumericalEmbed(nn.Module): """ Numerical embedding module, typically used for embedding edge features in graph neural networks. Attributes: - K: Output dimension for embeddings. - mul, bias, w_edge: Embeddings for transformation parameters. - proj: Projection layer to transform inputs. - ln: Layer normalization. """
[docs] def __init__(self, K=128, edge_types=1024, activation_fn='gelu'): """ Initializes the NonLinearHead. :param input_dim: The input dimension of the first layer. :param out_dim: The output dimension of the second layer. :param activation_fn: The activation function to use. :param hidden: The dimension of the hidden layer; defaults to input_dim if not specified. """ super().__init__() self.K = K self.mul = nn.Embedding(edge_types, 1) self.bias = nn.Embedding(edge_types, 1) self.w_edge = nn.Embedding(edge_types, K) self.proj = NonLinearHead(1, K, activation_fn, hidden=2 * K) self.ln = nn.LayerNorm(K) nn.init.constant_(self.bias.weight, 0) nn.init.constant_(self.mul.weight, 1) nn.init.kaiming_normal_(self.w_edge.weight)
[docs] def forward(self, x, edge_type): # edge_type, atoms """ Forward pass of the NonLinearHead. :param x: Input tensor to the classification head. :return: The output tensor after passing through the layers. """ mul = self.mul(edge_type).type_as(x) bias = self.bias(edge_type).type_as(x) w_edge = self.w_edge(edge_type).type_as(x) edge_emb = w_edge * torch.sigmoid(mul * x.unsqueeze(-1) + bias) edge_proj = x.unsqueeze(-1).type_as(self.mul.weight) edge_proj = self.proj(edge_proj) edge_proj = self.ln(edge_proj) h = edge_proj + edge_emb h = h.type_as(self.mul.weight) return h
def molecule_architecture(): args = Dict() args.encoder_layers = 15 args.encoder_embed_dim = 512 args.encoder_ffn_embed_dim = 2048 args.encoder_attention_heads = 64 args.dropout = 0.1 args.emb_dropout = 0.1 args.attention_dropout = 0.1 args.activation_dropout = 0.0 args.pooler_dropout = 0.2 args.max_seq_len = 512 args.activation_fn = "gelu" args.pooler_activation_fn = "tanh" args.post_ln = False args.backbone = "transformer" args.kernel = "gaussian" args.delta_pair_repr_norm_loss = -1.0 return args def protein_architecture(): args = Dict() args.encoder_layers = 15 args.encoder_embed_dim = 512 args.encoder_ffn_embed_dim = 2048 args.encoder_attention_heads = 64 args.dropout = 0.1 args.emb_dropout = 0.1 args.attention_dropout = 0.1 args.activation_dropout = 0.0 args.pooler_dropout = 0.2 args.max_seq_len = 512 args.activation_fn = "gelu" args.pooler_activation_fn = "tanh" args.post_ln = False args.backbone = "transformer" args.kernel = "gaussian" args.delta_pair_repr_norm_loss = -1.0 return args def crystal_architecture(): args = Dict() args.encoder_layers = 8 args.encoder_embed_dim = 512 args.encoder_ffn_embed_dim = 2048 args.encoder_attention_heads = 64 args.dropout = 0.1 args.emb_dropout = 0.1 args.attention_dropout = 0.1 args.activation_dropout = 0.0 args.pooler_dropout = 0.0 args.max_seq_len = 1024 args.activation_fn = "gelu" args.pooler_activation_fn = "tanh" args.post_ln = False args.backbone = "transformer" args.kernel = "linear" args.delta_pair_repr_norm_loss = -1.0 return args def oled_architecture(): args = Dict() args.encoder_layers = 8 args.encoder_embed_dim = 512 args.encoder_ffn_embed_dim = 2048 args.encoder_attention_heads = 64 args.dropout = 0.1 args.emb_dropout = 0.1 args.attention_dropout = 0.1 args.activation_dropout = 0.0 args.pooler_dropout = 0.0 args.max_seq_len = 1024 args.activation_fn = "gelu" args.pooler_activation_fn = "tanh" args.post_ln = False args.backbone = "transformer" args.kernel = "linear" args.delta_pair_repr_norm_loss = -1.0 return args