/lottery/desc.py

https://github.com/facebookresearch/open_lth
Python | 158 lines | 118 code | 28 blank | 12 comment | 34 complexity | c5519cda2980060239fd892db4bab767 MD5 | raw file
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # This source code is licensed under the MIT license found in the
  3. # LICENSE file in the root directory of this source tree.
  4. import argparse
  5. import copy
  6. from dataclasses import dataclass, replace
  7. import os
  8. from typing import Union
  9. from cli import arg_utils
  10. from datasets import registry as datasets_registry
  11. from foundations.desc import Desc
  12. from foundations import hparams
  13. from foundations.step import Step
  14. from platforms.platform import get_platform
  15. import pruning.registry
  16. @dataclass
  17. class LotteryDesc(Desc):
  18. """The hyperparameters necessary to describe a lottery ticket training backbone."""
  19. model_hparams: hparams.ModelHparams
  20. dataset_hparams: hparams.DatasetHparams
  21. training_hparams: hparams.TrainingHparams
  22. pruning_hparams: hparams.PruningHparams
  23. pretrain_dataset_hparams: hparams.DatasetHparams = None
  24. pretrain_training_hparams: hparams.TrainingHparams = None
  25. @staticmethod
  26. def name_prefix(): return 'lottery'
  27. @staticmethod
  28. def _add_pretrain_argument(parser):
  29. help_text = \
  30. 'Perform a pre-training phase prior to running the main lottery ticket process. Setting this argument '\
  31. 'will enable arguments to control how the dataset and training during this pre-training phase. Rewinding '\
  32. 'is a specific case of of pre-training where pre-training uses the same dataset and training procedure '\
  33. 'as the main training run.'
  34. parser.add_argument('--pretrain', action='store_true', help=help_text)
  35. @staticmethod
  36. def _add_rewinding_argument(parser):
  37. help_text = \
  38. 'The number of steps for which to train the network before the lottery ticket process begins. This is ' \
  39. 'the \'rewinding\' step as described in recent lottery ticket research. Can be expressed as a number of ' \
  40. 'epochs (\'160ep\') or a number of iterations (\'50000it\'). If this flag is present, no other '\
  41. 'pretraining arguments may be set. Pretraining will be conducted using the same dataset and training '\
  42. 'hyperparameters as for the main training run. For the full range of pre-training options, use --pretrain.'
  43. parser.add_argument('--rewinding_steps', type=str, help=help_text)
  44. @staticmethod
  45. def add_args(parser: argparse.ArgumentParser, defaults: 'LotteryDesc' = None):
  46. # Add the rewinding/pretraining arguments.
  47. rewinding_steps = arg_utils.maybe_get_arg('rewinding_steps')
  48. pretrain = arg_utils.maybe_get_arg('pretrain', boolean_arg=True)
  49. if rewinding_steps is not None and pretrain: raise ValueError('Cannot set --rewinding_steps and --pretrain')
  50. pretraining_parser = parser.add_argument_group(
  51. 'Rewinding/Pretraining Arguments', 'Arguments that control how the network is pre-trained')
  52. LotteryDesc._add_rewinding_argument(pretraining_parser)
  53. LotteryDesc._add_pretrain_argument(pretraining_parser)
  54. # Get the proper pruning hparams.
  55. pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy')
  56. if defaults and not pruning_strategy: pruning_strategy = defaults.pruning_hparams.pruning_strategy
  57. if pruning_strategy:
  58. pruning_hparams = pruning.registry.get_pruning_hparams(pruning_strategy)
  59. if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy:
  60. def_ph = defaults.pruning_hparams
  61. else:
  62. pruning_hparams = hparams.PruningHparams
  63. def_ph = None
  64. # Add the main arguments.
  65. hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None)
  66. hparams.ModelHparams.add_args(parser, defaults=defaults.model_hparams if defaults else None)
  67. hparams.TrainingHparams.add_args(parser, defaults=defaults.training_hparams if defaults else None)
  68. pruning_hparams.add_args(parser, defaults=def_ph if defaults else None)
  69. # Handle pretraining.
  70. if pretrain:
  71. if defaults: def_th = replace(defaults.training_hparams, training_steps='0ep')
  72. hparams.TrainingHparams.add_args(parser, defaults=def_th if defaults else None,
  73. name='Training Hyperparameters for Pretraining', prefix='pretrain')
  74. hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None,
  75. name='Dataset Hyperparameters for Pretraining', prefix='pretrain')
  76. @classmethod
  77. def create_from_args(cls, args: argparse.Namespace) -> 'LotteryDesc':
  78. # Get the main arguments.
  79. dataset_hparams = hparams.DatasetHparams.create_from_args(args)
  80. model_hparams = hparams.ModelHparams.create_from_args(args)
  81. training_hparams = hparams.TrainingHparams.create_from_args(args)
  82. pruning_hparams = pruning.registry.get_pruning_hparams(args.pruning_strategy).create_from_args(args)
  83. # Create the desc.
  84. desc = cls(model_hparams, dataset_hparams, training_hparams, pruning_hparams)
  85. # Handle pretraining.
  86. if args.pretrain and not Step.str_is_zero(args.pretrain_training_steps):
  87. desc.pretrain_dataset_hparams = hparams.DatasetHparams.create_from_args(args, prefix='pretrain')
  88. desc.pretrain_dataset_hparams._name = 'Pretraining ' + desc.pretrain_dataset_hparams._name
  89. desc.pretrain_training_hparams = hparams.TrainingHparams.create_from_args(args, prefix='pretrain')
  90. desc.pretrain_training_hparams._name = 'Pretraining ' + desc.pretrain_training_hparams._name
  91. elif 'rewinding_steps' in args and args.rewinding_steps and not Step.str_is_zero(args.rewinding_steps):
  92. desc.pretrain_dataset_hparams = copy.deepcopy(dataset_hparams)
  93. desc.pretrain_dataset_hparams._name = 'Pretraining ' + desc.pretrain_dataset_hparams._name
  94. desc.pretrain_training_hparams = copy.deepcopy(training_hparams)
  95. desc.pretrain_training_hparams._name = 'Pretraining ' + desc.pretrain_training_hparams._name
  96. desc.pretrain_training_hparams.training_steps = args.rewinding_steps
  97. return desc
  98. def str_to_step(self, s: str, pretrain: bool = False) -> Step:
  99. dataset_hparams = self.pretrain_dataset_hparams if pretrain else self.dataset_hparams
  100. iterations_per_epoch = datasets_registry.iterations_per_epoch(dataset_hparams)
  101. return Step.from_str(s, iterations_per_epoch)
  102. @property
  103. def pretrain_end_step(self):
  104. return self.str_to_step(self.pretrain_training_hparams.training_steps, True)
  105. @property
  106. def train_start_step(self):
  107. if self.pretrain_training_hparams: return self.str_to_step(self.pretrain_training_hparams.training_steps)
  108. else: return self.str_to_step('0it')
  109. @property
  110. def train_end_step(self):
  111. return self.str_to_step(self.training_hparams.training_steps)
  112. @property
  113. def pretrain_outputs(self):
  114. datasets_registry.num_classes(self.pretrain_dataset_hparams)
  115. @property
  116. def train_outputs(self):
  117. datasets_registry.num_classes(self.dataset_hparams)
  118. def run_path(self, replicate: int, pruning_level: Union[str, int], experiment: str = 'main'):
  119. """The location where any run is stored."""
  120. if not isinstance(replicate, int) or replicate <= 0:
  121. raise ValueError('Bad replicate: {}'.format(replicate))
  122. return os.path.join(get_platform().root, self.hashname,
  123. f'replicate_{replicate}', f'level_{pruning_level}', experiment)
  124. @property
  125. def display(self):
  126. ls = [self.dataset_hparams.display, self.model_hparams.display,
  127. self.training_hparams.display, self.pruning_hparams.display]
  128. if self.pretrain_training_hparams:
  129. ls += [self.pretrain_dataset_hparams.display, self.pretrain_training_hparams.display]
  130. return '\n'.join(ls)