/pytext/data/batch_sampler.py

https://github.com/facebookresearch/pytext · Python · 276 lines · 181 code · 33 blank · 62 comment · 17 complexity · 24af18ce75f6b1c3c4f15defd5f91420 MD5 · raw file

  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. from collections.abc import Iterator
  4. from typing import Dict, Optional
  5. import numpy as np
  6. from pytext.config.component import Component, ComponentType
  7. class BaseBatchSampler(Component):
  8. __COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
  9. __EXPANSIBLE__ = True
  10. @classmethod
  11. def from_config(cls, config: Component.Config):
  12. return cls()
  13. def __init__(self):
  14. pass
  15. def batchify(self, iterators: Dict[str, Iterator]):
  16. pass
  17. class EvalBatchSampler(BaseBatchSampler):
  18. """
  19. This sampler takes in a dictionary of Iterators and returns batches
  20. associated with each key in the dictionary. It guarentees that we will see
  21. each batch associated with each key exactly once in the epoch.
  22. Example:
  23. Iterator 1: [A, B, C, D], Iterator 2: [a, b]
  24. Output: [A, B, C, D, a, b]
  25. """
  26. def batchify(self, iterators: Dict[str, Iterator]):
  27. """
  28. Loop through each key in the input dict and generate batches from
  29. the iterator associated with that key.
  30. Args:
  31. iterators: Dictionary of iterators
  32. """
  33. iter_dict = {name: iter(iterator) for name, iterator in iterators.items()}
  34. for name, it in iter_dict.items():
  35. for item in it:
  36. yield name, item
  37. class RoundRobinBatchSampler(BaseBatchSampler):
  38. """
  39. This sampler takes a dictionary of Iterators and returns batches in a round
  40. robin fashion till a the end of one of the iterators is reached. The end
  41. is specified by `iter_to_set_epoch`.
  42. If `iter_to_set_epoch` is set, cycle batches from each iterator until one
  43. epoch of the target iterator is fulfilled. Iterators with fewer batches
  44. than the target iterator are repeated, so they never run out.
  45. If `iter_to_set_epoch` is None, cycle over batches from each iterator until the
  46. shortest iterator completes one epoch.
  47. Example:
  48. Iterator 1: [A, B, C, D], Iterator 2: [a, b]
  49. iter_to_set_epoch = "Iterator 1"
  50. Output: [A, a, B, b, C, a, D, b]
  51. iter_to_set_epoch = None
  52. Output: [A, a, B, b]
  53. Args:
  54. iter_to_set_epoch (Optional[str]): Name of iterator to define epoch size.
  55. If this is not set, epoch size defaults to the length of
  56. the shortest iterator.
  57. """
  58. __COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
  59. class Config(Component.Config):
  60. iter_to_set_epoch: str = ""
  61. @classmethod
  62. def from_config(cls, config: Config):
  63. return cls(config.iter_to_set_epoch)
  64. def __init__(self, iter_to_set_epoch: Optional[str] = None) -> None:
  65. self.iter_to_set_epoch = iter_to_set_epoch
  66. def batchify(self, iterators: Dict[str, Iterator]):
  67. """
  68. Loop through each key in the input dict and generate batches from
  69. the iterator associated with that key until the target iterator reaches
  70. its end.
  71. Args:
  72. iterators: Dictionary of iterators
  73. """
  74. iter_dict = {name: iter(iterator) for name, iterator in iterators.items()}
  75. while True:
  76. for name, it in iter_dict.items():
  77. try:
  78. yield name, next(it)
  79. except StopIteration:
  80. new_iter = iter(iterators[name])
  81. iter_dict[name] = new_iter
  82. if (not self.iter_to_set_epoch) or name == self.iter_to_set_epoch:
  83. self.iter_to_set_epoch = name
  84. # end of epoch
  85. return
  86. else:
  87. yield name, next(new_iter)
  88. def select_key_and_batch(
  89. iterator_names: Dict[str, str],
  90. iterator_probs: Dict[str, float],
  91. iter_dict: Dict[str, Iterator],
  92. iterators: Dict[str, Iterator],
  93. ):
  94. """ Helper function for RandomizedBatchSampler and AlternatingRandomizedBatchSampler
  95. to select a key from iterator_names using iterator_probs and return a batch
  96. for the selected key using iter_dict and iterators.
  97. """
  98. # Select a candidate iterator using the uniform distribtion
  99. selected_key = np.random.choice(iterator_names, p=iterator_probs)
  100. try:
  101. batch = next(iter_dict[selected_key])
  102. except StopIteration:
  103. iter_dict[selected_key] = iter(iterators[selected_key])
  104. batch = next(iter_dict[selected_key])
  105. return selected_key, batch
  106. def extract_iterator_properties(input_iterator_probs: Dict[str, float]):
  107. """ Helper function for RandomizedBatchSampler and AlternatingRandomizedBatchSampler
  108. to generate iterator properties: iterator_names and iterator_probs.
  109. """
  110. iterator_names = list(input_iterator_probs)
  111. iterator_probs = np.array(
  112. [float(input_iterator_probs[name]) for name in iterator_names]
  113. )
  114. iterator_probs /= iterator_probs.sum()
  115. return iterator_names, iterator_probs
  116. class RandomizedBatchSampler(BaseBatchSampler):
  117. """
  118. This sampler takes in a dictionary of iterators and returns batches according
  119. to the specified probabilities by `unnormalized_iterator_probs`. We cycle through
  120. the iterators (restarting any that "run out") indefinitely. Set batches_per_epoch
  121. in Trainer.Config.
  122. Example:
  123. Iterator A: [A, B, C, D], Iterator B: [a, b]
  124. batches_per_epoch = 3, unnormalized_iterator_probs = {"A": 0, "B": 1}
  125. Epoch 1 = [a, b, a]
  126. Epoch 2 = [b, a, b]
  127. Args:
  128. unnormalized_iterator_probs (Dict[str, float]): Iterator sampling probabilities.
  129. The keys should be the same as the keys of the underlying iterators, and the
  130. values will be normalized to sum to 1.
  131. """
  132. __COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
  133. class Config(Component.Config):
  134. unnormalized_iterator_probs: Dict[str, float]
  135. @classmethod
  136. def from_config(cls, config: Config):
  137. return cls(config.unnormalized_iterator_probs)
  138. def __init__(self, unnormalized_iterator_probs: Dict[str, float]) -> None:
  139. self.iterator_names, self.iterator_probs = extract_iterator_properties(
  140. unnormalized_iterator_probs
  141. )
  142. # Note: we need to make `iter_dict` an instance attribute so that it persists
  143. # across calls to `batchify()`. This way subsequent epochs will continue from
  144. # previous states of the iterators (instead of recreating them).
  145. self.iter_dict = None
  146. def batchify(self, iterators: Dict[str, Iterator]):
  147. assert set(iterators) == set(self.iterator_names)
  148. if self.iter_dict is None:
  149. self.iter_dict = {
  150. name: iter(iterator) for name, iterator in iterators.items()
  151. }
  152. num_batches = 0
  153. while True:
  154. selected_key, batch = select_key_and_batch(
  155. self.iterator_names, self.iterator_probs, self.iter_dict, iterators
  156. )
  157. num_batches += 1
  158. yield selected_key, batch
  159. class AlternatingRandomizedBatchSampler(RandomizedBatchSampler):
  160. """
  161. This sampler takes in a dictionary of iterators and returns batches alternating
  162. between keys and probabilities specified by `unnormalized_iterator_probs` and
  163. 'second_unnormalized_iterator_probs', This is used for example in XLM
  164. pre-training where we alternate between MLM and TLM batches.
  165. """
  166. __COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER
  167. class Config(Component.Config):
  168. unnormalized_iterator_probs: Dict[str, float]
  169. second_unnormalized_iterator_probs: Dict[str, float]
  170. @classmethod
  171. def from_config(cls, config: Config):
  172. assert (
  173. len(config.unnormalized_iterator_probs) > 0
  174. and len(config.second_unnormalized_iterator_probs) > 0
  175. )
  176. return cls(
  177. unnormalized_iterator_probs=config.unnormalized_iterator_probs,
  178. second_unnormalized_iterator_probs=(
  179. config.second_unnormalized_iterator_probs
  180. ),
  181. )
  182. def __init__(
  183. self,
  184. unnormalized_iterator_probs: Dict[str, float],
  185. second_unnormalized_iterator_probs: Dict[str, float],
  186. ) -> None:
  187. super().__init__(unnormalized_iterator_probs)
  188. (
  189. self.second_iterator_names,
  190. self.second_iterator_probs,
  191. ) = extract_iterator_properties(second_unnormalized_iterator_probs)
  192. self.is_secondary_turn = False
  193. def batchify(self, iterators: Dict[str, Iterator]):
  194. assert set(iterators) == set(self.iterator_names).union(
  195. set(self.second_iterator_names)
  196. )
  197. if self.iter_dict is None:
  198. self.iter_dict = {
  199. name: iter(iterator) for name, iterator in iterators.items()
  200. }
  201. while True:
  202. curr_iter = (
  203. self.second_iterator_names
  204. if self.is_secondary_turn
  205. else self.iterator_names
  206. )
  207. curr_probs = (
  208. self.second_iterator_probs
  209. if self.is_secondary_turn
  210. else self.iterator_probs
  211. )
  212. selected_key, batch = select_key_and_batch(
  213. curr_iter, curr_probs, self.iter_dict, iterators
  214. )
  215. self.is_secondary_turn = not self.is_secondary_turn
  216. yield selected_key, batch