/torch_ipex_py/weight_cast.py

https://github.com/intel/intel-extension-for-pytorch
Python | 87 lines | 66 code | 9 blank | 12 comment | 18 complexity | 8d3c9534dbc38f53eb9e42cdd47dd300 MD5 | raw file
  1. import torch
  2. import torch.nn as nn
  3. from .optimizer_utils import refresh_optimizer_params_after_cast, patch_step_for_master_weight_training, patch_load_state_dict
  4. import types
  5. # IPEX does not cast all module parameters for acc reason, such as BN
  6. IPEX_WEIGHT_CAST_MODULE = {
  7. # align with auto cast white list
  8. torch.nn.Linear,
  9. torch.nn.Conv1d,
  10. torch.nn.Conv2d,
  11. torch.nn.Conv3d,
  12. torch.nn.ConvTranspose1d,
  13. torch.nn.ConvTranspose2d,
  14. torch.nn.ConvTranspose3d,
  15. # ipex support
  16. torch.nn.EmbeddingBag,
  17. }
  18. def _save_to_state_dict(self, destination, prefix, keep_vars):
  19. # cast weight
  20. temp_weight = self.weight
  21. if self.master_weight_split:
  22. self.weight = torch.nn.Parameter(torch.ops.torch_ipex.cat_bfloat16_float(self.weight.data, self.weight_trail))
  23. else:
  24. self.weight = torch.nn.Parameter(self.master_weight)
  25. # cast bias
  26. if hasattr(self, 'bias') and self.bias is not None:
  27. temp_bias = self.bias
  28. if self.master_weight_split:
  29. self.bias = torch.nn.Parameter(torch.ops.torch_ipex.cat_bfloat16_float(self.bias.data, self.bias_trail))
  30. else:
  31. self.bias = torch.nn.Parameter(self.master_bias)
  32. super(type(self), self)._save_to_state_dict(destination, prefix, keep_vars)
  33. self.weight = temp_weight
  34. if hasattr(self, 'bias') and self.bias is not None:
  35. self.bias = temp_bias
  36. def _weight_dtype_convert_with_ipex(module, optimizer, params_attr, master_weight_split):
  37. def cast_attr(m, attr, master_weight_split, params_attr, optimizer):
  38. # cast weight/bias for BF16 dtype
  39. float_param = getattr(m, attr)
  40. params_attr[float_param] = {}
  41. if master_weight_split:
  42. top_half, bot_half = torch.ops.torch_ipex.split_float_bfloat16(float_param.data)
  43. setattr(m, attr + '_trail', bot_half)
  44. setattr(m, attr, nn.Parameter(top_half.detach()))
  45. params_attr[float_param]['trail'] = getattr(m, attr + '_trail')
  46. else:
  47. setattr(m, 'master_' + attr, float_param.data)
  48. setattr(m, attr, nn.Parameter(float_param.detach().bfloat16()))
  49. params_attr[float_param]['bf16_param'] = getattr(m, attr)
  50. # update attr entry, always use params in optimzer as "key"
  51. # while master weight split, key is m.weight/bias, if not split, key is m.master_weight/master_bias
  52. attr_name = attr if master_weight_split else 'master_' + attr
  53. params_attr[getattr(m, attr_name)] = params_attr.pop(float_param)
  54. refresh_optimizer_params_after_cast(m, attr, float_param, master_weight_split, optimizer)
  55. def convert(m):
  56. if type(m) in IPEX_WEIGHT_CAST_MODULE:
  57. setattr(m, 'master_weight_split', master_weight_split)
  58. # replace weight
  59. cast_attr(m, 'weight', master_weight_split, params_attr, optimizer)
  60. if hasattr(m, 'bias') and m.bias is not None:
  61. # replace bias
  62. cast_attr(m, 'bias', master_weight_split, params_attr, optimizer)
  63. # for resume training reason, we always save float tensors
  64. # replace module method to ensure return float params while call "state_dict()"
  65. setattr(m, '_save_to_state_dict', types.MethodType(_save_to_state_dict, m))
  66. return m
  67. def convert_rec(m):
  68. new_m = convert(m)
  69. for name, sub_m in m.named_children():
  70. setattr(new_m, name, convert_rec(sub_m))
  71. return new_m
  72. casted_model, casted_optimizer, params_attr = convert_rec(module), optimizer, params_attr
  73. if optimizer is not None:
  74. patch_load_state_dict(casted_optimizer)
  75. setattr(casted_optimizer, 'params_attr', params_attr)
  76. if not master_weight_split:
  77. patch_step_for_master_weight_training(casted_optimizer)
  78. return casted_model, casted_optimizer, params_attr