/poutyne/framework/callbacks/lr_scheduler.py

https://github.com/GRAAL-Research/pytoune
Python | 89 lines | 62 code | 18 blank | 9 comment | 9 complexity | bb6b78617de1b47a1b29f4fd6454e341 MD5 | raw file
  1. import inspect
  2. import sys
  3. from typing import Dict, BinaryIO
  4. import torch.optim.lr_scheduler
  5. from torch.optim import Optimizer
  6. from torch.optim.lr_scheduler import _LRScheduler
  7. from .callbacks import Callback
  8. class _PyTorchLRSchedulerWrapper(Callback):
  9. """
  10. Default class for the LR scheduling callback. Proposes default comportment for the scheduler
  11. loading and saving as well as for the epoch end handling.
  12. """
  13. def __init__(self, torch_lr_scheduler, *args, **kwargs):
  14. super().__init__()
  15. if len(args) > 0 and isinstance(args[0], Optimizer):
  16. raise ValueError("In the LR scheduler callbacks, the optimizer is "
  17. "automatically passed to the PyTorch's LR scheduler. "
  18. "You must remove it from the arguments.")
  19. self.args = args
  20. self.kwargs = kwargs
  21. self.scheduler = None
  22. self.state_to_load = None
  23. self.torch_lr_scheduler = torch_lr_scheduler
  24. def on_epoch_end(self, epoch_number: int, logs: Dict):
  25. self.scheduler.step()
  26. def on_train_begin(self, logs: Dict):
  27. self.scheduler = self.torch_lr_scheduler(self.model.optimizer, *self.args, **self.kwargs)
  28. def load_state(self, f: BinaryIO):
  29. if self.scheduler is not None:
  30. self.scheduler.load_state_dict(torch.load(f, map_location='cpu'))
  31. else:
  32. self.state_to_load = torch.load(f, map_location='cpu')
  33. def save_state(self, f: BinaryIO):
  34. torch.save(self.scheduler.state_dict(), f)
  35. def _load_state_to_load(self):
  36. if self.state_to_load is not None:
  37. self.scheduler.load_state_dict(self.state_to_load)
  38. self.state_to_load = None
  39. def new_init(torch_lr_scheduler):
  40. def f(self, *args, **kwargs):
  41. super(type(self), self).__init__(torch_lr_scheduler, *args, **kwargs)
  42. return f
  43. for name, module_cls in torch.optim.lr_scheduler.__dict__.items():
  44. if inspect.isclass(module_cls) and \
  45. issubclass(module_cls, _LRScheduler) and \
  46. module_cls != _LRScheduler:
  47. _new_cls = type(
  48. name, (_PyTorchLRSchedulerWrapper, ), {
  49. '__init__':
  50. new_init(module_cls),
  51. '__doc__':
  52. """
  53. See:
  54. :class:`~torch.optim.lr_scheduler.{name}`
  55. """.format(name=name)
  56. })
  57. setattr(sys.modules[__name__], name, _new_cls)
  58. class ReduceLROnPlateau(_PyTorchLRSchedulerWrapper):
  59. """
  60. Args:
  61. monitor (str): The quantity to monitor. (Default value = 'val_loss')
  62. See:
  63. :class:`~torch.optim.lr_scheduler.ReduceLROnPlateau`
  64. """
  65. def __init__(self, *args, monitor: str = 'val_loss', **kwargs):
  66. super().__init__(torch_lr_scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, *args, **kwargs)
  67. self.monitor = monitor
  68. def on_epoch_end(self, epoch_number: int, logs: Dict):
  69. self.scheduler.step(logs[self.monitor])