Source code for poutyne.framework.callbacks.lr_scheduler

"""
Copyright (c) 2022 Poutyne and all respective contributors.

Each contributor holds copyright over their respective contributions. The project versioning (Git)
records all such contribution source information.

This file is part of Poutyne.

Poutyne is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
version.

Poutyne is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along with Poutyne. If not, see
<https://www.gnu.org/licenses/>.
"""

import inspect
import sys
from typing import BinaryIO, Dict

import torch.optim.lr_scheduler
from torch.optim import Optimizer

try:
    from torch.optim.lr_scheduler import LRScheduler
except ImportError:
    from torch.optim.lr_scheduler import _LRScheduler as LRScheduler

from poutyne.framework.callbacks.callbacks import Callback


class _PyTorchLRSchedulerWrapper(Callback):
    """
    Default class for the LR scheduling callback. Proposes default comportment for the scheduler
    loading and saving as well as for the epoch end handling.
    """

    def __init__(self, torch_lr_scheduler, *args, **kwargs):
        super().__init__()
        if len(args) > 0 and isinstance(args[0], Optimizer):
            raise ValueError(
                "In the LR scheduler callbacks, the optimizer is "
                "automatically passed to the PyTorch's LR scheduler. "
                "You must remove it from the arguments."
            )
        self.args = args
        self.kwargs = kwargs
        self.scheduler = None
        self.state_to_load = None
        self.torch_lr_scheduler = torch_lr_scheduler

    def on_epoch_end(self, epoch_number: int, logs: Dict):
        self.scheduler.step()

    def on_train_begin(self, logs: Dict):
        self.scheduler = self.torch_lr_scheduler(self.model.optimizer, *self.args, **self.kwargs)

        # Load state if the scheduler was not initialized when the user asked
        # to load its state
        if self.state_to_load is not None:
            self.load_state_dict(self.state_to_load)
            self.state_to_load = None

    def load_state_dict(self, state_dict):
        if self.scheduler is not None:
            self.scheduler.load_state_dict(state_dict)
        else:
            self.state_to_load = state_dict

    def state_dict(self):
        return self.scheduler.state_dict()

    def load_state(self, f: BinaryIO):
        self.load_state_dict(torch.load(f, map_location='cpu'))

    def save_state(self, f: BinaryIO):
        torch.save(self.state_dict(), f)


def new_init(torch_lr_scheduler):
    def f(self, *args, **kwargs):
        super(type(self), self).__init__(torch_lr_scheduler, *args, **kwargs)

    return f


for name, module_cls in torch.optim.lr_scheduler.__dict__.items():
    if inspect.isclass(module_cls) and issubclass(module_cls, LRScheduler) and module_cls != LRScheduler:
        _new_cls = type(
            name,
            (_PyTorchLRSchedulerWrapper,),
            {
                '__init__': new_init(module_cls),
                '__doc__': f"""
                            See:
                                :class:`~torch.optim.lr_scheduler.{name}`
                            """,
            },
        )
        setattr(sys.modules[__name__], name, _new_cls)


[docs] class ReduceLROnPlateau(_PyTorchLRSchedulerWrapper): """ Args: monitor (str): The quantity to monitor. (Default value = 'val_loss') See: :class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` """ def __init__(self, *args, monitor: str = 'val_loss', **kwargs): super().__init__(torch_lr_scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, *args, **kwargs) self.monitor = monitor def on_epoch_end(self, epoch_number: int, logs: Dict): self.scheduler.step(logs[self.monitor])