"""
The source code of this file was copied from the Keras project, and has been modified. All modifications
made from the original source code are under the LGPLv3 license.
COPYRIGHT
All contributions by François Chollet:
Copyright (c) 2015, François Chollet.
All rights reserved.
All contributions by Google:
Copyright (c) 2015, Google, Inc.
All rights reserved.
All contributions by Microsoft:
Copyright (c) 2017, Microsoft, Inc.
All rights reserved.
All other contributions:
Copyright (c) 2015 - 2017, the respective contributors.
All rights reserved.
Copyright (c) 2022 Poutyne.
Each contributor holds copyright over their respective contributions. The project versioning (Git)
records all such contribution source information on the Poutyne and Keras repository.
LICENSE
The LGPLv3 License
This file is part of Poutyne.
Poutyne is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
version.
Poutyne is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with Poutyne. If not, see
<https://www.gnu.org/licenses/>.
The MIT License (MIT)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import os
import warnings
from abc import ABC, abstractmethod
from typing import IO, Callable, Dict, Optional
from poutyne.framework.callbacks._utils import atomic_lambda_save
from poutyne.framework.callbacks.callbacks import Callback
[docs]
class PeriodicSaveCallback(ABC, Callback):
"""
Write a file (or checkpoint) after every epoch. ``filename`` can contain named formatting options,
which will be filled the value of ``epoch`` and keys in ``logs`` (passed in ``on_epoch_end``). For
example: if ``filename`` is ``weights.{epoch:02d}-{val_loss:.2f}.txt``, then ``save_file()`` will be
called with a file descriptor for a file with the epoch number and the validation loss in the
filename.
By default, the file is written atomically to the specified filename so that the training can
be killed and restarted later using the same filename for periodic file saving. To do so, a
temporary file is created with the name of ``filename + '.tmp'`` and is then moved to the final
destination after the checkpoint is done. The ``temporary_filename`` argument allows to change the
path of this temporary file.
Args:
filename (str): Path to save the model file.
monitor (str): Quantity to monitor.
(Default value = 'val_loss')
verbose (bool): Whether to display a message when saving and restoring a checkpoint.
(Default value = False)
save_best_only (bool): If `save_best_only` is true, the latest best model according to the
quantity monitored will not be overwritten.
(Default value = False)
keep_only_last_best (bool): Whether only the last saved best checkpoint is kept. Applies only when
`save_best_only` is true.
(Default value = False)
restore_best (bool): If `restore_best` is true, the model will be reset to the last best checkpoint done.
This option only works when `save_best_only` is also true. (Default value = False)
mode (str): One of {'min', 'max'}.
If `save_best_only` is true, the decision to overwrite the current save file is made based
on either the maximization or the minimization of the monitored quantity. For
`val_accuracy`, this should be `max`, for `val_loss` this should be `min`, etc.
(Default value = 'min')
period (int): Interval (number of epochs) between checkpoints.
(Default value = 1)
temporary_filename (str, optional): Temporary filename for the checkpoint so that the
last checkpoint can be written atomically. See the ``atomic_write`` argument.
atomic_write (bool): Whether to write atomically the checkpoint. See the description above
for details.
(Default value = True)
open_mode (str): ``mode`` option passed to :func:`open()`.
(Default value = 'wb')
"""
def __init__(
self,
filename: str,
*,
monitor: str = 'val_loss',
mode: str = 'min',
save_best_only: bool = False,
keep_only_last_best: bool = False,
restore_best: bool = False,
period: int = 1,
verbose: bool = False,
temporary_filename: Optional[str] = None,
atomic_write: bool = True,
open_mode: str = 'wb',
read_mode: str = 'rb',
) -> None:
super().__init__()
self.filename = filename
self.monitor = monitor
self.verbose = verbose
self.save_best_only = save_best_only
self.keep_only_last_best = keep_only_last_best
self.restore_best = restore_best
self.temporary_filename = temporary_filename
self.atomic_write = atomic_write
self.open_mode = open_mode
self.read_mode = read_mode
self.best_filename = None
if self.keep_only_last_best and not self.save_best_only:
raise ValueError("The 'keep_only_last_best' argument only works when 'save_best_only' is also true.")
if self.restore_best and not self.save_best_only:
raise ValueError("The 'restore_best' argument only works when 'save_best_only' is also true.")
if self.save_best_only:
if mode not in ['min', 'max']:
raise ValueError(f"Invalid mode '{mode}'")
if mode == 'min':
self.monitor_op = lambda x, y: x < y
self.current_best = float('Inf')
elif mode == 'max':
self.monitor_op = lambda x, y: x > y
self.current_best = -float('Inf')
self.period = period
[docs]
@abstractmethod
def save_file(self, fd: IO, epoch_number: int, logs: Dict) -> None:
"""
Abstract method that is called every time a save needs to be done.
Args:
fd (IO): The descriptor of the file in which to write.
epoch_number (int): The epoch number.
logs (Dict): Dictionary passed on epoch end.
"""
pass
def _save_file(self, filename: str, epoch_number: int, logs: Dict) -> None:
atomic_lambda_save(
filename,
self.save_file,
(epoch_number, logs),
temporary_filename=self.temporary_filename,
open_mode=self.open_mode,
atomic=self.atomic_write,
)
def on_epoch_end(self, epoch_number: int, logs: Dict) -> None:
filename = self.filename.format_map(logs)
if self.save_best_only:
if self.monitor in logs:
if self.monitor_op(logs[self.monitor], self.current_best):
old_best = self.current_best
self.current_best = logs[self.monitor]
old_best_filename = self.best_filename
self.best_filename = filename
if self.verbose:
print(
f'Epoch {epoch_number:d}: {self.monitor} improved from {old_best:0.5f} '
f'to {self.current_best:0.5f}, saving file to {self.best_filename}'
)
self._save_file(self.best_filename, epoch_number, logs)
if (
self.keep_only_last_best
and self.best_filename != old_best_filename
and old_best_filename is not None
):
os.remove(old_best_filename)
else:
raise KeyError(f"The monitored metric name {self.monitor} is not found in computed metrics.")
elif epoch_number % self.period == 0:
if self.verbose:
print(f'Epoch {epoch_number:d}: saving file to {filename}')
self._save_file(filename, epoch_number, logs)
[docs]
@abstractmethod
def restore(self, fd: IO) -> None:
"""
Abstract method that is called when a save needs to be restored. This happens at the end of the training when
``restore_best`` is true.
Args:
fd (IO): The descriptor of the file to read.
"""
pass
def on_train_end(self, logs: Dict) -> None:
if self.restore_best:
if self.best_filename is not None:
if self.verbose:
print(f'Restoring data from {self.best_filename}')
# pylint: disable=unspecified-encoding
open_kwargs = dict(encoding='utf-8') if 'b' not in self.read_mode else {}
with open(self.best_filename, self.read_mode, **open_kwargs) as fd:
self.restore(fd)
else:
warnings.warn('No data to restore!')
[docs]
class PeriodicSaveLambda(PeriodicSaveCallback):
"""
Call a lambda with a file descriptor after every epoch. See
:class:`~poutyne.PeriodicSaveCallback` for the arguments' descriptions.
Args:
func (Callable[[fd, int, dict], None]): The lambda that will be called with a file descriptor, the
epoch number and the epoch logs.
restore (Callable[[fd], None]): The lambda that will be called with a file descriptor to restore
the state if necessary.
See:
:class:`~poutyne.PeriodicSaveCallback`
"""
def __init__(self, func: Callable, *args, restore: Optional[Callable] = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.func = func
self._restore = restore
def save_file(self, fd: IO, epoch_number: int, logs: Dict) -> None:
self.func(fd, epoch_number, logs)
def restore(self, fd: IO) -> None:
if self._restore is not None:
self._restore(fd)