# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import os
import joblib
import numpy as np
from scipy.stats import kurtosis, skew
from sklearn.preprocessing import (
FunctionTransformer,
MaxAbsScaler,
MinMaxScaler,
Normalizer,
PowerTransformer,
QuantileTransformer,
RobustScaler,
StandardScaler,
)
from ..utils import logger
[docs]
class TargetScaler(object):
'''
A class to scale the target.
'''
[docs]
def __init__(self, ss_method, task, load_dir=None):
"""
Initializes the TargetScaler object for scaling target values.
:param ss_method: (str) The scaling method to be used.
:param task: (str) The type of machine learning task (e.g., 'classification', 'regression').
:param load_dir: (str, optional) Directory from which to load an existing scaler.
"""
self.ss_method = ss_method
self.task = task
if load_dir and os.path.exists(os.path.join(load_dir, 'target_scaler.ss')):
self.scaler = joblib.load(os.path.join(load_dir, 'target_scaler.ss'))
else:
self.scaler = None
[docs]
def fit(self, target, dump_dir):
"""
Fits the scaler to the target values and optionally saves the scaler to disk.
:param target: (array-like) The target values to fit the scaler.
:param dump_dir: (str) Directory where the fitted scaler will be saved.
"""
if self.task in ['classification', 'multiclass', 'multilabel_classification']:
return
elif self.ss_method == 'none':
return
elif self.ss_method == 'auto':
if self.task == 'regression':
if self.is_skewed(target):
self.scaler = FunctionTransformer(
func=np.log1p, inverse_func=np.expm1
)
logger.info('Auto select robust transformer.')
else:
self.scaler = StandardScaler()
self.scaler.fit(target)
elif self.task == 'multilabel_regression':
self.scaler = []
target = np.ma.masked_invalid(target) # mask NaN value
for i in range(target.shape[1]):
if self.is_skewed(target[:, i]):
self.scaler.append(
FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
)
logger.info('Auto select robust transformer.')
else:
self.scaler.append(StandardScaler())
self.scaler[-1].fit(target[:, i : i + 1])
else:
if self.task == 'regression':
self.scaler = self.scaler_choose(self.ss_method, target)
self.scaler.fit(target)
elif self.task == 'multilabel_regression':
self.scaler = []
for i in range(target.shape[1]):
self.scaler.append(
self.scaler_choose(self.ss_method, target[:, i : i + 1])
)
self.scaler[-1].fit(target[:, i : i + 1])
try:
os.remove(os.path.join(dump_dir, 'target_scaler.ss'))
except:
pass
os.makedirs(dump_dir, exist_ok=True)
joblib.dump(self.scaler, os.path.join(dump_dir, 'target_scaler.ss'))
[docs]
def scaler_choose(self, method, target):
"""
Selects the appropriate scaler based on the scaling method and fit it to the target.
:param method: (str) The scaling method to be used.
currently support:
- 'minmax': MinMaxScaler,
- 'standard': StandardScaler,
- 'robust': RobustScaler,
- 'maxabs': MaxAbsScaler,
- 'quantile': QuantileTransformer,
- 'power_trans': PowerTransformer,
- 'normalizer': Normalizer,
- 'log1p': FunctionTransformer,
:param target: (array-like) The target values to fit the scaler.
:return: The fitted scaler object.
"""
if method == 'minmax':
scaler = MinMaxScaler()
elif method == 'standard':
scaler = StandardScaler()
elif method == 'robust':
scaler = RobustScaler()
elif method == 'maxabs':
scaler = MaxAbsScaler()
elif method == 'quantile':
scaler = QuantileTransformer()
elif method == 'power_trans':
scaler = (
PowerTransformer(method='box-cox')
if min(target) > 0
else PowerTransformer(method='yeo-johnson')
)
elif method == 'normalizer':
scaler = Normalizer()
elif method == 'log1p':
scaler = FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
else:
raise ValueError('Unknown scaler method: {}'.format(method))
return scaler
[docs]
def is_skewed(self, target):
"""
Determines whether the target values are skewed based on skewness and kurtosis metrics.
:param target: (array-like) The target values to be checked for skewness.
:return: (bool) True if the target is skewed, False otherwise.
"""
if self.task in ['classification', 'multiclass', 'multilabel_classification']:
return False
else:
return abs(skew(target)) > 5.0 or abs(kurtosis(target)) > 20.0