/utils/optimizer_wrapper.py

https://github.com/Randl/MobileNetV3-pytorch
Python | 120 lines | 74 code | 27 blank | 19 comment | 18 complexity | afffcecd79c3b5a1b7cea4739c346371 MD5 | raw file
  1. # https://github.com/eladhoffer/utils.pytorch/blob/ca6a47a7766c50930a607d8425216d39104b7664/optim.py
  2. from copy import deepcopy
  3. import torch
  4. from torch.optim.lr_scheduler import CyclicLR
  5. from cosine_with_warmup import CosineLR
  6. def copy_params(param_target, param_src):
  7. with torch.no_grad():
  8. for p_src, p_target in zip(param_src, param_target):
  9. p_target.copy_(p_src)
  10. def copy_params_grad(param_target, param_src):
  11. for p_src, p_target in zip(param_src, param_target):
  12. if p_target.grad is None:
  13. p_target.backward(p_src.grad.to(dtype=p_target.dtype))
  14. else:
  15. p_target.grad.detach().copy_(p_src.grad)
  16. class ModuleFloatShadow(torch.nn.Module):
  17. def __init__(self, module):
  18. super(ModuleFloatShadow, self).__init__()
  19. self.original_module = module
  20. self.float_module = deepcopy(module)
  21. self.float_module.to(dtype=torch.float32)
  22. def parameters(self, *kargs, **kwargs):
  23. return self.float_module.parameters(*kargs, **kwargs)
  24. def named_parameters(self, *kargs, **kwargs):
  25. return self.float_module.named_parameters(*kargs, **kwargs)
  26. def modules(self, *kargs, **kwargs):
  27. return self.float_module.modules(*kargs, **kwargs)
  28. def named_modules(self, *kargs, **kwargs):
  29. return self.float_module.named_modules(*kargs, **kwargs)
  30. def original_parameters(self, *kargs, **kwargs):
  31. return self.original_module.parameters(*kargs, **kwargs)
  32. def original_named_parameters(self, *kargs, **kwargs):
  33. return self.original_module.named_parameters(*kargs, **kwargs)
  34. def original_modules(self, *kargs, **kwargs):
  35. return self.original_module.modules(*kargs, **kwargs)
  36. def original_named_modules(self, *kargs, **kwargs):
  37. return self.original_module.named_modules(*kargs, **kwargs)
  38. class OptimizerWrapper(object):
  39. def __init__(self, model, optimizer_class, optimizer_params, scheduler_class, scheduler_params,
  40. optimizer_state_dict=None, use_shadow_weights=False):
  41. if use_shadow_weights:
  42. model = ModuleFloatShadow(model)
  43. self._original_parameters = list(model.original_parameters())
  44. self.parameters = list([p for p in model.parameters() if p.requires_grad])
  45. self.optimizer = optimizer_class(self.parameters, **optimizer_params)
  46. if optimizer_state_dict is not None:
  47. self.optimizer.load_state_dict(optimizer_state_dict)
  48. self.scheduler = scheduler_class(self.optimizer, **scheduler_params)
  49. self.use_shadow_weights = use_shadow_weights
  50. def state_dict(self):
  51. """Returns the state of the optimizer as a :class:`dict`.
  52. """
  53. return self.optimizer.state_dict()
  54. def load_state_dict(self, state_dict):
  55. """Loads the optimizer state.
  56. Arguments:
  57. state_dict (dict): optimizer state. Should be an object returned
  58. from a call to :meth:`state_dict`.
  59. """
  60. # deepcopy, to be consistent with module API
  61. optimizer_state_dict = state_dict['state']
  62. self.optimizer.__setstate__(optimizer_state_dict)
  63. def zero_grad(self):
  64. """Clears the gradients of all optimized :class:`Variable` s."""
  65. self.optimizer.zero_grad()
  66. if self.use_shadow_weights:
  67. for p in self._original_parameters:
  68. if p.grad is not None:
  69. p.grad.detach().zero_()
  70. def optimizer_step(self, closure=None):
  71. """Performs a single optimization step (parameter update).
  72. Arguments:
  73. closure (callable): A closure that reevaluates the model and
  74. returns the loss. Optional for most optimizers.
  75. """
  76. if self.use_shadow_weights:
  77. copy_params_grad(self.parameters, self._original_parameters)
  78. self.optimizer.step(closure)
  79. if self.use_shadow_weights:
  80. copy_params(self._original_parameters, self.parameters)
  81. def scheduler_step(self, epoch=None):
  82. """Performs a single lr update step.
  83. """
  84. self.scheduler.step()
  85. def batch_step(self, closure=None):
  86. if isinstance(self.scheduler, CyclicLR) or isinstance(self.scheduler, CosineLR):
  87. self.scheduler_step()
  88. self.optimizer_step(closure)
  89. def epoch_step(self):
  90. if not isinstance(self.scheduler, CyclicLR) and not isinstance(self.scheduler, CosineLR):
  91. self.scheduler_step()