/mmdet/core/fp16/hooks.py
Python | 127 lines | 71 code | 17 blank | 39 comment | 14 complexity | 46be49412fae69db658ad1a4a0cff932 MD5 | raw file
- import copy
- import torch
- import torch.nn as nn
- from mmcv.runner import OptimizerHook
- from ..utils.dist_utils import allreduce_grads
- from .utils import cast_tensor_type
- class Fp16OptimizerHook(OptimizerHook):
- """FP16 optimizer hook.
- The steps of fp16 optimizer is as follows.
- 1. Scale the loss value.
- 2. BP in the fp16 model.
- 2. Copy gradients from fp16 model to fp32 weights.
- 3. Update fp32 weights.
- 4. Copy updated parameters from fp32 weights to fp16 model.
- Refer to https://arxiv.org/abs/1710.03740 for more details.
- Args:
- loss_scale (float): Scale factor multiplied with loss.
- """
- def __init__(self,
- grad_clip=None,
- coalesce=True,
- bucket_size_mb=-1,
- loss_scale=512.,
- distributed=True):
- self.grad_clip = grad_clip
- self.coalesce = coalesce
- self.bucket_size_mb = bucket_size_mb
- self.loss_scale = loss_scale
- self.distributed = distributed
- def before_run(self, runner):
- # keep a copy of fp32 weights
- runner.optimizer.param_groups = copy.deepcopy(
- runner.optimizer.param_groups)
- # convert model to fp16
- wrap_fp16_model(runner.model)
- def copy_grads_to_fp32(self, fp16_net, fp32_weights):
- """Copy gradients from fp16 model to fp32 weight copy."""
- for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
- if fp16_param.grad is not None:
- if fp32_param.grad is None:
- fp32_param.grad = fp32_param.data.new(fp32_param.size())
- fp32_param.grad.copy_(fp16_param.grad)
- def copy_params_to_fp16(self, fp16_net, fp32_weights):
- """Copy updated params from fp32 weight copy to fp16 model."""
- for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
- fp16_param.data.copy_(fp32_param.data)
- def after_train_iter(self, runner):
- # clear grads of last iteration
- runner.model.zero_grad()
- runner.optimizer.zero_grad()
- # scale the loss value
- scaled_loss = runner.outputs['loss'] * self.loss_scale
- scaled_loss.backward()
- # copy fp16 grads in the model to fp32 params in the optimizer
- fp32_weights = []
- for param_group in runner.optimizer.param_groups:
- fp32_weights += param_group['params']
- self.copy_grads_to_fp32(runner.model, fp32_weights)
- # allreduce grads
- if self.distributed:
- allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb)
- # scale the gradients back
- for param in fp32_weights:
- if param.grad is not None:
- param.grad.div_(self.loss_scale)
- if self.grad_clip is not None:
- self.clip_grads(fp32_weights)
- # update fp32 params
- runner.optimizer.step()
- # copy fp32 params to the fp16 model
- self.copy_params_to_fp16(runner.model, fp32_weights)
- def wrap_fp16_model(model):
- # convert model to fp16
- model.half()
- # patch the normalization layers to make it work in fp32 mode
- patch_norm_fp32(model)
- # set `fp16_enabled` flag
- for m in model.modules():
- if hasattr(m, 'fp16_enabled'):
- m.fp16_enabled = True
- def patch_norm_fp32(module):
- if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
- module.float()
- module.forward = patch_forward_method(module.forward, torch.half,
- torch.float)
- for child in module.children():
- patch_norm_fp32(child)
- return module
- def patch_forward_method(func, src_type, dst_type, convert_output=True):
- """Patch the forward method of a module.
- Args:
- func (callable): The original forward method.
- src_type (torch.dtype): Type of input arguments to be converted from.
- dst_type (torch.dtype): Type of input arguments to be converted to.
- convert_output (bool): Whether to convert the output back to src_type.
- Returns:
- callable: The patched forward method.
- """
- def new_forward(*args, **kwargs):
- output = func(*cast_tensor_type(args, src_type, dst_type),
- **cast_tensor_type(kwargs, src_type, dst_type))
- if convert_output:
- output = cast_tensor_type(output, dst_type, src_type)
- return output
- return new_forward