/detectron2/solver/build.py

https://github.com/facebookresearch/detectron2 · Python · 165 lines · 104 code · 21 blank · 40 comment · 12 complexity · fa7a5bcdf2b8b7734a595d92798ec890 MD5 · raw file

  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. from enum import Enum
  3. from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
  4. import torch
  5. from detectron2.config import CfgNode
  6. from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR
  7. _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
  8. _GradientClipper = Callable[[_GradientClipperInput], None]
  9. class GradientClipType(Enum):
  10. VALUE = "value"
  11. NORM = "norm"
  12. def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
  13. """
  14. Creates gradient clipping closure to clip by value or by norm,
  15. according to the provided config.
  16. """
  17. cfg = cfg.clone()
  18. def clip_grad_norm(p: _GradientClipperInput):
  19. torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
  20. def clip_grad_value(p: _GradientClipperInput):
  21. torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
  22. _GRADIENT_CLIP_TYPE_TO_CLIPPER = {
  23. GradientClipType.VALUE: clip_grad_value,
  24. GradientClipType.NORM: clip_grad_norm,
  25. }
  26. return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
  27. def _generate_optimizer_class_with_gradient_clipping(
  28. optimizer_type: Type[torch.optim.Optimizer], gradient_clipper: _GradientClipper
  29. ) -> Type[torch.optim.Optimizer]:
  30. """
  31. Dynamically creates a new type that inherits the type of a given instance
  32. and overrides the `step` method to add gradient clipping
  33. """
  34. def optimizer_wgc_step(self, closure=None):
  35. for group in self.param_groups:
  36. for p in group["params"]:
  37. gradient_clipper(p)
  38. super(type(self), self).step(closure)
  39. OptimizerWithGradientClip = type(
  40. optimizer_type.__name__ + "WithGradientClip",
  41. (optimizer_type,),
  42. {"step": optimizer_wgc_step},
  43. )
  44. return OptimizerWithGradientClip
  45. def maybe_add_gradient_clipping(
  46. cfg: CfgNode, optimizer: torch.optim.Optimizer
  47. ) -> torch.optim.Optimizer:
  48. """
  49. If gradient clipping is enabled through config options, wraps the existing
  50. optimizer instance of some type OptimizerType to become an instance
  51. of the new dynamically created class OptimizerTypeWithGradientClip
  52. that inherits OptimizerType and overrides the `step` method to
  53. include gradient clipping.
  54. Args:
  55. cfg: CfgNode
  56. configuration options
  57. optimizer: torch.optim.Optimizer
  58. existing optimizer instance
  59. Return:
  60. optimizer: torch.optim.Optimizer
  61. either the unmodified optimizer instance (if gradient clipping is
  62. disabled), or the same instance with adjusted __class__ to override
  63. the `step` method and include gradient clipping
  64. """
  65. if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
  66. return optimizer
  67. grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
  68. OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
  69. type(optimizer), grad_clipper
  70. )
  71. optimizer.__class__ = OptimizerWithGradientClip
  72. return optimizer
  73. def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
  74. """
  75. Build an optimizer from config.
  76. """
  77. norm_module_types = (
  78. torch.nn.BatchNorm1d,
  79. torch.nn.BatchNorm2d,
  80. torch.nn.BatchNorm3d,
  81. torch.nn.SyncBatchNorm,
  82. # NaiveSyncBatchNorm inherits from BatchNorm2d
  83. torch.nn.GroupNorm,
  84. torch.nn.InstanceNorm1d,
  85. torch.nn.InstanceNorm2d,
  86. torch.nn.InstanceNorm3d,
  87. torch.nn.LayerNorm,
  88. torch.nn.LocalResponseNorm,
  89. )
  90. params: List[Dict[str, Any]] = []
  91. memo: Set[torch.nn.parameter.Parameter] = set()
  92. for module in model.modules():
  93. for key, value in module.named_parameters(recurse=False):
  94. if not value.requires_grad:
  95. continue
  96. # Avoid duplicating parameters
  97. if value in memo:
  98. continue
  99. memo.add(value)
  100. lr = cfg.SOLVER.BASE_LR
  101. weight_decay = cfg.SOLVER.WEIGHT_DECAY
  102. if isinstance(module, norm_module_types):
  103. weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM
  104. elif key == "bias":
  105. # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
  106. # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
  107. # hyperparameters are by default exactly the same as for regular
  108. # weights.
  109. lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
  110. weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
  111. params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
  112. optimizer = torch.optim.SGD(
  113. params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV
  114. )
  115. optimizer = maybe_add_gradient_clipping(cfg, optimizer)
  116. return optimizer
  117. def build_lr_scheduler(
  118. cfg: CfgNode, optimizer: torch.optim.Optimizer
  119. ) -> torch.optim.lr_scheduler._LRScheduler:
  120. """
  121. Build a LR scheduler from config.
  122. """
  123. name = cfg.SOLVER.LR_SCHEDULER_NAME
  124. if name == "WarmupMultiStepLR":
  125. return WarmupMultiStepLR(
  126. optimizer,
  127. cfg.SOLVER.STEPS,
  128. cfg.SOLVER.GAMMA,
  129. warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
  130. warmup_iters=cfg.SOLVER.WARMUP_ITERS,
  131. warmup_method=cfg.SOLVER.WARMUP_METHOD,
  132. )
  133. elif name == "WarmupCosineLR":
  134. return WarmupCosineLR(
  135. optimizer,
  136. cfg.SOLVER.MAX_ITER,
  137. warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
  138. warmup_iters=cfg.SOLVER.WARMUP_ITERS,
  139. warmup_method=cfg.SOLVER.WARMUP_METHOD,
  140. )
  141. else:
  142. raise ValueError("Unknown LR scheduler: {}".format(name))