/mmdet/core/fp16/hooks.py

https://github.com/WXinlong/SOLO
Python | 127 lines | 71 code | 17 blank | 39 comment | 14 complexity | 46be49412fae69db658ad1a4a0cff932 MD5 | raw file
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.runner import OptimizerHook
  5. from ..utils.dist_utils import allreduce_grads
  6. from .utils import cast_tensor_type
  7. class Fp16OptimizerHook(OptimizerHook):
  8. """FP16 optimizer hook.
  9. The steps of fp16 optimizer is as follows.
  10. 1. Scale the loss value.
  11. 2. BP in the fp16 model.
  12. 2. Copy gradients from fp16 model to fp32 weights.
  13. 3. Update fp32 weights.
  14. 4. Copy updated parameters from fp32 weights to fp16 model.
  15. Refer to https://arxiv.org/abs/1710.03740 for more details.
  16. Args:
  17. loss_scale (float): Scale factor multiplied with loss.
  18. """
  19. def __init__(self,
  20. grad_clip=None,
  21. coalesce=True,
  22. bucket_size_mb=-1,
  23. loss_scale=512.,
  24. distributed=True):
  25. self.grad_clip = grad_clip
  26. self.coalesce = coalesce
  27. self.bucket_size_mb = bucket_size_mb
  28. self.loss_scale = loss_scale
  29. self.distributed = distributed
  30. def before_run(self, runner):
  31. # keep a copy of fp32 weights
  32. runner.optimizer.param_groups = copy.deepcopy(
  33. runner.optimizer.param_groups)
  34. # convert model to fp16
  35. wrap_fp16_model(runner.model)
  36. def copy_grads_to_fp32(self, fp16_net, fp32_weights):
  37. """Copy gradients from fp16 model to fp32 weight copy."""
  38. for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
  39. if fp16_param.grad is not None:
  40. if fp32_param.grad is None:
  41. fp32_param.grad = fp32_param.data.new(fp32_param.size())
  42. fp32_param.grad.copy_(fp16_param.grad)
  43. def copy_params_to_fp16(self, fp16_net, fp32_weights):
  44. """Copy updated params from fp32 weight copy to fp16 model."""
  45. for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
  46. fp16_param.data.copy_(fp32_param.data)
  47. def after_train_iter(self, runner):
  48. # clear grads of last iteration
  49. runner.model.zero_grad()
  50. runner.optimizer.zero_grad()
  51. # scale the loss value
  52. scaled_loss = runner.outputs['loss'] * self.loss_scale
  53. scaled_loss.backward()
  54. # copy fp16 grads in the model to fp32 params in the optimizer
  55. fp32_weights = []
  56. for param_group in runner.optimizer.param_groups:
  57. fp32_weights += param_group['params']
  58. self.copy_grads_to_fp32(runner.model, fp32_weights)
  59. # allreduce grads
  60. if self.distributed:
  61. allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb)
  62. # scale the gradients back
  63. for param in fp32_weights:
  64. if param.grad is not None:
  65. param.grad.div_(self.loss_scale)
  66. if self.grad_clip is not None:
  67. self.clip_grads(fp32_weights)
  68. # update fp32 params
  69. runner.optimizer.step()
  70. # copy fp32 params to the fp16 model
  71. self.copy_params_to_fp16(runner.model, fp32_weights)
  72. def wrap_fp16_model(model):
  73. # convert model to fp16
  74. model.half()
  75. # patch the normalization layers to make it work in fp32 mode
  76. patch_norm_fp32(model)
  77. # set `fp16_enabled` flag
  78. for m in model.modules():
  79. if hasattr(m, 'fp16_enabled'):
  80. m.fp16_enabled = True
  81. def patch_norm_fp32(module):
  82. if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
  83. module.float()
  84. module.forward = patch_forward_method(module.forward, torch.half,
  85. torch.float)
  86. for child in module.children():
  87. patch_norm_fp32(child)
  88. return module
  89. def patch_forward_method(func, src_type, dst_type, convert_output=True):
  90. """Patch the forward method of a module.
  91. Args:
  92. func (callable): The original forward method.
  93. src_type (torch.dtype): Type of input arguments to be converted from.
  94. dst_type (torch.dtype): Type of input arguments to be converted to.
  95. convert_output (bool): Whether to convert the output back to src_type.
  96. Returns:
  97. callable: The patched forward method.
  98. """
  99. def new_forward(*args, **kwargs):
  100. output = func(*cast_tensor_type(args, src_type, dst_type),
  101. **cast_tensor_type(kwargs, src_type, dst_type))
  102. if convert_output:
  103. output = cast_tensor_type(output, dst_type, src_type)
  104. return output
  105. return new_forward