Source code for unimol_tools.data.split

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import numpy as np
from sklearn.model_selection import (
    GroupKFold, 
    KFold, 
    StratifiedKFold,
)
from ..utils import logger


[docs] class Splitter(object): """ The Splitter class is responsible for splitting a dataset into train and test sets based on the specified method. """
[docs] def __init__(self, method='random', kfold=5, seed=42, **params): """ Initializes the Splitter with a specified split method and random seed. :param split_method: (str) The method for splitting the dataset, in the format 'Nfold_method'. Defaults to '5fold_random'. :param seed: (int) Random seed for reproducibility in random splitting. Defaults to 42. """ self.method = method self.n_splits = kfold self.seed = seed self.splitter = self._init_split()
[docs] def _init_split(self): """ Initializes the actual splitter object based on the specified method. :return: The initialized splitter object. :raises ValueError: If an unknown splitting method is specified. """ if self.n_splits == 1: return None if self.method == 'random': splitter = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.seed) elif self.method == 'scaffold' or self.method == 'group': splitter = GroupKFold(n_splits=self.n_splits) elif self.method == 'stratified': splitter = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.seed) elif self.method == 'select': splitter = GroupKFold(n_splits=self.n_splits) else: raise ValueError('Unknown splitter method: {}fold - {}'.format(self.n_splits, self.method)) return splitter
[docs] def split(self, smiles, target=None, group=None, scaffolds=None, **params): """ Splits the dataset into train and test sets based on the initialized method. :param data: The dataset to be split. :param target: (optional) Target labels for stratified splitting. Defaults to None. :param group: (optional) Group labels for group-based splitting. Defaults to None. :return: An iterator yielding train and test set indices for each fold. :raises ValueError: If the splitter method does not support the provided parameters. """ if self.n_splits == 1: logger.warning('Only one fold is used for training, no splitting is performed.') return [(np.arange(len(smiles)), ())] if smiles is None and 'atoms' in params: smiles = params['atoms'] logger.warning('Atoms are used as SMILES for splitting.') if self.method in ['random']: self.skf = self.splitter.split(smiles) elif self.method in ['scaffold']: self.skf = self.splitter.split(smiles, target, scaffolds) elif self.method in ['group']: self.skf = self.splitter.split(smiles, target, group) elif self.method in ['stratified']: self.skf = self.splitter.split(smiles, group) elif self.method in ['select']: unique_groups = np.unique(group) if len(unique_groups) == self.n_splits: split_folds = [] for unique_group in unique_groups: train_idx = np.where(group != unique_group)[0] test_idx = np.where(group == unique_group)[0] split_folds.append((train_idx, test_idx)) self.split_folds = split_folds return self.split_folds else: logger.error('The number of unique groups is not equal to the number of splits.') exit(1) else: logger.error('Unknown splitter method: {}'.format(self.method)) exit(1) self.split_folds = list(self.skf) return self.split_folds