# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import os
import numpy as np
import joblib
from sklearn.preprocessing import (
StandardScaler,
MinMaxScaler,
MaxAbsScaler,
RobustScaler,
Normalizer,
QuantileTransformer,
PowerTransformer,
FunctionTransformer,
)
from scipy.stats import skew, kurtosis
from ..utils import logger
SCALER_MODE = {
'minmax': MinMaxScaler,
'standard': StandardScaler,
'robust': RobustScaler,
'maxabs': MaxAbsScaler,
'quantile': QuantileTransformer,
'power_trans': PowerTransformer,
'normalizer': Normalizer,
'log1p': FunctionTransformer,
}
[docs]
class TargetScaler(object):
'''
A class to scale the target.
'''
[docs]
def __init__(self, ss_method, task, load_dir=None):
"""
Initializes the TargetScaler object for scaling target values.
:param ss_method: (str) The scaling method to be used.
:param task: (str) The type of machine learning task (e.g., 'classification', 'regression').
:param load_dir: (str, optional) Directory from which to load an existing scaler.
"""
self.ss_method = ss_method
self.task = task
if load_dir and os.path.exists(os.path.join(load_dir, 'target_scaler.ss')):
self.scaler = joblib.load(os.path.join(load_dir, 'target_scaler.ss'))
else:
self.scaler = None
[docs]
def fit(self, target, dump_dir):
"""
Fits the scaler to the target values and optionally saves the scaler to disk.
:param target: (array-like) The target values to fit the scaler.
:param dump_dir: (str) Directory where the fitted scaler will be saved.
"""
if self.task in ['classification', 'multiclass', 'multilabel_classification']:
return
elif self.ss_method == 'none':
return
elif self.ss_method == 'auto':
if self.task == 'regression':
if self.is_skewed(target):
self.scaler = SCALER_MODE['robust']()
logger.info('Auto select robust transformer.')
else:
self.scaler = SCALER_MODE['standard']()
self.scaler.fit(target)
elif self.task == 'multilabel_regression':
self.scaler = []
target = np.ma.masked_invalid(target) # mask NaN value
for i in range(target.shape[1]):
if self.is_skewed(target[:, i]):
self.scaler.append(SCALER_MODE['robust']())
logger.info('Auto select robust transformer.')
else:
self.scaler.append(SCALER_MODE['standard']())
self.scaler[-1].fit(target[:, i:i+1])
else:
if self.task == 'regression':
self.scaler = self.scaler_choose(self.ss_method, target)
self.scaler.fit(target)
elif self.task == 'multilabel_regression':
self.scaler = []
for i in range(target.shape[1]):
self.scaler.append(self.scaler_choose(self.ss_method, target[:, i:i+1]))
self.scaler[-1].fit(target[:, i:i+1])
try:
os.remove(os.path.join(dump_dir, 'target_scaler.ss'))
except:
pass
os.makedirs(dump_dir, exist_ok=True)
joblib.dump(self.scaler, os.path.join(dump_dir, 'target_scaler.ss'))
[docs]
def scaler_choose(self, method, target):
"""
Selects the appropriate scaler based on the scaling method and fit it to the target.
:param method: (str) The scaling method to be used.
currently support:
- 'minmax': MinMaxScaler,
- 'standard': StandardScaler,
- 'robust': RobustScaler,
- 'maxabs': MaxAbsScaler,
- 'quantile': QuantileTransformer,
- 'power_trans': PowerTransformer,
- 'normalizer': Normalizer,
- 'log1p': FunctionTransformer,
:param target: (array-like) The target values to fit the scaler.
:return: The fitted scaler object.
"""
if method=='power_trans':
scaler = SCALER_MODE[method](method='box-cox') if min(target) > 0 else SCALER_MODE[method](method='yeo-johnson')
elif method=='log1p':
scaler = SCALER_MODE[method](np.log1p)
else:
scaler = SCALER_MODE[method]()
return scaler
[docs]
def is_skewed(self, target):
"""
Determines whether the target values are skewed based on skewness and kurtosis metrics.
:param target: (array-like) The target values to be checked for skewness.
:return: (bool) True if the target is skewed, False otherwise.
"""
if self.task in ['classification', 'multiclass', 'multilabel_classification']:
return False
else:
return abs(skew(target)) > 5.0 or abs(kurtosis(target)) > 20.0