PageRenderTime 44ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/tensorflow/contrib/metrics/python/ops/histogram_ops.py

https://gitlab.com/hrishikeshvganu/tensorflow
Python | 238 lines | 210 code | 4 blank | 24 comment | 0 complexity | 1b3a31c28baccf5e423200397422d852 MD5 | raw file
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. # pylint: disable=g-short-docstring-punctuation
  16. """## Metrics that use histograms.
  17. @@auc_using_histogram
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. from tensorflow.python.framework import dtypes
  23. from tensorflow.python.framework import ops
  24. from tensorflow.python.ops import array_ops
  25. from tensorflow.python.ops import constant_op
  26. from tensorflow.python.ops import control_flow_ops
  27. from tensorflow.python.ops import histogram_ops
  28. from tensorflow.python.ops import logging_ops
  29. from tensorflow.python.ops import math_ops
  30. from tensorflow.python.ops import nn_ops
  31. from tensorflow.python.ops import variable_scope
  32. def auc_using_histogram(boolean_labels,
  33. scores,
  34. score_range,
  35. nbins=100,
  36. collections=None,
  37. check_shape=True,
  38. name=None):
  39. """AUC computed by maintaining histograms.
  40. Rather than computing AUC directly, this Op maintains Variables containing
  41. histograms of the scores associated with `True` and `False` labels. By
  42. comparing these the AUC is generated, with some discretization error.
  43. See: "Efficient AUC Learning Curve Calculation" by Bouckaert.
  44. This AUC Op updates in `O(batch_size + nbins)` time and works well even with
  45. large class imbalance. The accuracy is limited by discretization error due
  46. to finite number of bins. If scores are concentrated in a fewer bins,
  47. accuracy is lower. If this is a concern, we recommend trying different
  48. numbers of bins and comparing results.
  49. Args:
  50. boolean_labels: 1-D boolean `Tensor`. Entry is `True` if the corresponding
  51. record is in class.
  52. scores: 1-D numeric `Tensor`, same shape as boolean_labels.
  53. score_range: `Tensor` of shape `[2]`, same dtype as `scores`. The min/max
  54. values of score that we expect. Scores outside range will be clipped.
  55. nbins: Integer number of bins to use. Accuracy strictly increases as the
  56. number of bins increases.
  57. collections: List of graph collections keys. Internal histogram Variables
  58. are added to these collections. Defaults to `[GraphKeys.LOCAL_VARIABLES]`.
  59. check_shape: Boolean. If `True`, do a runtime shape check on the scores
  60. and labels.
  61. name: A name for this Op. Defaults to "auc_using_histogram".
  62. Returns:
  63. auc: `float32` scalar `Tensor`. Fetching this converts internal histograms
  64. to auc value.
  65. update_op: `Op`, when run, updates internal histograms.
  66. """
  67. if collections is None:
  68. collections = [ops.GraphKeys.LOCAL_VARIABLES]
  69. with variable_scope.variable_op_scope(
  70. [boolean_labels, scores, score_range], name, 'auc_using_histogram'):
  71. score_range = ops.convert_to_tensor(score_range, name='score_range')
  72. boolean_labels, scores = _check_labels_and_scores(
  73. boolean_labels, scores, check_shape)
  74. hist_true, hist_false = _make_auc_histograms(boolean_labels, scores,
  75. score_range, nbins)
  76. hist_true_acc, hist_false_acc, update_op = _auc_hist_accumulate(hist_true,
  77. hist_false,
  78. nbins,
  79. collections)
  80. auc = _auc_convert_hist_to_auc(hist_true_acc, hist_false_acc, nbins)
  81. return auc, update_op
  82. def _check_labels_and_scores(boolean_labels, scores, check_shape):
  83. """Check the rank of labels/scores, return tensor versions."""
  84. with ops.op_scope([boolean_labels, scores], '_check_labels_and_scores'):
  85. boolean_labels = ops.convert_to_tensor(boolean_labels,
  86. name='boolean_labels')
  87. scores = ops.convert_to_tensor(scores, name='scores')
  88. if boolean_labels.dtype != dtypes.bool:
  89. raise ValueError(
  90. 'Argument boolean_labels should have dtype bool. Found: %s',
  91. boolean_labels.dtype)
  92. if check_shape:
  93. labels_rank_1 = logging_ops.Assert(
  94. math_ops.equal(1, array_ops.rank(boolean_labels)),
  95. ['Argument boolean_labels should have rank 1. Found: ',
  96. boolean_labels.name, array_ops.shape(boolean_labels)])
  97. scores_rank_1 = logging_ops.Assert(
  98. math_ops.equal(1, array_ops.rank(scores)),
  99. ['Argument scores should have rank 1. Found: ', scores.name,
  100. array_ops.shape(scores)])
  101. with ops.control_dependencies([labels_rank_1, scores_rank_1]):
  102. return boolean_labels, scores
  103. else:
  104. return boolean_labels, scores
  105. def _make_auc_histograms(boolean_labels, scores, score_range, nbins):
  106. """Create histogram tensors from one batch of labels/scores."""
  107. with variable_scope.variable_op_scope(
  108. [boolean_labels, scores, nbins], None, 'make_auc_histograms'):
  109. # Histogram of scores for records in this batch with True label.
  110. hist_true = histogram_ops.histogram_fixed_width(
  111. array_ops.boolean_mask(scores, boolean_labels),
  112. score_range,
  113. nbins=nbins,
  114. dtype=dtypes.int64,
  115. name='hist_true')
  116. # Histogram of scores for records in this batch with False label.
  117. hist_false = histogram_ops.histogram_fixed_width(
  118. array_ops.boolean_mask(scores, math_ops.logical_not(boolean_labels)),
  119. score_range,
  120. nbins=nbins,
  121. dtype=dtypes.int64,
  122. name='hist_false')
  123. return hist_true, hist_false
  124. def _auc_hist_accumulate(hist_true, hist_false, nbins, collections):
  125. """Accumulate histograms in new variables."""
  126. with variable_scope.variable_op_scope(
  127. [hist_true, hist_false], None, 'hist_accumulate'):
  128. # Holds running total histogram of scores for records labeled True.
  129. hist_true_acc = variable_scope.get_variable(
  130. 'hist_true_acc',
  131. initializer=array_ops.zeros_initializer(
  132. [nbins],
  133. dtype=hist_true.dtype),
  134. collections=collections,
  135. trainable=False)
  136. # Holds running total histogram of scores for records labeled False.
  137. hist_false_acc = variable_scope.get_variable(
  138. 'hist_false_acc',
  139. initializer=array_ops.zeros_initializer(
  140. [nbins],
  141. dtype=hist_false.dtype),
  142. collections=collections,
  143. trainable=False)
  144. update_op = control_flow_ops.group(
  145. hist_true_acc.assign_add(hist_true),
  146. hist_false_acc.assign_add(hist_false),
  147. name='update_op')
  148. return hist_true_acc, hist_false_acc, update_op
  149. def _auc_convert_hist_to_auc(hist_true_acc, hist_false_acc, nbins):
  150. """Convert histograms to auc.
  151. Args:
  152. hist_true_acc: `Tensor` holding accumulated histogram of scores for records
  153. that were `True`.
  154. hist_false_acc: `Tensor` holding accumulated histogram of scores for
  155. records that were `False`.
  156. nbins: Integer number of bins in the histograms.
  157. Returns:
  158. Scalar `Tensor` estimating AUC.
  159. """
  160. # Note that this follows the "Approximating AUC" section in:
  161. # Efficient AUC learning curve calculation, R. R. Bouckaert,
  162. # AI'06 Proceedings of the 19th Australian joint conference on Artificial
  163. # Intelligence: advances in Artificial Intelligence
  164. # Pages 181-191.
  165. # Note that the above paper has an error, and we need to re-order our bins to
  166. # go from high to low score.
  167. # Normalize histogram so we get fraction in each bin.
  168. normed_hist_true = math_ops.truediv(hist_true_acc,
  169. math_ops.reduce_sum(hist_true_acc))
  170. normed_hist_false = math_ops.truediv(hist_false_acc,
  171. math_ops.reduce_sum(hist_false_acc))
  172. # These become delta x, delta y from the paper.
  173. delta_y_t = array_ops.reverse(normed_hist_true, [True], name='delta_y_t')
  174. delta_x_t = array_ops.reverse(normed_hist_false, [True], name='delta_x_t')
  175. # strict_1d_cumsum requires float32 args.
  176. delta_y_t = math_ops.cast(delta_y_t, dtypes.float32)
  177. delta_x_t = math_ops.cast(delta_x_t, dtypes.float32)
  178. # Trapezoidal integration, \int_0^1 0.5 * (y_t + y_{t-1}) dx_t
  179. y_t = _strict_1d_cumsum(delta_y_t, nbins)
  180. first_trap = delta_x_t[0] * y_t[0] / 2.0
  181. other_traps = delta_x_t[1:] * (y_t[1:] + y_t[:nbins - 1]) / 2.0
  182. return math_ops.add(first_trap, math_ops.reduce_sum(other_traps), name='auc')
  183. # TODO(langmore) Remove once a faster cumsum (accumulate_sum) Op is available.
  184. # Also see if cast to float32 above can be removed with new cumsum.
  185. # See: https://github.com/tensorflow/tensorflow/issues/813
  186. def _strict_1d_cumsum(tensor, len_tensor):
  187. """Cumsum of a 1D tensor with defined shape by padding and convolving."""
  188. # Assumes tensor shape is fully defined.
  189. with ops.op_scope([tensor], 'strict_1d_cumsum'):
  190. if len_tensor == 0:
  191. return constant_op.constant([])
  192. len_pad = len_tensor - 1
  193. x = array_ops.pad(tensor, [[len_pad, 0]])
  194. h = array_ops.ones_like(x)
  195. return _strict_conv1d(x, h)[:len_tensor]
  196. # TODO(langmore) Remove once a faster cumsum (accumulate_sum) Op is available.
  197. # See: https://github.com/tensorflow/tensorflow/issues/813
  198. def _strict_conv1d(x, h):
  199. """Return x * h for rank 1 tensors x and h."""
  200. with ops.op_scope([x, h], 'strict_conv1d'):
  201. x = array_ops.reshape(x, (1, -1, 1, 1))
  202. h = array_ops.reshape(h, (-1, 1, 1, 1))
  203. result = nn_ops.conv2d(x, h, [1, 1, 1, 1], 'SAME')
  204. return array_ops.reshape(result, [-1])