/src/transformers/hf_argparser.py

https://github.com/huggingface/pytorch-pretrained-BERT · Python · 180 lines · 153 code · 5 blank · 22 comment · 5 complexity · 4ba263ba3c7e75d852f7cb6d2b01a6d7 MD5 · raw file

  1. import dataclasses
  2. import json
  3. import sys
  4. from argparse import ArgumentParser
  5. from enum import Enum
  6. from pathlib import Path
  7. from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
  8. DataClass = NewType("DataClass", Any)
  9. DataClassType = NewType("DataClassType", Any)
  10. class HfArgumentParser(ArgumentParser):
  11. """
  12. This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
  13. to generate arguments.
  14. The class is designed to play well with the native argparse. In particular,
  15. you can add more (non-dataclass backed) arguments to the parser after initialization
  16. and you'll get the output back after parsing as an additional namespace.
  17. """
  18. dataclass_types: Iterable[DataClassType]
  19. def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
  20. """
  21. Args:
  22. dataclass_types:
  23. Dataclass type, or list of dataclass types for which we will "fill" instances
  24. with the parsed args.
  25. kwargs:
  26. (Optional) Passed to `argparse.ArgumentParser()` in the regular way.
  27. """
  28. super().__init__(**kwargs)
  29. if dataclasses.is_dataclass(dataclass_types):
  30. dataclass_types = [dataclass_types]
  31. self.dataclass_types = dataclass_types
  32. for dtype in self.dataclass_types:
  33. self._add_dataclass_arguments(dtype)
  34. def _add_dataclass_arguments(self, dtype: DataClassType):
  35. for field in dataclasses.fields(dtype):
  36. field_name = f"--{field.name}"
  37. kwargs = field.metadata.copy()
  38. # field.metadata is not used at all by Data Classes,
  39. # it is provided as a third-party extension mechanism.
  40. if isinstance(field.type, str):
  41. raise ImportError(
  42. "This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
  43. "which can be opted in from Python 3.7 with `from __future__ import annotations`."
  44. "We will add compatibility when Python 3.9 is released."
  45. )
  46. typestring = str(field.type)
  47. for prim_type in (int, float, str):
  48. for collection in (List,):
  49. if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
  50. field.type = collection[prim_type]
  51. if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
  52. field.type = prim_type
  53. if isinstance(field.type, type) and issubclass(field.type, Enum):
  54. kwargs["choices"] = list(field.type)
  55. kwargs["type"] = field.type
  56. if field.default is not dataclasses.MISSING:
  57. kwargs["default"] = field.default
  58. elif field.type is bool or field.type is Optional[bool]:
  59. kwargs["action"] = "store_false" if field.default is True else "store_true"
  60. if field.default is True:
  61. field_name = f"--no-{field.name}"
  62. kwargs["dest"] = field.name
  63. elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
  64. kwargs["nargs"] = "+"
  65. kwargs["type"] = field.type.__args__[0]
  66. assert all(
  67. x == kwargs["type"] for x in field.type.__args__
  68. ), "{} cannot be a List of mixed types".format(field.name)
  69. if field.default_factory is not dataclasses.MISSING:
  70. kwargs["default"] = field.default_factory()
  71. else:
  72. kwargs["type"] = field.type
  73. if field.default is not dataclasses.MISSING:
  74. kwargs["default"] = field.default
  75. elif field.default_factory is not dataclasses.MISSING:
  76. kwargs["default"] = field.default_factory()
  77. else:
  78. kwargs["required"] = True
  79. self.add_argument(field_name, **kwargs)
  80. def parse_args_into_dataclasses(
  81. self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
  82. ) -> Tuple[DataClass, ...]:
  83. """
  84. Parse command-line args into instances of the specified dataclass types.
  85. This relies on argparse's `ArgumentParser.parse_known_args`.
  86. See the doc at:
  87. docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
  88. Args:
  89. args:
  90. List of strings to parse. The default is taken from sys.argv.
  91. (same as argparse.ArgumentParser)
  92. return_remaining_strings:
  93. If true, also return a list of remaining argument strings.
  94. look_for_args_file:
  95. If true, will look for a ".args" file with the same base name
  96. as the entry point script for this process, and will append its
  97. potential content to the command line args.
  98. args_filename:
  99. If not None, will uses this file instead of the ".args" file
  100. specified in the previous argument.
  101. Returns:
  102. Tuple consisting of:
  103. - the dataclass instances in the same order as they
  104. were passed to the initializer.abspath
  105. - if applicable, an additional namespace for more
  106. (non-dataclass backed) arguments added to the parser
  107. after initialization.
  108. - The potential list of remaining argument strings.
  109. (same as argparse.ArgumentParser.parse_known_args)
  110. """
  111. if args_filename or (look_for_args_file and len(sys.argv)):
  112. if args_filename:
  113. args_file = Path(args_filename)
  114. else:
  115. args_file = Path(sys.argv[0]).with_suffix(".args")
  116. if args_file.exists():
  117. fargs = args_file.read_text().split()
  118. args = fargs + args if args is not None else fargs + sys.argv[1:]
  119. # in case of duplicate arguments the first one has precedence
  120. # so we append rather than prepend.
  121. namespace, remaining_args = self.parse_known_args(args=args)
  122. outputs = []
  123. for dtype in self.dataclass_types:
  124. keys = {f.name for f in dataclasses.fields(dtype)}
  125. inputs = {k: v for k, v in vars(namespace).items() if k in keys}
  126. for k in keys:
  127. delattr(namespace, k)
  128. obj = dtype(**inputs)
  129. outputs.append(obj)
  130. if len(namespace.__dict__) > 0:
  131. # additional namespace.
  132. outputs.append(namespace)
  133. if return_remaining_strings:
  134. return (*outputs, remaining_args)
  135. else:
  136. if remaining_args:
  137. raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
  138. return (*outputs,)
  139. def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
  140. """
  141. Alternative helper method that does not use `argparse` at all,
  142. instead loading a json file and populating the dataclass types.
  143. """
  144. data = json.loads(Path(json_file).read_text())
  145. outputs = []
  146. for dtype in self.dataclass_types:
  147. keys = {f.name for f in dataclasses.fields(dtype)}
  148. inputs = {k: v for k, v in data.items() if k in keys}
  149. obj = dtype(**inputs)
  150. outputs.append(obj)
  151. return (*outputs,)
  152. def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
  153. """
  154. Alternative helper method that does not use `argparse` at all,
  155. instead uses a dict and populating the dataclass types.
  156. """
  157. outputs = []
  158. for dtype in self.dataclass_types:
  159. keys = {f.name for f in dataclasses.fields(dtype)}
  160. inputs = {k: v for k, v in args.items() if k in keys}
  161. obj = dtype(**inputs)
  162. outputs.append(obj)
  163. return (*outputs,)