from abc import ABC, abstractmethod
import torch.nn as nn
[docs]class EpochMetric(ABC, nn.Module):
"""
The abstract class representing a epoch metric which can be accumulated at each batch and calculated at the end
of the epoch.
"""
[docs] @abstractmethod
def forward(self, y_pred, y_true) -> None:
"""
To define the behavior of the metric when called.
Args:
y_pred: The prediction of the model.
y_true: Target to evaluate the model.
"""
pass
[docs] @abstractmethod
def get_metric(self):
"""
Compute and return the metric. Should not modify the state of
the epoch metric.
"""
pass
[docs] @abstractmethod
def reset(self) -> None:
"""
The information kept for the computation of the metric is cleaned so
that a new epoch can be done.
"""
pass