"""
Copyright (c) 2022 Poutyne and all respective contributors.
Each contributor holds copyright over their respective contributions. The project versioning (Git)
records all such contribution source information.
This file is part of Poutyne.
Poutyne is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
version.
Poutyne is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with Poutyne. If not, see
<https://www.gnu.org/licenses/>.
"""
# pylint: disable=too-many-lines
import os
import warnings
from typing import Any, Callable, Dict, List, Tuple, Union
try:
import pandas as pd
except ImportError:
pd = None
try:
# pylint: disable=unused-import
import matplotlib.pyplot # noqa: F401
is_matplotlib_available = True
except ImportError:
is_matplotlib_available = False
import torch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
from poutyne.framework.callbacks import (
AtomicCSVLogger,
BestModelRestore,
LRSchedulerCheckpoint,
ModelCheckpoint,
OptimizerCheckpoint,
PeriodicSaveLambda,
RandomStatesCheckpoint,
TensorBoardLogger,
)
from poutyne.framework.model import Model
from poutyne.plotting import plot_history
from poutyne.utils import load_random_states, set_seeds
[docs]
class ModelBundle:
"""
The :class:`~poutyne.ModelBundle` class provides a straightforward experimentation tool for efficient and entirely
customizable finetuning of the whole neural network training procedure with PyTorch. The
:class:`~poutyne.ModelBundle` object takes care of the training and testing processes while also managing to keep
traces of all pertinent information via the automatic logging option.
Use ``ModelBundle.from_*`` methods to instanciate a :class:`~poutyne.ModelBundle`.
"""
BEST_CHECKPOINT_FILENAME = 'checkpoint_epoch_{epoch}.ckpt'
MODEL_CHECKPOINT_FILENAME = 'checkpoint.ckpt'
OPTIMIZER_CHECKPOINT_FILENAME = 'checkpoint.optim'
RANDOM_STATE_CHECKPOINT_FILENAME = 'checkpoint.randomstate'
LOG_FILENAME = 'log.tsv'
TENSORBOARD_DIRECTORY = 'tensorboard'
EPOCH_FILENAME = 'last.epoch'
LR_SCHEDULER_FILENAME = 'lr_sched_%d.lrsched'
PLOTS_DIRECTORY = 'plots'
TEST_LOG_FILENAME = '{name}_log.tsv'
def __init__(
self,
directory: str,
model: Model,
*,
logging: bool = True,
monitoring: bool = True,
monitor_metric: Union[str, None] = None,
monitor_mode: Union[str, None] = None,
_is_direct=True,
) -> None:
if _is_direct:
raise TypeError("Create a ModelBundle with ModelBundle.from_* methods.")
if pd is None:
raise ImportError("pandas needs to be installed to use the class ModelBundle.")
self.directory = directory
self.model = model
self.logging = logging
self.monitoring = monitoring
self.monitor_metric = monitor_metric
self.monitor_mode = monitor_mode
self.set_paths()
[docs]
@classmethod
def from_network(
cls,
directory: str,
network: torch.nn.Module,
*,
device: Union[torch.device, List[torch.device], List[str], None, str] = None,
logging: bool = True,
optimizer: Union[torch.optim.Optimizer, str] = 'sgd',
loss_function: Union[Callable, str] = None,
batch_metrics: Union[List, None] = None,
epoch_metrics: Union[List, None] = None,
monitoring: bool = True,
monitor_metric: Union[str, None] = None,
monitor_mode: Union[str, None] = None,
task: Union[str, None] = None,
):
# pylint: disable=line-too-long
"""
Instanciate a :class:`~poutyne.ModelBundle` from a PyTorch :class:`~torch.nn.Module` instance.
Args:
directory (str): Path to the model bundle's working directory. Will be used for automatic logging.
network (torch.nn.Module): A PyTorch network.
device (Union[torch.torch.device, List[torch.torch.device], str, None]): The device to which the model is
sent or for multi-GPUs, the list of devices to which the model is to be sent. When using a string for a
multiple GPUs, the option is "all", for "take them all." By default, the current device is used as the
main one. If None, the model will be kept on its current device.
(Default value = None)
logging (bool): Whether or not to log the model bundle's progress. If true, various logging
callbacks will be inserted to output training and testing stats as well as to save model checkpoints,
for example, automatically. See :func:`~ModelBundle.train()` and :func:`~ModelBundle.test()` for more
details. (Default value = True)
optimizer (Union[torch.optim.Optimizer, str]): If Pytorch Optimizer, must already be initialized.
If str, should be the optimizer's name in Pytorch (i.e. 'Adam' for torch.optim.Adam).
(Default value = 'sgd')
loss_function(Union[Callable, str], optional) It can be any PyTorch
loss layer or custom loss function. It can also be a string with the same name as a PyTorch
loss function (either the functional or object name). The loss function must have the signature
``loss_function(input, target)`` where ``input`` is the prediction of the network and ``target``
is the ground truth. If ``None``, will default to, in priority order, either the model's own
loss function or the default loss function associated with the ``task``.
(Default value = None)
batch_metrics (list): List of functions with the same signature as a loss function or objects with the same
signature as either :class:`~poutyne.Metric` or :class:`torchmetrics.Metric <torchmetrics.Metric>`. It can
also be a string with the same name as a PyTorch loss function (either the functional or object name).
Some metrics, such as 'accuracy' (or just 'acc'), are also available as strings. See :ref:`metrics` and
the `TorchMetrics documentation <https://torchmetrics.readthedocs.io/en/latest/references/modules.html>`__
for available metrics.
Batch metric are computed on computed for each batch.
(Default value = None)
.. warning:: When using this argument, the metrics are computed for each batch. This can significantly slow
down the compuations depending on the metrics used. This mostly happens on non-decomposable metrics
such as :class:`torchmetrics.AUROC <torchmetrics.AUROC>` where an ordering of the elements is necessary
to compute the metric. In such case, we advise to use them as epoch metrics instead.
epoch_metrics (list): List of functions with the same signature as a loss function or objects with the same
signature as either :class:`~poutyne.Metric` or :class:`torchmetrics.Metric <torchmetrics.Metric>`. It can
also be a string with the same name as a PyTorch loss function (either the functional or object name).
Some metrics, such as 'accuracy' (or just 'acc'), are also available as strings. See :ref:`metrics` and
the `TorchMetrics documentation <https://torchmetrics.readthedocs.io/en/latest/references/modules.html>`__
for available metrics.
Epoch metrics are computed only at the end of the epoch.
(Default value = None)
monitoring (bool): Whether or not to monitor the training. If True will track the best epoch.
If False, ``monitor_metric`` and ``monitor_mode`` are not used, and when testing, the last epoch is used
to test the model instead of the best epoch.
(Default value = True)
monitor_metric (str, optional): Which metric to consider for best model performance calculation. Should be
in the format '{metric_name}' or 'val_{metric_name}' (i.e. 'val_loss'). If None, will follow the value
suggested by ``task`` or default to 'val_loss'. If ``monitoring`` is set to False, will be ignore.
.. warning:: If you do not plan on using a validation set, you must set the monitor metric to another
value.
monitor_mode (str, optional): Which mode, either 'min' or 'max', should be used when considering the
``monitor_metric`` value. If None, will follow the value suggested by ``task`` or default to 'min'.
If ``monitoring`` is set to False, will be ignore.
task (str, optional): Any str beginning with either 'classif' or 'reg'. Specifying a ``task``
can assign default values to the ``loss_function``, ``batch_metrics``, ``monitor_mode`` and
``monitor_mode``. For ``task`` that begins with 'reg', the only default value is the loss function
that is the mean squared error. When beginning with 'classif', the default loss function is the
cross-entropy loss. The default batch metrics will be the accuracy, the default epoch metrics will be
the F1 score and the default monitoring will be set on 'val_acc' with a 'max' mode.
(Default value = None)
Examples:
Using a PyTorch DataLoader, on classification task with SGD optimizer::
import torch
from torch.utils.data import DataLoader, TensorDataset
from poutyne import ModelBundle
num_features = 20
num_classes = 5
# Our training dataset with 800 samples.
num_train_samples = 800
train_x = torch.rand(num_train_samples, num_features)
train_y = torch.randint(num_classes, (num_train_samples, ), dtype=torch.long)
train_dataset = TensorDataset(train_x, train_y)
train_generator = DataLoader(train_dataset, batch_size=32)
# Our validation dataset with 200 samples.
num_valid_samples = 200
valid_x = torch.rand(num_valid_samples, num_features)
valid_y = torch.randint(num_classes, (num_valid_samples, ), dtype=torch.long)
valid_dataset = TensorDataset(valid_x, valid_y)
valid_generator = DataLoader(valid_dataset, batch_size=32)
# Our network
pytorch_network = torch.nn.Linear(num_features, num_train_samples)
# Initialization of our experimentation and network training
exp = ModelBundle.from_network('./simple_example',
pytorch_network,
optimizer='sgd',
task='classif')
exp.train(train_generator, valid_generator, epochs=5)
The above code will yield an output similar to the below lines. Note the automatic checkpoint saving
in the model bundle directory when the monitored metric improved.
.. code-block:: none
Epoch 1/5 0.09s Step 25/25: loss: 6.351375, acc: 1.375000, val_loss: 6.236106, val_acc: 5.000000
Epoch 1: val_acc improved from -inf to 5.00000, saving file to ./simple_example/checkpoint_epoch_1.ckpt
Epoch 2/5 0.10s Step 25/25: loss: 6.054254, acc: 14.000000, val_loss: 5.944495, val_acc: 19.500000
Epoch 2: val_acc improved from 5.00000 to 19.50000, saving file to ./simple_example/checkpoint_epoch_2.ckpt
Epoch 3/5 0.09s Step 25/25: loss: 5.759377, acc: 22.875000, val_loss: 5.655412, val_acc: 21.000000
Epoch 3: val_acc improved from 19.50000 to 21.00000, saving file to ./simple_example/checkpoint_epoch_3.ckpt
...
Training can now easily be resumed from the best checkpoint::
exp.train(train_generator, valid_generator, epochs=10)
.. code-block:: none
Restoring model from ./simple_example/checkpoint_epoch_3.ckpt
Loading weights from ./simple_example/checkpoint.ckpt and starting at epoch 6.
Loading optimizer state from ./simple_example/checkpoint.optim and starting at epoch 6.
Epoch 6/10 0.16s Step 25/25: loss: 4.897135, acc: 22.875000, val_loss: 4.813141, val_acc: 20.500000
Epoch 7/10 0.10s Step 25/25: loss: 4.621514, acc: 22.625000, val_loss: 4.545359, val_acc: 20.500000
Epoch 8/10 0.24s Step 25/25: loss: 4.354721, acc: 23.625000, val_loss: 4.287117, val_acc: 20.500000
...
Testing is also very intuitive::
exp.test(test_generator)
.. code-block:: none
Restoring model from ./simple_example/checkpoint_epoch_9.ckpt
Found best checkpoint at epoch: 9
lr: 0.01, loss: 4.09892, acc: 23.625, val_loss: 4.04057, val_acc: 21.5
On best model: test_loss: 4.06664, test_acc: 17.5
Finally, all the pertinent metrics specified to the ModelBundle at each epoch are stored in a specific logging
file, found here at './simple_example/log.tsv'.
.. code-block:: none
epoch time lr loss acc val_loss val_acc
1 0.0721172170015052 0.01 6.351375141143799 1.375 6.23610631942749 5.0
2 0.0298177790245972 0.01 6.054253826141357 14.000 5.94449516296386 19.5
3 0.0637106419890187 0.01 5.759376544952392 22.875 5.65541223526001 21.0
...
Also, we could use more than one GPU (on a single node) by using the device argument
.. code-block:: none
# Initialization of our experimentation and network training
exp = ModelBundle.from_network('./simple_example',
pytorch_network,
optimizer='sgd',
task='classif',
device="all")
exp.train(train_generator, valid_generator, epochs=5)
"""
if task is not None and not task.startswith('classif') and not task.startswith('reg'):
raise ValueError(f"Invalid task '{task}'")
batch_metrics = [] if batch_metrics is None else batch_metrics
epoch_metrics = [] if epoch_metrics is None else epoch_metrics
loss_function = cls._get_loss_function(loss_function, network, task)
batch_metrics = cls._get_batch_metrics(batch_metrics, network, task)
epoch_metrics = cls._get_epoch_metrics(epoch_metrics, network, task)
monitoring, monitor_metric, monitor_mode = cls._get_monitoring_config(
monitoring, monitor_metric, monitor_mode, task
)
model = Model(
network,
optimizer,
loss_function,
batch_metrics=batch_metrics,
epoch_metrics=epoch_metrics,
device=device,
)
return ModelBundle(
directory,
model,
logging=logging,
monitoring=monitoring,
monitor_metric=monitor_metric,
monitor_mode=monitor_mode,
_is_direct=False,
)
[docs]
@classmethod
def from_model(
cls,
directory: str,
model: Model,
*,
logging: bool = True,
monitoring: bool = True,
monitor_metric: Union[str, None] = None,
monitor_mode: Union[str, None] = None,
):
# pylint: disable=line-too-long
"""
Instanciate a :class:`~poutyne.ModelBundle` from a :class:`~poutyne.Model` instance.
Args:
directory (str): Path to the model bundle's working directory. Will be used for automatic logging.
model (poutyne.Model): A Model instance..
logging (bool): Whether or not to log the model bundle's progress. If true, various logging
callbacks will be inserted to output training and testing stats as well as to save model checkpoints,
for example, automatically. See :func:`~ModelBundle.train()` and :func:`~ModelBundle.test()` for more
details. (Default value = True)
monitoring (bool): Whether or not to monitor the training. If True will track the best epoch.
If False, ``monitor_metric`` and ``monitor_mode`` are not used, and when testing, the last epoch is used
to test the model instead of the best epoch.
(Default value = True)
monitor_metric (str, optional): Which metric to consider for best model performance calculation. Should be
in the format '{metric_name}' or 'val_{metric_name}' (i.e. 'val_loss'). If None, will follow the value
suggested by ``task`` or default to 'val_loss'. If ``monitoring`` is set to False, will be ignore.
.. warning:: If you do not plan on using a validation set, you must set the monitor metric to another
value.
monitor_mode (str, optional): Which mode, either 'min' or 'max', should be used when considering the
``monitor_metric`` value. If None, will follow the value suggested by ``task`` or default to 'min'.
If ``monitoring`` is set to False, will be ignore.
Examples:
Using a PyTorch DataLoader, on classification task with SGD optimizer::
import torch
from torch.utils.data import DataLoader, TensorDataset
from poutyne import Model, ModelBundle
num_features = 20
num_classes = 5
# Our training dataset with 800 samples.
num_train_samples = 800
train_x = torch.rand(num_train_samples, num_features)
train_y = torch.randint(num_classes, (num_train_samples, ), dtype=torch.long)
train_dataset = TensorDataset(train_x, train_y)
train_generator = DataLoader(train_dataset, batch_size=32)
# Our validation dataset with 200 samples.
num_valid_samples = 200
valid_x = torch.rand(num_valid_samples, num_features)
valid_y = torch.randint(num_classes, (num_valid_samples, ), dtype=torch.long)
valid_dataset = TensorDataset(valid_x, valid_y)
valid_generator = DataLoader(valid_dataset, batch_size=32)
# Our network
pytorch_network = torch.nn.Linear(num_features, num_train_samples)
model = Model(pytorch_network, 'sgd', 'crossentropy', batch_metrics=['accuracy'])
# Initialization of our experimentation and network training
exp = ModelBundle.from_model('./simple_example', model)
exp.train(train_generator, valid_generator, epochs=5)
The above code will yield an output similar to the below lines. Note the automatic checkpoint saving
in the model bundle directory when the monitored metric improved.
.. code-block:: none
Epoch 1/5 0.09s Step 25/25: loss: 6.351375, acc: 1.375000, val_loss: 6.236106, val_acc: 5.000000
Epoch 1: val_acc improved from -inf to 5.00000, saving file to ./simple_example/checkpoint_epoch_1.ckpt
Epoch 2/5 0.10s Step 25/25: loss: 6.054254, acc: 14.000000, val_loss: 5.944495, val_acc: 19.500000
Epoch 2: val_acc improved from 5.00000 to 19.50000, saving file to ./simple_example/checkpoint_epoch_2.ckpt
Epoch 3/5 0.09s Step 25/25: loss: 5.759377, acc: 22.875000, val_loss: 5.655412, val_acc: 21.000000
Epoch 3: val_acc improved from 19.50000 to 21.00000, saving file to ./simple_example/checkpoint_epoch_3.ckpt
...
Training can now easily be resumed from the best checkpoint::
exp.train(train_generator, valid_generator, epochs=10)
.. code-block:: none
Restoring model from ./simple_example/checkpoint_epoch_3.ckpt
Loading weights from ./simple_example/checkpoint.ckpt and starting at epoch 6.
Loading optimizer state from ./simple_example/checkpoint.optim and starting at epoch 6.
Epoch 6/10 0.16s Step 25/25: loss: 4.897135, acc: 22.875000, val_loss: 4.813141, val_acc: 20.500000
Epoch 7/10 0.10s Step 25/25: loss: 4.621514, acc: 22.625000, val_loss: 4.545359, val_acc: 20.500000
Epoch 8/10 0.24s Step 25/25: loss: 4.354721, acc: 23.625000, val_loss: 4.287117, val_acc: 20.500000
...
Testing is also very intuitive::
exp.test(test_generator)
.. code-block:: none
Restoring model from ./simple_example/checkpoint_epoch_9.ckpt
Found best checkpoint at epoch: 9
lr: 0.01, loss: 4.09892, acc: 23.625, val_loss: 4.04057, val_acc: 21.5
On best model: test_loss: 4.06664, test_acc: 17.5
Finally, all the pertinent metrics specified to the ModelBundle at each epoch are stored in a specific logging
file, found here at './simple_example/log.tsv'.
.. code-block:: none
epoch time lr loss acc val_loss val_acc
1 0.0721172170015052 0.01 6.351375141143799 1.375 6.23610631942749 5.0
2 0.0298177790245972 0.01 6.054253826141357 14.000 5.94449516296386 19.5
3 0.0637106419890187 0.01 5.759376544952392 22.875 5.65541223526001 21.0
...
"""
monitoring, monitor_metric, monitor_mode = cls._get_monitoring_config(monitoring, monitor_metric, monitor_mode)
return ModelBundle(
directory,
model,
logging=logging,
monitoring=monitoring,
monitor_metric=monitor_metric,
monitor_mode=monitor_mode,
_is_direct=False,
)
def set_paths(self):
self.best_checkpoint_filename = self.get_path(ModelBundle.BEST_CHECKPOINT_FILENAME)
self.model_checkpoint_filename = self.get_path(ModelBundle.MODEL_CHECKPOINT_FILENAME)
self.optimizer_checkpoint_filename = self.get_path(ModelBundle.OPTIMIZER_CHECKPOINT_FILENAME)
self.random_state_checkpoint_filename = self.get_path(ModelBundle.RANDOM_STATE_CHECKPOINT_FILENAME)
self.log_filename = self.get_path(ModelBundle.LOG_FILENAME)
self.tensorboard_directory = self.get_path(ModelBundle.TENSORBOARD_DIRECTORY)
self.epoch_filename = self.get_path(ModelBundle.EPOCH_FILENAME)
self.lr_scheduler_filename = self.get_path(ModelBundle.LR_SCHEDULER_FILENAME)
self.plots_directory = self.get_path(ModelBundle.PLOTS_DIRECTORY)
self.test_log_filename = self.get_path(ModelBundle.TEST_LOG_FILENAME)
[docs]
def get_path(self, *paths: str) -> str:
"""
Returns the path inside the model bundle directory.
"""
return os.path.join(self.directory, *paths)
@classmethod
def _get_loss_function(
cls, loss_function: Union[Callable, str], network: torch.nn.Module, task: Union[str, None]
) -> Union[Callable, str]:
if loss_function is None:
if hasattr(network, 'loss_function'):
return network.loss_function
if task is not None:
if task.startswith('classif'):
return 'cross_entropy'
if task.startswith('reg'):
return 'mse'
return loss_function
@classmethod
def _get_batch_metrics(
cls, batch_metrics: Union[List, None], network: torch.nn.Module, task: Union[str, None]
) -> Union[List, None]:
if batch_metrics is None or len(batch_metrics) == 0:
if hasattr(network, 'batch_metrics'):
return network.batch_metrics
if task is not None and task.startswith('classif'):
return ['accuracy']
return batch_metrics
@classmethod
def _get_epoch_metrics(cls, epoch_metrics: Union[List, None], network, task: Union[str, None]) -> Union[List, None]:
if epoch_metrics is None or len(epoch_metrics) == 0:
if hasattr(network, 'epoch_metrics'):
return network.epoch_metrics
if task is not None and task.startswith('classif'):
return ['f1']
return epoch_metrics
@classmethod
def _get_monitoring_config(
cls,
monitoring: bool,
monitor_metric: Union[str, None],
monitor_mode: Union[str, None],
task: Union[str, None] = None,
) -> None:
if not monitoring:
return False, None, None
if monitor_mode is not None and monitor_mode not in ['min', 'max']:
raise ValueError(f"Invalid mode '{monitor_mode}'")
if monitor_metric is None:
if task is not None and task.startswith('classif'):
monitor_metric = 'val_acc'
monitor_mode = 'max'
else:
monitor_metric = 'val_loss'
if monitor_mode is None:
monitor_mode = 'min'
return True, monitor_metric, monitor_mode
def get_stats(self):
if not os.path.isfile(self.log_filename):
raise ValueError("There are no logs available. Did you forget to train with logging enabled?")
return pd.read_csv(self.log_filename, sep='\t')
[docs]
def get_best_epoch_stats(self) -> Dict:
"""
Returns all computed statistics corresponding to the best epoch according to the
``monitor_metric`` and ``monitor_mode`` attributes.
Returns:
dict where each key is a column name in the logging output file
and values are the ones found at the best epoch.
"""
if not self.monitoring:
raise ValueError("Monitoring was disabled. Cannot get best epoch.")
history = self.get_stats()
if self.monitor_mode == 'min':
best_epoch_index = history[self.monitor_metric].idxmin()
else:
best_epoch_index = history[self.monitor_metric].idxmax()
return history.iloc[best_epoch_index : best_epoch_index + 1]
[docs]
def get_saved_epochs(self):
"""
Returns a pandas DataFrame which each row corresponds to an epoch having
a saved checkpoint.
Returns:
pandas DataFrame which each row corresponds to an epoch having a saved
checkpoint.
"""
if not self.monitoring:
raise ValueError("Monitoring was disabled. Except the last epoch, no epoch checkpoint were saved.")
history = self.get_stats()
metrics = history[self.monitor_metric].tolist()
if self.monitor_mode == 'min':
def monitor_op(x, y):
return x < y
current_best = float('Inf')
elif self.monitor_mode == 'max':
def monitor_op(x, y):
return x > y
current_best = -float('Inf')
saved_epoch_indices = []
for i, metric in enumerate(metrics):
if monitor_op(metric, current_best):
current_best = metric
saved_epoch_indices.append(i)
return history.iloc[saved_epoch_indices]
def _warn_missing_file(self, filename: str) -> None:
warnings.warn(f"Missing checkpoint: {filename}.")
def _load_epoch_state(self, lr_schedulers: List) -> int:
# pylint: disable=broad-except
initial_epoch = 1
if os.path.isfile(self.epoch_filename):
with open(self.epoch_filename, 'r', encoding='utf-8') as f:
initial_epoch = int(f.read()) + 1
if os.path.isfile(self.model_checkpoint_filename):
print(f"Loading weights from {self.model_checkpoint_filename} and starting at epoch {initial_epoch:d}.")
self.model.load_weights(self.model_checkpoint_filename)
else:
self._warn_missing_file(self.model_checkpoint_filename)
if os.path.isfile(self.optimizer_checkpoint_filename):
print(
f"Loading optimizer state from {self.optimizer_checkpoint_filename} and "
f"starting at epoch {initial_epoch:d}."
)
self.model.load_optimizer_state(self.optimizer_checkpoint_filename)
else:
self._warn_missing_file(self.optimizer_checkpoint_filename)
if os.path.isfile(self.random_state_checkpoint_filename):
print(
f"Loading random states from {self.random_state_checkpoint_filename} and "
f"starting at epoch {initial_epoch:d}."
)
load_random_states(self.random_state_checkpoint_filename)
else:
self._warn_missing_file(self.random_state_checkpoint_filename)
for i, lr_scheduler in enumerate(lr_schedulers):
filename = self.lr_scheduler_filename % i
if os.path.isfile(filename):
print(f"Loading LR scheduler state from {filename} and starting at epoch {initial_epoch:d}.")
lr_scheduler.load_state(filename)
else:
self._warn_missing_file(filename)
return initial_epoch
def _init_model_restoring_callbacks(
self, initial_epoch: int, keep_only_last_best: bool, save_every_epoch: bool
) -> List:
callbacks = []
if not save_every_epoch:
best_checkpoint = ModelCheckpoint(
self.best_checkpoint_filename,
monitor=self.monitor_metric,
mode=self.monitor_mode,
keep_only_last_best=keep_only_last_best,
save_best_only=True,
restore_best=True,
verbose=True,
)
callbacks.append(best_checkpoint)
else:
best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)
callbacks.append(best_restore)
if initial_epoch > 1:
# We set the current best metric score in the ModelCheckpoint so that
# it does not save checkpoint it would not have saved if the
# optimization was not stopped.
best_epoch_stats = self.get_best_epoch_stats()
best_epoch = best_epoch_stats['epoch'].item()
best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
if not save_every_epoch:
best_checkpoint.best_filename = best_filename
best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()
else:
best_restore.best_weights = torch.load(best_filename, map_location='cpu')
best_restore.current_best = best_epoch_stats[self.monitor_metric].item()
return callbacks
def _init_tensorboard_callbacks(self, disable_tensorboard: bool) -> Tuple:
tensorboard_writer = None
callbacks = []
if not disable_tensorboard:
if SummaryWriter is None:
warnings.warn(
"tensorboard does not seem to be installed. "
"To remove this warning, set the 'disable_tensorboard' "
"flag to True or install tensorboard.",
stacklevel=3,
)
else:
tensorboard_writer = SummaryWriter(self.tensorboard_directory)
callbacks += [TensorBoardLogger(tensorboard_writer)]
return tensorboard_writer, callbacks
def _init_lr_scheduler_callbacks(self, lr_schedulers: List) -> List:
callbacks = []
if self.logging:
for i, lr_scheduler in enumerate(lr_schedulers):
filename = self.lr_scheduler_filename % i
callbacks += [LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False)]
else:
callbacks += lr_schedulers
return callbacks
def _save_history(self):
if is_matplotlib_available:
history = self.get_stats()
plot_history(
history,
show=False,
save=True,
close=True,
save_directory=self.plots_directory,
save_extensions=('png', 'pdf'),
)
[docs]
def train(self, train_generator, valid_generator=None, **kwargs) -> List[Dict]:
"""
Trains or finetunes the model on a dataset using a generator. If a previous training already occurred
and lasted a total of `n_previous` epochs, then the model's weights will be set to the last checkpoint and the
training will be resumed for epochs range (`n_previous`, `epochs`].
If the ModelBundle has logging enabled (i.e. self.logging is True), numerous callbacks will be automatically
included. Notably, two :class:`~poutyne.ModelCheckpoint` objects will take care of saving the last and every
new best (according to monitor mode) model weights in appropriate checkpoint files.
:class:`~poutyne.OptimizerCheckpoint` and :class:`~poutyne.LRSchedulerCheckpoint` will also respectively
handle the saving of the optimizer and LR scheduler's respective states for future retrieval. Moreover, a
:class:`~poutyne.AtomicCSVLogger` will save all available epoch statistics in an output .tsv file. Lastly, a
:class:`~poutyne.TensorBoardLogger` handles automatic TensorBoard logging of various neural network
statistics.
Args:
train_generator: Generator-like object for the training set. See :func:`~Model.fit_generator()`
for details on the types of generators supported.
valid_generator (optional): Generator-like object for the validation set. See
:func:`~Model.fit_generator()` for details on the types of generators supported.
(Default value = None)
callbacks (List[~poutyne.Callback]): List of callbacks that will be called during
training. These callbacks are added after those used in this method (see above). This allows to assume
that they are called after those.
(Default value = None)
lr_schedulers: List of learning rate schedulers. (Default value = None)
keep_only_last_best (bool): Whether only the last saved best checkpoint is kept. Applies only when
`save_every_epoch` is false.
(Default value = False)
save_every_epoch (bool, optional): Whether or not to save the model bundle's model's weights after
every epoch.
(Default value = False)
disable_tensorboard (bool, optional): Whether or not to disable the automatic tensorboard logging
callbacks.
(Default value = False)
seed (int, optional): Seed used to make the sampling deterministic.
(Default value = 42)
kwargs: Any keyword arguments to pass to :func:`~Model.fit_generator()`.
Returns:
List of dict containing the history of each epoch.
"""
return self._train(self.model.fit_generator, train_generator, valid_generator, **kwargs)
[docs]
def train_dataset(self, train_dataset, valid_dataset=None, **kwargs) -> List[Dict]:
"""
Trains or finetunes the model on a dataset. If a previous training already occurred
and lasted a total of `n_previous` epochs, then the model's weights will be set to the last checkpoint and the
training will be resumed for epochs range (`n_previous`, `epochs`].
If the ModelBundle has logging enabled (i.e. self.logging is True), numerous callbacks will be automatically
included. Notably, two :class:`~poutyne.ModelCheckpoint` objects will take care of saving the last and every
new best (according to monitor mode) model weights in appropriate checkpoint files.
:class:`~poutyne.OptimizerCheckpoint` and :class:`~poutyne.LRSchedulerCheckpoint` will also respectively
handle the saving of the optimizer and LR scheduler's respective states for future retrieval. Moreover, a
:class:`~poutyne.AtomicCSVLogger` will save all available epoch statistics in an output .tsv file. Lastly, a
:class:`~poutyne.TensorBoardLogger` handles automatic TensorBoard logging of various neural network
statistics.
Args:
train_dataset (~torch.utils.data.Dataset): Training dataset.
valid_dataset (~torch.utils.data.Dataset): Validation dataset.
callbacks (List[~poutyne.Callback]): List of callbacks that will be called during
training. These callbacks are added after those used in this method (see above). This allows to assume
that they are called after those.
(Default value = None)
lr_schedulers: List of learning rate schedulers. (Default value = None)
keep_only_last_best (bool): Whether only the last saved best checkpoint is kept. Applies only when
`save_every_epoch` is false.
(Default value = False)
save_every_epoch (bool, optional): Whether or not to save the model bundle's model's weights after
every epoch.
(Default value = False)
disable_tensorboard (bool, optional): Whether or not to disable the automatic tensorboard logging
callbacks.
(Default value = False)
seed (int, optional): Seed used to make the sampling deterministic.
(Default value = 42)
kwargs: Any keyword arguments to pass to :func:`~Model.fit_dataset()`.
Returns:
List of dict containing the history of each epoch.
"""
return self._train(self.model.fit_dataset, train_dataset, valid_dataset, **kwargs)
[docs]
def train_data(self, x, y, validation_data=None, **kwargs) -> List[Dict]:
"""
Trains or finetunes the model on data under the form of NumPy arrays or torch tensors. If a previous
training already occurred and lasted a total of `n_previous` epochs, then the model's weights will be set to the
last checkpoint and the training will be resumed for epochs range (`n_previous`, `epochs`].
If the ModelBundle has logging enabled (i.e. self.logging is True), numerous callbacks will be automatically
included. Notably, two :class:`~poutyne.ModelCheckpoint` objects will take care of saving the last and every
new best (according to monitor mode) model weights in appropriate checkpoint files.
:class:`~poutyne.OptimizerCheckpoint` and :class:`~poutyne.LRSchedulerCheckpoint` will also respectively
handle the saving of the optimizer and LR scheduler's respective states for future retrieval. Moreover, a
:class:`~poutyne.AtomicCSVLogger` will save all available epoch statistics in an output .tsv file. Lastly, a
:class:`~poutyne.TensorBoardLogger` handles automatic TensorBoard logging of various neural network
statistics.
Args:
x (Union[~torch.Tensor, ~numpy.ndarray] or Union[tuple, list] of Union[~torch.Tensor, ~numpy.ndarray]):
Training dataset. Union[Tensor, ndarray] if the model has a single input.
Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple inputs.
y (Union[~torch.Tensor, ~numpy.ndarray] or Union[tuple, list] of Union[~torch.Tensor, ~numpy.ndarray]):
Target. Union[Tensor, ndarray] if the model has a single output.
Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple outputs.
validation_data (Tuple[``x_val``, ``y_val``]):
Same format as ``x`` and ``y`` previously described. Validation dataset on which to
evaluate the loss and any model metrics at the end of each epoch. The model will not be
trained on this data.
(Default value = None)
callbacks (List[~poutyne.Callback]): List of callbacks that will be called during
training. These callbacks are added after those used in this method (see above). This allows to assume
that they are called after those.
(Default value = None)
lr_schedulers: List of learning rate schedulers. (Default value = None)
keep_only_last_best (bool): Whether only the last saved best checkpoint is kept. Applies only when
`save_every_epoch` is false.
(Default value = False)
save_every_epoch (bool, optional): Whether or not to save the model bundle's model's weights after
every epoch.
(Default value = False)
disable_tensorboard (bool, optional): Whether or not to disable the automatic tensorboard logging
callbacks.
(Default value = False)
seed (int, optional): Seed used to make the sampling deterministic.
(Default value = 42)
kwargs: Any keyword arguments to pass to :func:`~Model.fit()`.
Returns:
List of dict containing the history of each epoch.
"""
return self._train(self.model.fit, x, y, validation_data, **kwargs)
def _train(
self,
training_func,
*args,
callbacks: Union[List, None] = None,
lr_schedulers: Union[List, None] = None,
keep_only_last_best: bool = False,
save_every_epoch: bool = False,
disable_tensorboard: bool = False,
seed: int = 42,
**kwargs,
) -> List[Dict]:
set_seeds(seed)
lr_schedulers = [] if lr_schedulers is None else lr_schedulers
expt_callbacks = []
tensorboard_writer = None
initial_epoch = 1
if self.logging:
if not os.path.exists(self.directory):
os.makedirs(self.directory)
# Restarting optimization if needed.
initial_epoch = self._load_epoch_state(lr_schedulers)
expt_callbacks += [AtomicCSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1)]
if self.monitoring:
expt_callbacks += self._init_model_restoring_callbacks(
initial_epoch, keep_only_last_best, save_every_epoch
)
if save_every_epoch:
expt_callbacks += [
ModelCheckpoint(
self.best_checkpoint_filename,
save_best_only=False,
restore_best=False,
verbose=False,
)
]
expt_callbacks += [ModelCheckpoint(self.model_checkpoint_filename, verbose=False)]
expt_callbacks += [OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False)]
expt_callbacks += [RandomStatesCheckpoint(self.random_state_checkpoint_filename, verbose=False)]
# We save the last epoch number after the end of the epoch so that the
# _load_epoch_state() knows which epoch to restart the optimization.
expt_callbacks += [
PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd), self.epoch_filename, open_mode='w')
]
tensorboard_writer, cb_list = self._init_tensorboard_callbacks(disable_tensorboard)
expt_callbacks += cb_list
else:
if self.monitoring:
expt_callbacks += [BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)]
# This method returns callbacks that checkpoints the LR scheduler if logging is enabled.
# Otherwise, it just returns the list of LR schedulers with a BestModelRestore callback.
expt_callbacks += self._init_lr_scheduler_callbacks(lr_schedulers)
if callbacks is not None:
expt_callbacks += callbacks
try:
return training_func(*args, initial_epoch=initial_epoch, callbacks=expt_callbacks, **kwargs)
finally:
if self.logging:
self._save_history()
if tensorboard_writer is not None:
tensorboard_writer.close()
[docs]
def load_checkpoint(
self, checkpoint: Union[int, str], *, verbose: bool = False, strict: bool = True
) -> Union[Dict, None]:
"""
Loads the model's weights with the weights at a given checkpoint epoch.
Args:
checkpoint (Union[int, str]): Which checkpoint to load the model's weights form.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
verbose (bool, optional): Whether or not to print the checkpoint filename, and the best epoch
number and stats when checkpoint is 'best'.
(Default value = False)
Returns:
If checkpoint is 'best', will return the best epoch stats, as per :func:`~get_best_epoch_stats()`,
if checkpoint is 'last', will return the last epoch stats, if checkpoint is a int, will return the
epoch number stats, if a path, will return the stats of that specific checkpoint.
else None.
"""
epoch_stats = None
if isinstance(checkpoint, int):
epoch_stats, incompatible_keys = self._load_epoch_checkpoint(checkpoint, verbose=verbose, strict=strict)
elif checkpoint == 'best':
epoch_stats, incompatible_keys = self._load_best_checkpoint(verbose=verbose, strict=strict)
elif checkpoint == 'last':
epoch_stats, incompatible_keys = self._load_last_checkpoint(verbose=verbose, strict=strict)
else:
incompatible_keys = self._load_path_checkpoint(path=checkpoint, verbose=verbose, strict=strict)
if len(incompatible_keys.unexpected_keys) > 0:
warnings.warn(
'Unexpected key(s): ' + ', '.join(f'"{k}"' for k in incompatible_keys.unexpected_keys) + '.',
stacklevel=2,
)
if len(incompatible_keys.missing_keys) > 0:
warnings.warn(
'Missing key(s): ' + ', '.join(f'"{k}"' for k in incompatible_keys.missing_keys) + '.',
stacklevel=2,
)
return epoch_stats
def _print_epoch_stats(self, epoch_stats):
metrics_str = ', '.join(
f'{metric_name}: {epoch_stats[metric_name].item():g}' for metric_name in epoch_stats.columns[2:]
)
print(metrics_str)
def _load_epoch_checkpoint(self, epoch: int, *, verbose: bool = False, strict: bool = True) -> None:
ckpt_filename = self.best_checkpoint_filename.format(epoch=epoch)
history = self.get_stats()
epoch_stats = history.iloc[epoch - 1 : epoch]
if verbose:
print(f"Loading checkpoint {ckpt_filename}")
self._print_epoch_stats(epoch_stats)
if not os.path.isfile(ckpt_filename):
raise ValueError(f"No checkpoint found for epoch {epoch}")
return epoch_stats, self.model.load_weights(ckpt_filename, strict=strict)
def _load_best_checkpoint(self, *, verbose: bool = False, strict: bool = True) -> Dict:
best_epoch_stats = self.get_best_epoch_stats()
best_epoch = best_epoch_stats['epoch'].item()
ckpt_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
if verbose:
print(f"Found best checkpoint at epoch: {best_epoch}")
self._print_epoch_stats(best_epoch_stats)
print(f"Loading checkpoint {ckpt_filename}")
return best_epoch_stats, self.model.load_weights(ckpt_filename, strict=strict)
def _load_last_checkpoint(self, *, verbose: bool = False, strict: bool = True) -> None:
history = self.get_stats()
epoch_stats = history.iloc[-1:]
if verbose:
print(f"Loading checkpoint {self.model_checkpoint_filename}")
self._print_epoch_stats(epoch_stats)
return epoch_stats, self.model.load_weights(self.model_checkpoint_filename, strict=strict)
def _load_path_checkpoint(self, path, verbose: bool = False, strict: bool = True) -> None:
if verbose:
print(f"Loading checkpoint {path}")
return self.model.load_weights(path, strict=strict)
[docs]
def test(self, test_generator, **kwargs):
"""
Computes and returns the loss and the metrics of the model on a given test examples
generator.
If the ModelBundle has logging enabled (i.e. self.logging is True), a checkpoint (the best one by default)
is loaded and test and validation statistics are saved in a specific test output .tsv file. Otherwise, the
current weights of the network is used for testing and statistics are only shown in the standard output.
Args:
test_generator: Generator-like object for the test set. See :func:`~Model.fit_generator()` for
details on the types of generators supported.
checkpoint (Union[str, int]): Which model checkpoint weights to load for the test evaluation.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
This argument has no effect when logging is disabled. (Default value = 'best')
seed (int, optional): Seed used to make the sampling deterministic.
(Default value = 42)
name (str): Prefix of the test log file. (Default value = 'test')
kwargs: Any keyword arguments to pass to :func:`~Model.evaluate_generator()`.
If the ModelBundle has logging enabled (i.e. self.logging is True), one callback will be automatically
included to save the test metrics. Moreover, a :class:`~poutyne.AtomicCSVLogger` will save the test
metrics in an output .tsv file.
Returns:
dict sorting of all the test metrics values by their names.
"""
return self._test(self.model.evaluate_generator, test_generator, **kwargs)
[docs]
def test_dataset(self, test_dataset, **kwargs) -> Dict:
"""
Computes and returns the loss and the metrics of the model on a given test dataset.
If the ModelBundle has logging enabled (i.e. self.logging is True), a checkpoint (the best one by default)
is loaded and test and validation statistics are saved in a specific test output .tsv file. Otherwise, the
current weights of the network is used for testing and statistics are only shown in the standard output.
Args:
test_dataset (~torch.utils.data.Dataset): Test dataset.
checkpoint (Union[str, int]): Which model checkpoint weights to load for the test evaluation.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
This argument has no effect when logging is disabled. (Default value = 'best')
seed (int, optional): Seed used to make the sampling deterministic.
(Default value = 42)
name (str): Prefix of the test log file. (Default value = 'test')
kwargs: Any keyword arguments to pass to :func:`~Model.evaluate_dataset()`.
If the ModelBundle has logging enabled (i.e. self.logging is True), one callback will be automatically
included to save the test metrics. Moreover, a :class:`~poutyne.AtomicCSVLogger` will save the test
metrics in an output .tsv file.
Returns:
dict sorting of all the test metrics values by their names.
"""
return self._test(self.model.evaluate_dataset, test_dataset, **kwargs)
[docs]
def test_data(self, x, y, **kwargs) -> Dict:
"""
Computes and returns the loss and the metrics of the model on a given test dataset.
If the ModelBundle has logging enabled (i.e. self.logging is True), a checkpoint (the best one by default)
is loaded and test and validation statistics are saved in a specific test output .tsv file. Otherwise, the
current weights of the network is used for testing and statistics are only shown in the standard output.
Args:
x (Union[~torch.Tensor, ~numpy.ndarray] or Union[tuple, list] of Union[~torch.Tensor, ~numpy.ndarray]):
Input to the model. Union[Tensor, ndarray] if the model has a single input.
Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple inputs.
y (Union[~torch.Tensor, ~numpy.ndarray] or Union[tuple, list] of Union[~torch.Tensor, ~numpy.ndarray]):
Target, corresponding ground truth.
Union[Tensor, ndarray] if the model has a single output.
Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple outputs.
checkpoint (Union[str, int]): Which model checkpoint weights to load for the test evaluation.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
This argument has no effect when logging is disabled. (Default value = 'best')
seed (int, optional): Seed used to make the sampling deterministic.
(Default value = 42)
name (str): Prefix of the test log file. (Default value = 'test')
kwargs: Any keyword arguments to pass to :func:`~Model.evaluate()`.
If the ModelBundle has logging enabled (i.e. self.logging is True), one callback will be automatically
included to save the test metrics. Moreover, a :class:`~poutyne.AtomicCSVLogger` will save the test
metrics in an output .tsv file.
Returns:
dict sorting of all the test metrics values by their names.
"""
return self._test(self.model.evaluate, x, y, **kwargs)
def _test(
self,
evaluate_func,
*args,
checkpoint: Union[str, int] = 'best',
seed: int = 42,
name='test',
verbose=True,
**kwargs,
) -> Dict:
if kwargs.get('return_dict_format') is False:
raise ValueError("This method only returns a dict.")
kwargs['return_dict_format'] = True
set_seeds(seed)
if self.logging:
if not self.monitoring and checkpoint == 'best':
checkpoint = 'last'
epoch_stats = self.load_checkpoint(checkpoint, verbose=verbose)
if verbose:
print(f"Running {name}")
ret = evaluate_func(*args, **kwargs, verbose=verbose)
if self.logging:
test_metrics_dict = ret[0] if isinstance(ret, tuple) else ret
test_stats = pd.DataFrame([list(test_metrics_dict.values())], columns=list(test_metrics_dict.keys()))
test_stats.drop(['time'], axis=1, inplace=True)
if epoch_stats is not None:
epoch_stats = epoch_stats.reset_index(drop=True)
test_stats = epoch_stats.join(test_stats)
test_stats.to_csv(self.test_log_filename.format(name=name), sep='\t', index=False)
return ret
[docs]
def infer(self, generator, **kwargs) -> Any:
"""
Returns the predictions of the network given batches of samples ``x``, where the tensors are
converted into Numpy arrays.
Args:
generator: Generator-like object for the dataset. The generator must yield a batch of
samples. See the :func:`fit_generator()` method for details on the types of generators
supported. This should only yield input data ``x`` and NOT the target ``y``.
checkpoint (Union[str, int]): Which model checkpoint weights to load for the prediction.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
This argument has no effect when logging is disabled. (Default value = 'best')
kwargs: Any keyword arguments to pass to :func:`~Model.predict_generator()`.
Returns:
Depends on the value of ``concatenate_returns``. By default, (``concatenate_returns`` is true),
the data structures (tensor, tuple, list, dict) returned as predictions for the batches are
merged together. In the merge, the tensors are converted into Numpy arrays and are then
concatenated together. If ``concatenate_returns`` is false, then a list of the predictions
for the batches is returned with tensors converted into Numpy arrays.
"""
return self._predict(self.model.predict_generator, generator, **kwargs)
[docs]
def infer_dataset(self, dataset, **kwargs) -> Any:
"""
Returns the inferred predictions of the network given a dataset, where the tensors are
converted into Numpy arrays.
Args:
dataset (~torch.utils.data.Dataset): Dataset. Must not return ``y``, just ``x``.
checkpoint (Union[str, int]): Which model checkpoint weights to load for the prediction.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
This argument has no effect when logging is disabled. (Default value = 'best')
kwargs: Any keyword arguments to pass to :func:`~Model.predict_dataset()`.
Returns:
Return the predictions in the format outputted by the model.
"""
return self._predict(self.model.predict_dataset, dataset, **kwargs)
[docs]
def infer_data(self, x, **kwargs) -> Any:
"""
Returns the inferred predictions of the network given a dataset ``x``, where the tensors are
converted into Numpy arrays.
Args:
x (Union[~torch.Tensor, ~numpy.ndarray] or Union[tuple, list] of Union[~torch.Tensor, ~numpy.ndarray]):
Input to the model. Union[Tensor, ndarray] if the model has a single input.
Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple inputs.
checkpoint (Union[str, int]): Which model checkpoint weights to load for the prediction.
- If 'best', will load the best weights according to ``monitor_metric`` and ``monitor_mode``.
- If 'last', will load the last model checkpoint.
- If int, will load the checkpoint of the specified epoch.
- If a path (str), will load the model pickled state_dict weights (for instance, saved as
``torch.save(a_pytorch_network.state_dict(), "./a_path.p")``).
This argument has no effect when logging is disabled. (Default value = 'best')
kwargs: Any keyword arguments to pass to :func:`~Model.predict()`.
Returns:
Return the predictions in the format outputted by the model.
"""
return self._predict(self.model.predict, x, **kwargs)
def _predict(
self, predict_func: Callable, *args, verbose=True, checkpoint: Union[str, int] = 'best', **kwargs
) -> Any:
if self.logging:
if not self.monitoring and checkpoint == 'best':
checkpoint = 'last'
self.load_checkpoint(checkpoint, verbose=verbose)
ret = predict_func(*args, verbose=verbose, **kwargs)
return ret
[docs]
def is_better_than(self, another_model_bundle) -> bool:
"""
Compare the results of the ModelBundle with another model bundle. To compare, both ModelBundles need to be
logged, monitor the same metric and the same monitor mode ("min" or "max").
Args:
another_model_bundle (~poutyne.ModelBundle): Another Poutyne model bundle to compare results with.
Return:
Whether the ModelBundle is better than the ModelBundle to compare with.
"""
if not self.logging:
raise ValueError("The model bundle is not logged.")
if not another_model_bundle.logging:
raise ValueError("The model bundle to compare to is not logged.")
if self.monitor_metric != another_model_bundle.monitor_metric:
raise ValueError("The monitored metric is not the same between the two model bundles.")
monitored_metric = self.monitor_metric
if self.monitor_mode != another_model_bundle.monitor_mode:
raise ValueError("The monitored mode is not the same between the two model bundles.")
monitor_mode = self.monitor_mode
checkpoint = 'best' if self.monitoring else 'last'
self_stats = self.load_checkpoint(checkpoint, verbose=False)
self_monitored_metric = self_stats[monitored_metric]
self_monitored_metric_value = self_monitored_metric.item()
other_checkpoint = 'best' if another_model_bundle.monitoring else 'last'
other_stats = self.load_checkpoint(other_checkpoint, verbose=False)
other_monitored_metric = other_stats[monitored_metric]
other_monitored_metric_value = other_monitored_metric.item()
if monitor_mode == 'min':
is_better_than = self_monitored_metric_value < other_monitored_metric_value
else:
is_better_than = self_monitored_metric_value > other_monitored_metric_value
return is_better_than