Source code for poutyne.framework.callbacks.best_model_restore

import warnings

from .callbacks import Callback

[docs]class BestModelRestore(Callback): """ Restore the weights of the best model at the end of the training depending on a monitored quantity. Args: monitor (string): Quantity to monitor. (Default value = 'val_loss') mode (string): One of {min, max}. Whether the monitored has to be maximized or minimized. For instance, for `val_accuracy`, this should be `max`, and for `val_loss`, this should be `min`, etc. (Default value = 'min') verbose (bool): Whether to display a message when the model has improved or when restoring the best model. (Default value = False) """ def __init__(self, *, monitor='val_loss', mode='min', verbose=False): super().__init__() self.monitor = monitor if mode not in ['min', 'max']: raise ValueError("Invalid mode '%s'" % mode) if mode == 'min': self.monitor_op = lambda x, y: x < y self.current_best = float('Inf') elif mode == 'max': self.monitor_op = lambda x, y: x > y self.current_best = -float('Inf') self.best_weights = None self.verbose = verbose def on_epoch_end(self, epoch, logs): if self.monitor_op(logs[self.monitor], self.current_best): old_best = self.current_best self.current_best = logs[self.monitor] if self.verbose: print('Epoch %d: %s improved from %0.5f to %0.5f' % ( epoch, self.monitor, old_best, self.current_best )) self.best_weights = self.model.get_weight_copies() def on_train_end(self, logs): if self.best_weights is not None: if self.verbose: print('Restoring best model') self.model.set_weights(self.best_weights) else: warnings.warn('No weights to restore!')