import copy
import os
import numpy as np
import pandas as pd
import torch
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import (
accuracy_score,
average_precision_score,
cohen_kappa_score,
f1_score,
log_loss,
matthews_corrcoef,
mean_absolute_error,
mean_squared_error,
precision_score,
r2_score,
recall_score,
roc_auc_score,
)
from .base_logger import logger
def cal_nan_metric(y_true, y_pred, nan_value=None, metric_func=None):
if y_true.shape != y_pred.shape:
raise ValueError('y_ture and y_pred must have same shape')
if isinstance(y_true, pd.DataFrame):
y_true = y_true.to_numpy()
if isinstance(y_pred, pd.DataFrame):
y_pred = y_pred.to_numpy()
if not np.issubdtype(y_true.dtype, np.floating):
y_true = y_true.astype(np.float64)
mask = ~np.isnan(y_true)
if nan_value is not None:
mask = mask & (y_true != nan_value)
sz = y_true.shape[1]
result = []
for i in range(sz):
_mask = mask[:, i]
if not (~_mask).all():
result.append(metric_func(y_true[:, i][_mask], y_pred[:, i][_mask]))
return np.mean(result)
def multi_acc(y_true, y_pred):
y_true = y_true.flatten()
y_pred_idx = np.argmax(y_pred, axis=1)
return np.mean(y_true == y_pred_idx)
def log_loss_with_label(y_true, y_pred, labels=None):
if labels is None:
return log_loss(y_true, y_pred)
else:
return log_loss(y_true, y_pred, labels=labels)
def reg_preasonr(y_true, y_pred):
return pearsonr(y_true, y_pred)[0]
def reg_spearmanr(y_true, y_pred):
return spearmanr(y_true, y_pred)[0]
# metric_func, is_increase, value_type
METRICS_REGISTER = {
'regression': {
"mae": [mean_absolute_error, False, 'float'],
"pearsonr": [reg_preasonr, True, 'float'],
"spearmanr": [reg_spearmanr, True, 'float'],
"mse": [mean_squared_error, False, 'float'],
"r2": [r2_score, True, 'float'],
},
'classification': {
"auroc": [roc_auc_score, True, 'float'],
"auc": [roc_auc_score, True, 'float'],
"auprc": [average_precision_score, True, 'float'],
"log_loss": [log_loss, False, 'float'],
"acc": [accuracy_score, True, 'int'],
"f1_score": [f1_score, True, 'int'],
"mcc": [matthews_corrcoef, True, 'int'],
"precision": [precision_score, True, 'int'],
"recall": [recall_score, True, 'int'],
"cohen_kappa": [cohen_kappa_score, True, 'int'],
},
'multiclass': {
"log_loss": [log_loss_with_label, False, 'float'],
"acc": [multi_acc, True, 'int'],
},
'multilabel_classification': {
"auroc": [roc_auc_score, True, 'float'],
"auc": [roc_auc_score, True, 'float'],
"auprc": [average_precision_score, True, 'float'],
"log_loss": [log_loss_with_label, False, 'float'],
"acc": [accuracy_score, True, 'int'],
"mcc": [matthews_corrcoef, True, 'int'],
},
'multilabel_regression': {
"mae": [mean_absolute_error, False, 'float'],
"mse": [mean_squared_error, False, 'float'],
"r2": [r2_score, True, 'float'],
},
}
DEFAULT_METRICS = {
'regression': ['mse', 'mae', 'r2', 'spearmanr', 'pearsonr'],
'classification': [
'log_loss',
'auc',
'f1_score',
'mcc',
'acc',
'precision',
'recall',
],
'multiclass': ['log_loss', 'acc'],
"multilabel_classification": ['log_loss', 'auc', 'auprc'],
"multilabel_regression": ['mse', 'mae', 'r2'],
}
[docs]
class Metrics(object):
"""
Class for calculating metrics for different tasks.
:param task: The task type. Supported tasks are 'regression', 'multilabel_regression',
'classification', 'multilabel_classification', and 'multiclass'.
:param metrics_str: Comma-separated string of metric names. If provided, only the specified metrics will be calculated. If not provided or an empty string, default metrics for the task will be used.
"""
[docs]
def __init__(self, task=None, metrics_str=None, **params):
self.task = task
self.threshold = np.arange(0, 1.0, 0.1)
self.metric_dict = self._init_metrics(self.task, metrics_str, **params)
self.METRICS_REGISTER = METRICS_REGISTER[task]
def _init_metrics(self, task, metrics_str, **params):
if task not in METRICS_REGISTER:
raise ValueError('Unknown task: {}'.format(self.task))
if (
not isinstance(metrics_str, str)
or metrics_str == ''
or metrics_str == 'none'
):
metric_dict = {
key: METRICS_REGISTER[task][key] for key in DEFAULT_METRICS[task]
}
else:
for key in metrics_str.split(','):
if key not in METRICS_REGISTER[task]:
raise ValueError('Unknown metric: {}'.format(key))
priority_metric_list = metrics_str.split(',')
metric_list = priority_metric_list + [
key for key in METRICS_REGISTER[task] if key not in priority_metric_list
]
metric_dict = {key: METRICS_REGISTER[task][key] for key in metric_list}
return metric_dict
[docs]
def cal_classification_metric(self, label, predict, nan_value=-1.0, threshold=None):
"""
:param label: the labels of the dataset.
:param predict: the predict values of the model.
"""
res_dict = {}
for metric_type, metric_value in self.metric_dict.items():
metric, _, value_type = metric_value
def nan_metric(label, predict):
return cal_nan_metric(label, predict, nan_value, metric)
if value_type == 'float':
res_dict[metric_type] = nan_metric(
label.astype(int), predict.astype(np.float32)
)
elif value_type == 'int':
thre = 0.5 if threshold is None else threshold
res_dict[metric_type] = nan_metric(
label.astype(int), (predict > thre).astype(int)
)
# TO DO : add more metrics by grid search threshold
return res_dict
[docs]
def cal_reg_metric(self, label, predict, nan_value=-1.0):
"""
:param label: the labels of the dataset.
:param predict: the predict values of the model.
"""
res_dict = {}
for metric_type, metric_value in self.metric_dict.items():
metric, _, _ = metric_value
def nan_metric(label, predict):
return cal_nan_metric(label, predict, nan_value, metric)
res_dict[metric_type] = nan_metric(label, predict)
return res_dict
[docs]
def cal_multiclass_metric(self, label, predict, nan_value=-1.0, label_cnt=-1):
"""
:param label: the labels of the dataset.
:param predict: the predict values of the model.
"""
res_dict = {}
for metric_type, metric_value in self.metric_dict.items():
metric, _, _ = metric_value
if metric_type == 'log_loss' and label_cnt is not None:
labels = list(range(label_cnt))
res_dict[metric_type] = metric(label, predict, labels)
else:
res_dict[metric_type] = metric(label, predict)
return res_dict
def cal_metric(self, label, predict, nan_value=-1.0, threshold=0.5, label_cnt=None):
if self.task in ['regression', 'multilabel_regression']:
return self.cal_reg_metric(label, predict, nan_value)
elif self.task in ['classification', 'multilabel_classification']:
return self.cal_classification_metric(label, predict, nan_value)
elif self.task in ['multiclass']:
return self.cal_multiclass_metric(label, predict, nan_value, label_cnt)
else:
raise ValueError("We will add more tasks soon")
def _early_stop_choice(
self,
wait,
min_score,
metric_score,
max_score,
model,
dump_dir,
fold,
patience,
epoch,
):
score = list(metric_score.values())[0]
judge_metric = list(metric_score.keys())[0]
is_increase = METRICS_REGISTER[self.task][judge_metric][1]
if is_increase:
is_early_stop, max_score, wait = self._judge_early_stop_increase(
wait, score, max_score, model, dump_dir, fold, patience, epoch
)
else:
is_early_stop, min_score, wait = self._judge_early_stop_decrease(
wait, score, min_score, model, dump_dir, fold, patience, epoch
)
return is_early_stop, min_score, wait, max_score
def _judge_early_stop_decrease(
self, wait, score, min_score, model, dump_dir, fold, patience, epoch
):
is_early_stop = False
if score <= min_score:
min_score = score
wait = 0
info = {'model_state_dict': model.state_dict()}
os.makedirs(dump_dir, exist_ok=True)
torch.save(info, os.path.join(dump_dir, f'model_{fold}.pth'))
elif score >= min_score:
wait += 1
if wait == patience:
logger.warning(f'Early stopping at epoch: {epoch+1}')
is_early_stop = True
return is_early_stop, min_score, wait
def _judge_early_stop_increase(
self, wait, score, max_score, model, dump_dir, fold, patience, epoch
):
is_early_stop = False
if score >= max_score:
max_score = score
wait = 0
info = {'model_state_dict': model.state_dict()}
os.makedirs(dump_dir, exist_ok=True)
torch.save(info, os.path.join(dump_dir, f'model_{fold}.pth'))
elif score <= max_score:
wait += 1
if wait == patience:
logger.warning(f'Early stopping at epoch: {epoch+1}')
is_early_stop = True
return is_early_stop, max_score, wait
def calculate_single_classification_threshold(
self, target, pred, metrics_key=None, step=20
):
data = copy.deepcopy(pred)
range_min = np.min(data).item()
range_max = np.max(data).item()
for metric_type, metric_value in self.metric_dict.items():
metric, is_increase, value_type = metric_value
if value_type == 'int':
metrics_key = metric_value
break
# default threshold metrics
if metrics_key is None:
metrics_key = METRICS_REGISTER['classification']['f1_score']
logger.info("metrics for threshold: {0}".format(metrics_key[0].__name__))
metrics = metrics_key[0]
if metrics_key[1]:
# increase metric
best_metric = float('-inf')
best_threshold = 0.5
for threshold in np.linspace(range_min, range_max, step):
pred_label = np.zeros_like(pred)
pred_label[pred > threshold] = 1
# print ("threshold: ", threshold, metric(target, pred_label))
if metric(target, pred_label) > best_metric:
best_metric = metric(target, pred_label)
best_threshold = threshold
logger.info(
"best threshold: {0}, metrics: {1}".format(best_threshold, best_metric)
)
else:
# increase metric
best_metric = float('inf')
best_threshold = 0.5
for threshold in np.linspace(range_min, range_max, step):
pred_label = np.zeros_like(pred)
pred_label[pred > threshold] = 1
if metric(target, pred_label) < best_metric:
best_metric = metric(target, pred_label)
best_threshold = threshold
logger.info(
"best threshold: {0}, metrics: {1}".format(best_threshold, best_metric)
)
return best_threshold
def calculate_classification_threshold(self, target, pred):
threshold = np.zeros(target.shape[1])
for idx in range(target.shape[1]):
threshold[idx] = self.calculate_single_classification_threshold(
target[:, idx].reshape(-1, 1),
pred[:, idx].reshape(-1, 1),
metrics_key=None,
step=20,
)
return threshold