Source code for poutyne.framework.metrics.epoch_metrics.sklearn_metrics

from typing import Optional, Union, List, Callable, Dict, Tuple
import numpy as np
import torch
from .base import EpochMetric


[docs]class SKLearnMetrics(EpochMetric): """ Wrap metrics with Scikit-learn-like interface (``metric(y_true, y_pred, sample_weight=sample_weight, **kwargs)``). The ``SKLearnMetrics`` object has to keep in memory the ground truths and predictions so that in can compute the metric at the end. Example: .. code-block:: python from sklearn.metrics import roc_auc_score, average_precision_score from poutyne import SKLearnMetrics my_epoch_metric = SKLearnMetrics([roc_auc_score, average_precision_score]) Args: funcs (Union[Callable, List[Callable]]): A metric or a list of metrics with a scikit-learn-like interface. kwargs (Optional[Union[dict, List[dict]]]): Optional dictionary of list of dictionaries corresponding to keyword arguments to pass to each corresponding metric. (Default value = None) names (Optional[Union[str, List[str]]]): Optional string or list of strings corresponding to the names given to the metrics. By default, the names are the names of the functions. """ def __init__(self, funcs: Union[Callable, List[Callable]], kwargs: Optional[Union[dict, List[dict]]] = None, names: Optional[Union[str, List[str]]] = None) -> None: super().__init__() self.funcs = funcs if isinstance(funcs, (list, tuple)) else [funcs] self.kwargs = self._validate_kwargs(kwargs) self.__name__ = self._validate_names(names) self.reset() def _validate_kwargs(self, kwargs): if kwargs is not None: kwargs = kwargs if isinstance(kwargs, (list, tuple)) else [kwargs] if kwargs is not None and len(self.funcs) != len(kwargs): raise ValueError("`kwargs` has to have the same length as `funcs` when provided") else: kwargs = [{}] * len(self.funcs) if kwargs is None else kwargs return kwargs def _validate_names(self, names): if names is not None: names = names if isinstance(names, (list, tuple)) else [names] if len(self.funcs) != len(names): raise ValueError("`names` has to have the same length as `funcs` when provided") else: names = [func.__name__ for func in self.funcs] return names
[docs] def forward(self, y_pred: torch.Tensor, y_true: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> None: """ Accumulate the predictions, ground truths and sample weights if any. Args: y_pred (torch.Tensor): A tensor of predictions of the shape expected by the metric functions passed to the class. y_true (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): Ground truths. A tensor of ground truths of the shape expected by the metric functions passed to the class. It can also be a tuple with two tensors, the first being the ground truths and the second corresponding the ``sample_weight`` argument passed to the metric functions in Scikit-Learn. """ self.y_pred_list.append(y_pred.cpu().numpy()) if isinstance(y_true, (tuple, list)): y_true, sample_weight = y_true self.sample_weight_list.append(sample_weight.cpu().numpy()) self.y_true_list.append(y_true.cpu().numpy())
[docs] def get_metric(self) -> Dict: """ Returns the metrics as a dictionary with the names as keys. Note: This will reset the epoch metric value. """ sample_weight = None if len(self.sample_weight_list) != 0: sample_weight = np.concatenate(self.sample_weight_list) y_pred = np.concatenate(self.y_pred_list) y_true = np.concatenate(self.y_true_list) return { name: func(y_true, y_pred, sample_weight=sample_weight, **kwargs) for name, func, kwargs in zip(self.__name__, self.funcs, self.kwargs) }
[docs] def reset(self) -> None: self.y_true_list = [] self.y_pred_list = [] self.sample_weight_list = []