/tensorflow_data_validation/statistics/stats_impl.py

https://github.com/tensorflow/data-validation · Python · 935 lines · 559 code · 98 blank · 278 comment · 86 complexity · 40daa778efc41c1a3d2286ded5fe4785 MD5 · raw file

  1. # Copyright 2018 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Implementation of statistics generators."""
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import itertools
  19. import math
  20. import random
  21. import apache_beam as beam
  22. import numpy as np
  23. import pyarrow as pa
  24. import six
  25. from six.moves import zip
  26. from tensorflow_data_validation import constants
  27. from tensorflow_data_validation import types
  28. from tensorflow_data_validation.arrow import arrow_util
  29. from tensorflow_data_validation.statistics import stats_options
  30. from tensorflow_data_validation.statistics.generators import basic_stats_generator
  31. from tensorflow_data_validation.statistics.generators import image_stats_generator
  32. from tensorflow_data_validation.statistics.generators import lift_stats_generator
  33. from tensorflow_data_validation.statistics.generators import natural_language_stats_generator
  34. from tensorflow_data_validation.statistics.generators import sparse_feature_stats_generator
  35. from tensorflow_data_validation.statistics.generators import stats_generator
  36. from tensorflow_data_validation.statistics.generators import time_stats_generator
  37. from tensorflow_data_validation.statistics.generators import top_k_uniques_combiner_stats_generator
  38. from tensorflow_data_validation.statistics.generators import top_k_uniques_stats_generator
  39. from tensorflow_data_validation.statistics.generators import weighted_feature_stats_generator
  40. from tensorflow_data_validation.utils import slicing_util
  41. from tensorflow_data_validation.utils import stats_util
  42. from tfx_bsl.arrow import table_util
  43. from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple
  44. from tensorflow_metadata.proto.v0 import schema_pb2
  45. from tensorflow_metadata.proto.v0 import statistics_pb2
  46. @beam.typehints.with_input_types(pa.RecordBatch)
  47. @beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatisticsList)
  48. class GenerateStatisticsImpl(beam.PTransform):
  49. """PTransform that applies a set of generators over input examples."""
  50. def __init__(
  51. self,
  52. options: stats_options.StatsOptions = stats_options.StatsOptions()
  53. ) -> None:
  54. self._options = options
  55. def expand(self, dataset: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
  56. # If a set of whitelist features are provided, keep only those features.
  57. if self._options.feature_whitelist:
  58. dataset |= ('RemoveNonWhitelistedFeatures' >> beam.Map(
  59. _filter_features, feature_whitelist=self._options.feature_whitelist))
  60. if self._options.slice_functions:
  61. # Add default slicing function.
  62. slice_functions = [slicing_util.default_slicer]
  63. slice_functions.extend(self._options.slice_functions)
  64. dataset = (
  65. dataset
  66. | 'GenerateSliceKeys' >> beam.FlatMap(
  67. slicing_util.generate_slices, slice_functions=slice_functions))
  68. else:
  69. # TODO(pachristopher): Remove this special case if this doesn't give any
  70. # performance improvement.
  71. dataset = (dataset
  72. | 'KeyWithVoid' >> beam.Map(lambda v: (None, v)))
  73. return dataset | GenerateSlicedStatisticsImpl(self._options)
  74. # This transform will be used by the example validation API to compute
  75. # statistics over anomalous examples. Specifically, it is used to compute
  76. # statistics over examples found for each anomaly (i.e., the anomaly type
  77. # will be the slice key).
  78. @beam.typehints.with_input_types(types.BeamSlicedRecordBatch)
  79. @beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatisticsList)
  80. class GenerateSlicedStatisticsImpl(beam.PTransform):
  81. """PTransform that applies a set of generators to sliced input examples."""
  82. def __init__(
  83. self,
  84. options: stats_options.StatsOptions = stats_options.StatsOptions(),
  85. is_slicing_enabled: bool = False,
  86. ) -> None:
  87. """Initializes GenerateSlicedStatisticsImpl.
  88. Args:
  89. options: `tfdv.StatsOptions` for generating data statistics.
  90. is_slicing_enabled: Whether to include slice keys in the resulting proto,
  91. even if slice functions are not provided in `options`. If slice
  92. functions are provided in `options`, slice keys are included regardless
  93. of this value.
  94. """
  95. self._options = options
  96. self._is_slicing_enabled = (
  97. is_slicing_enabled or bool(self._options.slice_functions))
  98. def expand(self, dataset: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
  99. # Handles generators by their type:
  100. # - CombinerStatsGenerators will be wrapped in a single CombinePerKey by
  101. # _CombinerStatsGeneratorsCombineFn.
  102. # - TransformStatsGenerator will be invoked separately with `dataset`.
  103. combiner_stats_generators = []
  104. result_protos = []
  105. for generator in get_generators(self._options):
  106. if isinstance(generator, stats_generator.CombinerStatsGenerator):
  107. combiner_stats_generators.append(generator)
  108. elif isinstance(generator, stats_generator.TransformStatsGenerator):
  109. result_protos.append(
  110. dataset
  111. | generator.name >> generator.ptransform)
  112. else:
  113. raise TypeError('Statistics generator must extend one of '
  114. 'CombinerStatsGenerator or TransformStatsGenerator, '
  115. 'found object of type %s' %
  116. generator.__class__.__name__)
  117. if combiner_stats_generators:
  118. # TODO(b/162543416): Obviate the need for explicit fanout.
  119. fanout = 5 * int(math.ceil(math.sqrt(len(combiner_stats_generators))))
  120. result_protos.append(dataset
  121. | 'RunCombinerStatsGenerators'
  122. >> beam.CombinePerKey(
  123. _CombinerStatsGeneratorsCombineFn(
  124. combiner_stats_generators,
  125. self._options.desired_batch_size
  126. )).with_hot_key_fanout(fanout))
  127. # result_protos is a list of PCollections of (slice key,
  128. # DatasetFeatureStatistics proto) pairs. We now flatten the list into a
  129. # single PCollection, combine the DatasetFeatureStatistics protos by key,
  130. # and then merge the DatasetFeatureStatistics protos in the PCollection into
  131. # a single DatasetFeatureStatisticsList proto.
  132. return (result_protos
  133. | 'FlattenFeatureStatistics' >> beam.Flatten()
  134. | 'MergeDatasetFeatureStatisticsProtos' >>
  135. beam.CombinePerKey(_merge_dataset_feature_stats_protos)
  136. | 'AddSliceKeyToStatsProto' >> beam.Map(
  137. _add_slice_key,
  138. self._is_slicing_enabled)
  139. | 'ToList' >> beam.combiners.ToList()
  140. | 'MakeDatasetFeatureStatisticsListProto' >>
  141. beam.Map(_make_dataset_feature_statistics_list_proto))
  142. def get_generators(options: stats_options.StatsOptions,
  143. in_memory: bool = False
  144. ) -> List[stats_generator.StatsGenerator]:
  145. """Initializes the list of stats generators, including custom generators.
  146. Args:
  147. options: A StatsOptions object.
  148. in_memory: Whether the generators will be used to generate statistics in
  149. memory (True) or using Beam (False).
  150. Returns:
  151. A list of stats generator objects.
  152. """
  153. generators = _get_default_generators(options, in_memory)
  154. if options.generators:
  155. # Add custom stats generators.
  156. generators.extend(options.generators)
  157. if options.enable_semantic_domain_stats:
  158. semantic_domain_feature_stats_generators = [
  159. image_stats_generator.ImageStatsGenerator(),
  160. natural_language_stats_generator.NLStatsGenerator(),
  161. time_stats_generator.TimeStatsGenerator(),
  162. ]
  163. # Wrap semantic domain feature stats generators as a separate combiner
  164. # stats generator, so that we can apply sampling only for those and other
  165. # feature stats generators are not affected by it.
  166. generators.append(
  167. CombinerFeatureStatsWrapperGenerator(
  168. semantic_domain_feature_stats_generators,
  169. weight_feature=options.weight_feature,
  170. sample_rate=options.semantic_domain_stats_sample_rate))
  171. if options.schema is not None:
  172. if _schema_has_sparse_features(options.schema):
  173. generators.append(
  174. sparse_feature_stats_generator.SparseFeatureStatsGenerator(
  175. options.schema))
  176. if options.schema.weighted_feature:
  177. generators.append(
  178. weighted_feature_stats_generator.WeightedFeatureStatsGenerator(
  179. options.schema))
  180. if options.label_feature and not in_memory:
  181. # The LiftStatsGenerator is not a CombinerStatsGenerator and therefore
  182. # cannot currenty be used for in_memory executions.
  183. generators.append(
  184. lift_stats_generator.LiftStatsGenerator(
  185. y_path=types.FeaturePath([options.label_feature]),
  186. schema=options.schema,
  187. weight_column_name=options.weight_feature,
  188. output_custom_stats=True))
  189. # Replace all CombinerFeatureStatsGenerator with a single
  190. # CombinerFeatureStatsWrapperGenerator.
  191. feature_generators = [
  192. x for x in generators
  193. if isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
  194. ]
  195. if feature_generators:
  196. generators = [
  197. x for x in generators
  198. if not isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
  199. ] + [
  200. CombinerFeatureStatsWrapperGenerator(
  201. feature_generators, weight_feature=options.weight_feature)
  202. ]
  203. if in_memory:
  204. for generator in generators:
  205. if not isinstance(generator, stats_generator.CombinerStatsGenerator):
  206. raise TypeError('Statistics generator used in '
  207. 'generate_statistics_in_memory must '
  208. 'extend CombinerStatsGenerator, found object of '
  209. 'type %s.' % generator.__class__.__name__)
  210. return generators
  211. def _get_default_generators(
  212. options: stats_options.StatsOptions, in_memory: bool = False
  213. ) -> List[stats_generator.StatsGenerator]:
  214. """Initializes default list of stats generators.
  215. Args:
  216. options: A StatsOptions object.
  217. in_memory: Whether the generators will be used to generate statistics in
  218. memory (True) or using Beam (False).
  219. Returns:
  220. A list of stats generator objects.
  221. """
  222. stats_generators = [
  223. basic_stats_generator.BasicStatsGenerator(
  224. schema=options.schema,
  225. weight_feature=options.weight_feature,
  226. num_values_histogram_buckets=options.num_values_histogram_buckets,
  227. num_histogram_buckets=options.num_histogram_buckets,
  228. num_quantiles_histogram_buckets=\
  229. options.num_quantiles_histogram_buckets,
  230. epsilon=options.epsilon),
  231. NumExamplesStatsGenerator(options.weight_feature)
  232. ]
  233. if in_memory:
  234. stats_generators.append(
  235. top_k_uniques_combiner_stats_generator
  236. .TopKUniquesCombinerStatsGenerator(
  237. schema=options.schema,
  238. weight_feature=options.weight_feature,
  239. num_top_values=options.num_top_values,
  240. frequency_threshold=options.frequency_threshold,
  241. weighted_frequency_threshold=options.weighted_frequency_threshold,
  242. num_rank_histogram_buckets=options.num_rank_histogram_buckets))
  243. else:
  244. stats_generators.extend([
  245. top_k_uniques_stats_generator.TopKUniquesStatsGenerator(
  246. schema=options.schema,
  247. weight_feature=options.weight_feature,
  248. num_top_values=options.num_top_values,
  249. frequency_threshold=options.frequency_threshold,
  250. weighted_frequency_threshold=options.weighted_frequency_threshold,
  251. num_rank_histogram_buckets=options.num_rank_histogram_buckets),
  252. ])
  253. return stats_generators
  254. def _schema_has_sparse_features(schema: schema_pb2.Schema) -> bool:
  255. """Returns whether there are any sparse features in the specified schema."""
  256. def _has_sparse_features(
  257. feature_container: Iterable[schema_pb2.Feature]
  258. ) -> bool:
  259. """Helper function used to determine whether there are sparse features."""
  260. for f in feature_container:
  261. if isinstance(f, schema_pb2.SparseFeature):
  262. return True
  263. if f.type == schema_pb2.STRUCT:
  264. if f.struct_domain.sparse_feature:
  265. return True
  266. return _has_sparse_features(f.struct_domain.feature)
  267. return False
  268. if schema.sparse_feature:
  269. return True
  270. return _has_sparse_features(schema.feature)
  271. def _filter_features(
  272. record_batch: pa.RecordBatch,
  273. feature_whitelist: List[types.FeatureName]) -> pa.RecordBatch:
  274. """Removes features that are not whitelisted.
  275. Args:
  276. record_batch: Input Arrow RecordBatch.
  277. feature_whitelist: A set of feature names to whitelist.
  278. Returns:
  279. An Arrow RecordBatch containing only the whitelisted features of the input.
  280. """
  281. schema = record_batch.schema
  282. column_names = set(schema.names)
  283. columns_to_select = []
  284. column_names_to_select = []
  285. for feature_name in feature_whitelist:
  286. if feature_name in column_names:
  287. columns_to_select.append(
  288. record_batch.column(schema.get_field_index(feature_name)))
  289. column_names_to_select.append(feature_name)
  290. return pa.RecordBatch.from_arrays(columns_to_select, column_names_to_select)
  291. def _add_slice_key(
  292. stats_proto_per_slice: Tuple[types.SliceKey,
  293. statistics_pb2.DatasetFeatureStatistics],
  294. is_slicing_enabled: bool
  295. ) -> statistics_pb2.DatasetFeatureStatistics:
  296. """Add slice key to stats proto."""
  297. result = statistics_pb2.DatasetFeatureStatistics()
  298. result.CopyFrom(stats_proto_per_slice[1])
  299. if is_slicing_enabled:
  300. result.name = stats_proto_per_slice[0]
  301. return result
  302. def _merge_dataset_feature_stats_protos(
  303. stats_protos: Iterable[statistics_pb2.DatasetFeatureStatistics]
  304. ) -> statistics_pb2.DatasetFeatureStatistics:
  305. """Merges together a list of DatasetFeatureStatistics protos.
  306. Args:
  307. stats_protos: A list of DatasetFeatureStatistics protos to merge.
  308. Returns:
  309. The merged DatasetFeatureStatistics proto.
  310. """
  311. stats_per_feature = {}
  312. # Create a new DatasetFeatureStatistics proto.
  313. result = statistics_pb2.DatasetFeatureStatistics()
  314. # Iterate over each DatasetFeatureStatistics proto and merge the
  315. # FeatureNameStatistics protos per feature and add the cross feature stats.
  316. for stats_proto in stats_protos:
  317. if stats_proto.cross_features:
  318. result.cross_features.extend(stats_proto.cross_features)
  319. for feature_stats_proto in stats_proto.features:
  320. feature_path = types.FeaturePath.from_proto(feature_stats_proto.path)
  321. if feature_path not in stats_per_feature:
  322. # Make a copy for the "cache" since we are modifying it in 'else' below.
  323. new_feature_stats_proto = statistics_pb2.FeatureNameStatistics()
  324. new_feature_stats_proto.CopyFrom(feature_stats_proto)
  325. stats_per_feature[feature_path] = new_feature_stats_proto
  326. else:
  327. stats_for_feature = stats_per_feature[feature_path]
  328. # MergeFrom would concatenate repeated fields which is not what we want
  329. # for path.step.
  330. del stats_for_feature.path.step[:]
  331. stats_for_feature.MergeFrom(feature_stats_proto)
  332. num_examples = None
  333. for feature_stats_proto in six.itervalues(stats_per_feature):
  334. # Add the merged FeatureNameStatistics proto for the feature
  335. # into the DatasetFeatureStatistics proto.
  336. new_feature_stats_proto = result.features.add()
  337. new_feature_stats_proto.CopyFrom(feature_stats_proto)
  338. # Get the number of examples from one of the features that
  339. # has common stats.
  340. if num_examples is None:
  341. stats_type = feature_stats_proto.WhichOneof('stats')
  342. stats_proto = None
  343. if stats_type == 'num_stats':
  344. stats_proto = feature_stats_proto.num_stats
  345. else:
  346. stats_proto = feature_stats_proto.string_stats
  347. if stats_proto.HasField('common_stats'):
  348. num_examples = (stats_proto.common_stats.num_non_missing +
  349. stats_proto.common_stats.num_missing)
  350. # Set the num_examples field.
  351. if num_examples is not None:
  352. result.num_examples = num_examples
  353. return result
  354. def _update_example_and_missing_count(
  355. stats: statistics_pb2.DatasetFeatureStatistics) -> None:
  356. """Updates example count of the dataset and missing count for all features."""
  357. if not stats.features:
  358. return
  359. dummy_feature = stats_util.get_feature_stats(stats, _DUMMY_FEATURE_PATH)
  360. num_examples = stats_util.get_custom_stats(dummy_feature, _NUM_EXAMPLES_KEY)
  361. weighted_num_examples = stats_util.get_custom_stats(
  362. dummy_feature, _WEIGHTED_NUM_EXAMPLES_KEY)
  363. stats.features.remove(dummy_feature)
  364. for feature_stats in stats.features:
  365. # For features nested under a STRUCT feature, their num_missing is computed
  366. # in the basic stats generator (because their num_missing is relative to
  367. # their parent's value count).
  368. if len(feature_stats.path.step) > 1:
  369. continue
  370. common_stats = None
  371. which_oneof_stats = feature_stats.WhichOneof('stats')
  372. if which_oneof_stats is None:
  373. # There are not common_stats for this feature (which can be the case when
  374. # generating only custom_stats for a sparse or weighted feature). In that
  375. # case, simply continue without modifying the common stats.
  376. continue
  377. common_stats = getattr(feature_stats, which_oneof_stats).common_stats
  378. assert num_examples >= common_stats.num_non_missing, (
  379. 'Total number of examples: {} is less than number of non missing '
  380. 'examples: {} for feature {}.'.format(
  381. num_examples, common_stats.num_non_missing,
  382. '.'.join(feature_stats.path.step)))
  383. num_missing = int(num_examples - common_stats.num_non_missing)
  384. common_stats.num_missing = num_missing
  385. if common_stats.presence_and_valency_stats:
  386. common_stats.presence_and_valency_stats[0].num_missing = num_missing
  387. if weighted_num_examples != 0:
  388. weighted_num_missing = (
  389. weighted_num_examples -
  390. common_stats.weighted_common_stats.num_non_missing)
  391. common_stats.weighted_common_stats.num_missing = weighted_num_missing
  392. if common_stats.weighted_presence_and_valency_stats:
  393. common_stats.weighted_presence_and_valency_stats[0].num_missing = (
  394. weighted_num_missing)
  395. stats.num_examples = int(num_examples)
  396. stats.weighted_num_examples = weighted_num_examples
  397. def _make_dataset_feature_statistics_list_proto(
  398. stats_protos: List[statistics_pb2.DatasetFeatureStatistics]
  399. ) -> statistics_pb2.DatasetFeatureStatisticsList:
  400. """Constructs a DatasetFeatureStatisticsList proto.
  401. Args:
  402. stats_protos: List of DatasetFeatureStatistics protos.
  403. Returns:
  404. The DatasetFeatureStatisticsList proto containing the input stats protos.
  405. """
  406. # Create a new DatasetFeatureStatisticsList proto.
  407. result = statistics_pb2.DatasetFeatureStatisticsList()
  408. for stats_proto in stats_protos:
  409. # Add the input DatasetFeatureStatistics proto.
  410. new_stats_proto = result.datasets.add()
  411. new_stats_proto.CopyFrom(stats_proto)
  412. # We now update the example count for the dataset and the missing count
  413. # for all the features, using the number of examples computed separately
  414. # using NumExamplesStatsGenerator. Note that we compute the number of
  415. # examples separately to avoid ignoring example counts for features which
  416. # may be completely missing in a shard. We set the missing count of a
  417. # feature to be num_examples - non_missing_count.
  418. _update_example_and_missing_count(new_stats_proto)
  419. if not stats_protos:
  420. # Handle the case in which there are no examples. In that case, we want to
  421. # output a DatasetFeatureStatisticsList proto with a dataset containing
  422. # num_examples == 0 instead of an empty DatasetFeatureStatisticsList proto.
  423. result.datasets.add(num_examples=0)
  424. return result
  425. _DUMMY_FEATURE_PATH = types.FeaturePath(['__TFDV_INTERNAL_FEATURE__'])
  426. _NUM_EXAMPLES_KEY = '__NUM_EXAMPLES__'
  427. _WEIGHTED_NUM_EXAMPLES_KEY = '__WEIGHTED_NUM_EXAMPLES__'
  428. class NumExamplesStatsGenerator(stats_generator.CombinerStatsGenerator):
  429. """Computes total number of examples."""
  430. def __init__(self,
  431. weight_feature: Optional[types.FeatureName] = None) -> None:
  432. self._weight_feature = weight_feature
  433. def create_accumulator(self) -> List[float]:
  434. return [0, 0] # [num_examples, weighted_num_examples]
  435. def add_input(self, accumulator: List[float],
  436. examples: pa.RecordBatch) -> List[float]:
  437. accumulator[0] += examples.num_rows
  438. if self._weight_feature:
  439. weights_column = examples.column(
  440. examples.schema.get_field_index(self._weight_feature))
  441. accumulator[1] += np.sum(np.asarray(weights_column.flatten()))
  442. return accumulator
  443. def merge_accumulators(self, accumulators: Iterable[List[float]]
  444. ) -> List[float]:
  445. result = self.create_accumulator()
  446. for acc in accumulators:
  447. result[0] += acc[0]
  448. result[1] += acc[1]
  449. return result
  450. def extract_output(self, accumulator: List[float]
  451. ) -> statistics_pb2.DatasetFeatureStatistics:
  452. result = statistics_pb2.DatasetFeatureStatistics()
  453. dummy_feature = result.features.add()
  454. dummy_feature.path.CopyFrom(_DUMMY_FEATURE_PATH.to_proto())
  455. dummy_feature.custom_stats.add(name=_NUM_EXAMPLES_KEY, num=accumulator[0])
  456. dummy_feature.custom_stats.add(name=_WEIGHTED_NUM_EXAMPLES_KEY,
  457. num=accumulator[1])
  458. return result
  459. class _CombinerStatsGeneratorsCombineFnAcc(object):
  460. """accumulator for _CombinerStatsGeneratorsCombineFn."""
  461. __slots__ = [
  462. 'partial_accumulators', 'input_record_batches', 'curr_batch_size',
  463. 'curr_byte_size'
  464. ]
  465. def __init__(self, partial_accumulators: List[Any]):
  466. # Partial accumulator states of the underlying CombinerStatsGenerators.
  467. self.partial_accumulators = partial_accumulators
  468. # Input record batches to be processed.
  469. self.input_record_batches = []
  470. # Current batch size.
  471. self.curr_batch_size = 0
  472. # Current total byte size of all the pa.RecordBatches accumulated.
  473. self.curr_byte_size = 0
  474. @beam.typehints.with_input_types(pa.RecordBatch)
  475. @beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatistics)
  476. class _CombinerStatsGeneratorsCombineFn(beam.CombineFn):
  477. """A beam.CombineFn wrapping a list of CombinerStatsGenerators with batching.
  478. This wrapper does two things:
  479. 1. Wraps a list of combiner stats generators. Its accumulator is a list
  480. of accumulators for each wrapped stats generators.
  481. 2. Batches input examples before passing it to the underlying
  482. stats generators.
  483. We do this by accumulating examples in the combiner state until we
  484. accumulate a large enough batch, at which point we send them through the
  485. add_input step of each of the underlying combiner stats generators. When
  486. merging, we merge the accumulators of the stats generators and accumulate
  487. examples accordingly. We finally process any remaining examples
  488. before producing the final output value.
  489. This wrapper is needed to support slicing as we need the ability to
  490. perform slice-aware batching. But currently there is no way to do key-aware
  491. batching in Beam. Hence, this wrapper does batching and combining together.
  492. See also:
  493. BEAM-3737: Key-aware batching function
  494. (https://issues.apache.org/jira/browse/BEAM-3737).
  495. """
  496. __slots__ = ['_generators', '_desired_batch_size', '_combine_batch_size',
  497. '_combine_byte_size', '_num_compacts', '_num_instances']
  498. # This needs to be large enough to allow for efficient merging of
  499. # accumulators in the individual stats generators, but shouldn't be too large
  500. # as it also acts as cap on the maximum memory usage of the computation.
  501. # TODO(b/73789023): Ideally we should automatically infer the batch size.
  502. _DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE = 100
  503. # The combiner accumulates record batches from the upstream and merges them
  504. # when certain conditions are met. A merged record batch would allow better
  505. # vectorized processing, but we have to pay for copying and the RAM to
  506. # contain the merged record batch. If the total byte size of accumulated
  507. # record batches exceeds this threshold a merge will be forced to avoid
  508. # consuming too much memory.
  509. _MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB
  510. def __init__(
  511. self,
  512. generators: List[stats_generator.CombinerStatsGenerator],
  513. desired_batch_size: Optional[int] = None) -> None:
  514. self._generators = generators
  515. # We really want the batch size to be adaptive like it is in
  516. # beam.BatchElements(), but there isn't an easy way to make it so.
  517. # TODO(b/73789023): Figure out how to make this batch size dynamic.
  518. if desired_batch_size and desired_batch_size > 0:
  519. self._desired_batch_size = desired_batch_size
  520. else:
  521. self._desired_batch_size = constants.DEFAULT_DESIRED_INPUT_BATCH_SIZE
  522. # Metrics
  523. self._combine_batch_size = beam.metrics.Metrics.distribution(
  524. constants.METRICS_NAMESPACE, 'combine_batch_size')
  525. self._combine_byte_size = beam.metrics.Metrics.distribution(
  526. constants.METRICS_NAMESPACE, 'combine_byte_size')
  527. self._num_compacts = beam.metrics.Metrics.counter(
  528. constants.METRICS_NAMESPACE, 'num_compacts')
  529. self._num_instances = beam.metrics.Metrics.counter(
  530. constants.METRICS_NAMESPACE, 'num_instances')
  531. def _for_each_generator(self,
  532. func: Callable[..., Any],
  533. *args: Iterable[Any]) -> List[Any]:
  534. """Apply `func` for each wrapped generators.
  535. Args:
  536. func: a function that takes N + 1 arguments where N is the size of `args`.
  537. the first argument is the stats generator.
  538. *args: Iterables parallel to wrapped stats generators (i.e. the i-th item
  539. corresponds to the self._generators[i]).
  540. Returns:
  541. A list whose i-th element is the result of
  542. func(self._generators[i], args[0][i], args[1][i], ...).
  543. """
  544. return [func(gen, *args_for_func) for gen, args_for_func in zip(
  545. self._generators, zip(*args))]
  546. def create_accumulator(self
  547. ) -> _CombinerStatsGeneratorsCombineFnAcc: # pytype: disable=invalid-annotation
  548. return _CombinerStatsGeneratorsCombineFnAcc(
  549. [g.create_accumulator() for g in self._generators])
  550. def _should_do_batch(self, accumulator: _CombinerStatsGeneratorsCombineFnAcc,
  551. force: bool) -> bool:
  552. curr_batch_size = accumulator.curr_batch_size
  553. if force and curr_batch_size > 0:
  554. return True
  555. if curr_batch_size >= self._desired_batch_size:
  556. return True
  557. if (accumulator.curr_byte_size >=
  558. self._MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD):
  559. return True
  560. return False
  561. def _maybe_do_batch(
  562. self,
  563. accumulator: _CombinerStatsGeneratorsCombineFnAcc,
  564. force: bool = False) -> None:
  565. """Maybe updates accumulator in place.
  566. Checks if accumulator has enough examples for a batch, and if so, does the
  567. stats computation for the batch and updates accumulator in place.
  568. Args:
  569. accumulator: Accumulator. Will be updated in place.
  570. force: Force computation of stats even if accumulator has less examples
  571. than the batch size.
  572. """
  573. if self._should_do_batch(accumulator, force):
  574. self._combine_batch_size.update(accumulator.curr_batch_size)
  575. self._combine_byte_size.update(accumulator.curr_byte_size)
  576. if len(accumulator.input_record_batches) == 1:
  577. record_batch = accumulator.input_record_batches[0]
  578. else:
  579. record_batch = table_util.MergeRecordBatches(
  580. accumulator.input_record_batches)
  581. accumulator.partial_accumulators = self._for_each_generator(
  582. lambda gen, gen_acc: gen.add_input(gen_acc, record_batch),
  583. accumulator.partial_accumulators)
  584. del accumulator.input_record_batches[:]
  585. accumulator.curr_batch_size = 0
  586. accumulator.curr_byte_size = 0
  587. def add_input(
  588. self, accumulator: _CombinerStatsGeneratorsCombineFnAcc,
  589. input_record_batch: pa.RecordBatch
  590. ) -> _CombinerStatsGeneratorsCombineFnAcc:
  591. accumulator.input_record_batches.append(input_record_batch)
  592. num_rows = input_record_batch.num_rows
  593. accumulator.curr_batch_size += num_rows
  594. accumulator.curr_byte_size += table_util.TotalByteSize(input_record_batch)
  595. self._maybe_do_batch(accumulator)
  596. self._num_instances.inc(num_rows)
  597. return accumulator
  598. def merge_accumulators(
  599. self,
  600. accumulators: Iterable[_CombinerStatsGeneratorsCombineFnAcc]
  601. ) -> _CombinerStatsGeneratorsCombineFnAcc:
  602. result = self.create_accumulator()
  603. # Make sure accumulators is an iterator (so it remembers its position).
  604. accumulators = iter(accumulators)
  605. while True:
  606. # Repeatedly take the next N from `accumulators` (an iterator).
  607. # If there are less than N remaining, all is taken.
  608. batched_accumulators = list(itertools.islice(
  609. accumulators, self._DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE))
  610. if not batched_accumulators:
  611. break
  612. # Batch together remaining examples in each accumulator, and
  613. # feed to each generator. Note that there might still be remaining
  614. # examples after this, but a compact() might follow and flush the
  615. # remaining examples, and extract_output() in the end will flush anyways.
  616. batched_partial_accumulators = []
  617. for acc in batched_accumulators:
  618. result.input_record_batches.extend(acc.input_record_batches)
  619. result.curr_batch_size += acc.curr_batch_size
  620. result.curr_byte_size += acc.curr_byte_size
  621. self._maybe_do_batch(result)
  622. batched_partial_accumulators.append(acc.partial_accumulators)
  623. batched_accumulators_by_generator = list(
  624. zip(*batched_partial_accumulators))
  625. result.partial_accumulators = self._for_each_generator(
  626. lambda gen, b, m: gen.merge_accumulators(itertools.chain((b,), m)),
  627. result.partial_accumulators,
  628. batched_accumulators_by_generator)
  629. return result
  630. # TODO(pachristopher): Consider adding CombinerStatsGenerator.compact method.
  631. def compact(
  632. self,
  633. accumulator: _CombinerStatsGeneratorsCombineFnAcc
  634. ) -> _CombinerStatsGeneratorsCombineFnAcc:
  635. self._maybe_do_batch(accumulator, force=True)
  636. self._num_compacts.inc(1)
  637. return accumulator
  638. def extract_output(
  639. self,
  640. accumulator: _CombinerStatsGeneratorsCombineFnAcc
  641. ) -> statistics_pb2.DatasetFeatureStatistics: # pytype: disable=invalid-annotation
  642. # Make sure we have processed all the examples.
  643. self._maybe_do_batch(accumulator, force=True)
  644. return _merge_dataset_feature_stats_protos(
  645. self._for_each_generator(lambda gen, acc: gen.extract_output(acc),
  646. accumulator.partial_accumulators))
  647. def generate_partial_statistics_in_memory(
  648. record_batch: pa.RecordBatch, options: stats_options.StatsOptions,
  649. stats_generators: List[stats_generator.CombinerStatsGenerator]
  650. ) -> List[Any]:
  651. """Generates statistics for an in-memory list of examples.
  652. Args:
  653. record_batch: Arrow RecordBatch.
  654. options: Options for generating data statistics.
  655. stats_generators: A list of combiner statistics generators.
  656. Returns:
  657. A list of accumulators containing partial statistics.
  658. """
  659. result = []
  660. if options.feature_whitelist:
  661. schema = record_batch.schema
  662. whitelisted_columns = [
  663. record_batch.column(schema.get_field_index(f))
  664. for f in options.feature_whitelist
  665. ]
  666. record_batch = pa.RecordBatch.from_arrays(whitelisted_columns,
  667. list(options.feature_whitelist))
  668. for generator in stats_generators:
  669. result.append(
  670. generator.add_input(generator.create_accumulator(), record_batch))
  671. return result
  672. def generate_statistics_in_memory(
  673. record_batch: pa.RecordBatch,
  674. options: stats_options.StatsOptions = stats_options.StatsOptions()
  675. ) -> statistics_pb2.DatasetFeatureStatisticsList:
  676. """Generates statistics for an in-memory list of examples.
  677. Args:
  678. record_batch: Arrow RecordBatch.
  679. options: Options for generating data statistics.
  680. Returns:
  681. A DatasetFeatureStatisticsList proto.
  682. """
  683. stats_generators = get_generators(options, in_memory=True) # type: List[stats_generator.CombinerStatsGenerator]
  684. partial_stats = generate_partial_statistics_in_memory(record_batch, options,
  685. stats_generators)
  686. return extract_statistics_output(partial_stats, stats_generators)
  687. def extract_statistics_output(
  688. partial_stats: List[Any],
  689. stats_generators: List[stats_generator.CombinerStatsGenerator]
  690. ) -> statistics_pb2.DatasetFeatureStatisticsList:
  691. """Extracts final stats output from the accumulators holding partial stats."""
  692. outputs = [
  693. gen.extract_output(stats)
  694. for (gen, stats) in zip(stats_generators, partial_stats) # pytype: disable=attribute-error
  695. ]
  696. return _make_dataset_feature_statistics_list_proto(
  697. [_merge_dataset_feature_stats_protos(outputs)])
  698. # Type for the wrapper_accumulator of a CombinerFeatureStatsWrapperGenerator.
  699. # See documentation below for more details.
  700. WrapperAccumulator = Dict[types.FeaturePath, List[Any]]
  701. class CombinerFeatureStatsWrapperGenerator(
  702. stats_generator.CombinerStatsGenerator):
  703. """A combiner that wraps multiple CombinerFeatureStatsGenerators.
  704. This combiner wraps multiple CombinerFeatureStatsGenerators by generating
  705. and updating wrapper_accumulators where:
  706. wrapper_accumulator[feature_path][feature_generator_index] contains the
  707. generator specific accumulator for the pair (feature_path,
  708. feature_generator_index).
  709. """
  710. def __init__(self,
  711. feature_stats_generators: List[
  712. stats_generator.CombinerFeatureStatsGenerator],
  713. name: Text = 'CombinerFeatureStatsWrapperGenerator',
  714. schema: Optional[schema_pb2.Schema] = None,
  715. weight_feature: Optional[types.FeatureName] = None,
  716. sample_rate: Optional[float] = None) -> None:
  717. """Initializes a CombinerFeatureStatsWrapperGenerator.
  718. Args:
  719. feature_stats_generators: A list of CombinerFeatureStatsGenerator.
  720. name: An optional unique name associated with the statistics generator.
  721. schema: An optional schema for the dataset.
  722. weight_feature: An optional feature name whose numeric value represents
  723. the weight of an example. Currently the weight feature is ignored by
  724. feature level stats generators.
  725. sample_rate: An optional sampling rate. If specified, statistics is
  726. computed over the sample.
  727. """
  728. super(CombinerFeatureStatsWrapperGenerator, self).__init__(name, schema)
  729. self._feature_stats_generators = feature_stats_generators
  730. self._weight_feature = weight_feature
  731. self._sample_rate = sample_rate
  732. def _perhaps_initialize_for_feature_path(
  733. self, wrapper_accumulator: WrapperAccumulator,
  734. feature_path: types.FeaturePath) -> None:
  735. """Initializes the feature_path key if it does not exist."""
  736. # Note: This manual initialization could have been avoided if
  737. # wrapper_accumulator was a defaultdict, but this breaks pickling.
  738. if feature_path not in wrapper_accumulator:
  739. wrapper_accumulator[feature_path] = [
  740. generator.create_accumulator()
  741. for generator in self._feature_stats_generators
  742. ]
  743. def create_accumulator(self) -> WrapperAccumulator:
  744. """Returns a fresh, empty wrapper_accumulator.
  745. Returns:
  746. An empty wrapper_accumulator.
  747. """
  748. return {}
  749. def add_input(self, wrapper_accumulator: WrapperAccumulator,
  750. input_record_batch: pa.RecordBatch) -> WrapperAccumulator:
  751. """Returns result of folding a batch of inputs into wrapper_accumulator.
  752. Args:
  753. wrapper_accumulator: The current wrapper accumulator.
  754. input_record_batch: An arrow RecordBatch representing a batch of examples,
  755. which should be added to the accumulator.
  756. Returns:
  757. The wrapper_accumulator after updating the statistics for the batch of
  758. inputs.
  759. """
  760. if self._sample_rate is not None and random.random() <= self._sample_rate:
  761. return wrapper_accumulator
  762. for feature_path, feature_array, _ in arrow_util.enumerate_arrays(
  763. input_record_batch,
  764. weight_column=self._weight_feature,
  765. enumerate_leaves_only=True):
  766. for index, generator in enumerate(self._feature_stats_generators):
  767. self._perhaps_initialize_for_feature_path(wrapper_accumulator,
  768. feature_path)
  769. wrapper_accumulator[feature_path][index] = generator.add_input(
  770. generator.create_accumulator(), feature_path, feature_array)
  771. return wrapper_accumulator
  772. def merge_accumulators(
  773. self,
  774. wrapper_accumulators: Iterable[WrapperAccumulator]) -> WrapperAccumulator:
  775. """Merges several wrapper_accumulators to a single one.
  776. Args:
  777. wrapper_accumulators: The wrapper accumulators to merge.
  778. Returns:
  779. The merged accumulator.
  780. """
  781. result = self.create_accumulator()
  782. for wrapper_accumulator in wrapper_accumulators:
  783. for feature_path, accumulator_for_feature in six.iteritems(
  784. wrapper_accumulator):
  785. self._perhaps_initialize_for_feature_path(result, feature_path)
  786. for index, generator in enumerate(self._feature_stats_generators):
  787. result[feature_path][index] = generator.merge_accumulators(
  788. [result[feature_path][index], accumulator_for_feature[index]])
  789. return result
  790. def extract_output(self, wrapper_accumulator: WrapperAccumulator
  791. ) -> statistics_pb2.DatasetFeatureStatistics:
  792. """Returns result of converting wrapper_accumulator into the output value.
  793. Args:
  794. wrapper_accumulator: The final wrapper_accumulator value.
  795. Returns:
  796. A proto representing the result of this stats generator.
  797. """
  798. result = statistics_pb2.DatasetFeatureStatistics()
  799. for feature_path, accumulator_for_feature in six.iteritems(
  800. wrapper_accumulator):
  801. feature_stats = result.features.add()
  802. feature_stats.path.CopyFrom(feature_path.to_proto())
  803. for index, generator in enumerate(self._feature_stats_generators):
  804. feature_stats.MergeFrom(
  805. generator.extract_output(accumulator_for_feature[index]))
  806. return result