PageRenderTime 46ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_functions.py

https://gitlab.com/github-cloud-corporation/tensorflow
Python | 202 lines | 133 code | 20 blank | 49 comment | 26 complexity | f74c9ee78c196567363a3af54d7f1d34 MD5 | raw file
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Helper functions for enqueuing data from arrays and pandas `DataFrame`s."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import random
  20. import numpy as np
  21. from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_queue_runner as fqr
  22. from tensorflow.python.framework import dtypes
  23. from tensorflow.python.framework import ops
  24. from tensorflow.python.ops import array_ops
  25. from tensorflow.python.ops import data_flow_ops
  26. from tensorflow.python.ops import logging_ops
  27. from tensorflow.python.ops import math_ops
  28. from tensorflow.python.platform import tf_logging as logging
  29. from tensorflow.python.training import queue_runner
  30. # pylint: disable=g-import-not-at-top
  31. try:
  32. import pandas as pd
  33. HAS_PANDAS = True
  34. except ImportError:
  35. HAS_PANDAS = False
  36. class _ArrayFeedFn(object):
  37. """Creates feed dictionaries from numpy arrays."""
  38. def __init__(self,
  39. placeholders,
  40. array,
  41. batch_size,
  42. random_start=False,
  43. seed=None):
  44. if len(placeholders) != 2:
  45. raise ValueError("_array_feed_fn expects 2 placeholders; got {}.".format(
  46. len(placeholders)))
  47. self._placeholders = placeholders
  48. self._array = array
  49. self._max = len(array)
  50. self._batch_size = batch_size
  51. random.seed(seed)
  52. self._trav = random.randrange(self._max) if random_start else 0
  53. def __call__(self):
  54. integer_indexes = [j % self._max
  55. for j in range(self._trav, self._trav + self._batch_size)
  56. ]
  57. self._trav = (integer_indexes[-1] + 1) % self._max
  58. return {self._placeholders[0]: integer_indexes,
  59. self._placeholders[1]: self._array[integer_indexes]}
  60. class _PandasFeedFn(object):
  61. """Creates feed dictionaries from pandas `DataFrames`."""
  62. def __init__(self,
  63. placeholders,
  64. dataframe,
  65. batch_size,
  66. random_start=False,
  67. seed=None):
  68. if len(placeholders) != len(dataframe.columns) + 1:
  69. raise ValueError("Expected {} placeholders; got {}.".format(
  70. len(dataframe.columns), len(placeholders)))
  71. self._index_placeholder = placeholders[0]
  72. self._col_placeholders = placeholders[1:]
  73. self._dataframe = dataframe
  74. self._max = len(dataframe)
  75. self._batch_size = batch_size
  76. random.seed(seed)
  77. self._trav = random.randrange(self._max) if random_start else 0
  78. def __call__(self):
  79. integer_indexes = [j % self._max
  80. for j in range(self._trav, self._trav + self._batch_size)
  81. ]
  82. self._trav = (integer_indexes[-1] + 1) % self._max
  83. result = self._dataframe.iloc[integer_indexes]
  84. cols = [result[col].values for col in result.columns]
  85. feed_dict = dict(zip(self._col_placeholders, cols))
  86. feed_dict[self._index_placeholder] = result.index.values
  87. return feed_dict
  88. def enqueue_data(data,
  89. capacity,
  90. shuffle=False,
  91. min_after_dequeue=None,
  92. num_threads=1,
  93. seed=None,
  94. name="enqueue_input",
  95. enqueue_size=1):
  96. """Creates a queue filled from a numpy array or pandas `DataFrame`.
  97. Returns a queue filled with the rows of the given array or `DataFrame`. In
  98. the case of a pandas `DataFrame`, the first enqueued `Tensor` corresponds to
  99. the index of the `DataFrame`. For numpy arrays, the first enqueued `Tensor`
  100. contains the row number.
  101. Args:
  102. data: a numpy `ndarray or` pandas `DataFrame` that will be read into the
  103. queue.
  104. capacity: the capacity of the queue.
  105. shuffle: whether or not to shuffle the rows of the array.
  106. min_after_dequeue: minimum number of elements that can remain in the queue
  107. after a dequeue operation. Only used when `shuffle` is true. If not set,
  108. defaults to `capacity` / 4.
  109. num_threads: number of threads used for reading and enqueueing.
  110. seed: used to seed shuffling and reader starting points.
  111. name: a scope name identifying the data.
  112. enqueue_size: the number of rows to enqueue per step.
  113. Returns:
  114. A queue filled with the rows of the given array or `DataFrame`.
  115. Raises:
  116. TypeError: `data` is not a Pandas `DataFrame` or a numpy `ndarray`.
  117. """
  118. with ops.name_scope(name):
  119. if isinstance(data, np.ndarray):
  120. types = [dtypes.int64, dtypes.as_dtype(data.dtype)]
  121. queue_shapes = [(), data.shape[1:]]
  122. get_feed_fn = _ArrayFeedFn
  123. elif HAS_PANDAS and isinstance(data, pd.DataFrame):
  124. types = [dtypes.as_dtype(dt)
  125. for dt in [data.index.dtype] + list(data.dtypes)]
  126. queue_shapes = [() for _ in types]
  127. get_feed_fn = _PandasFeedFn
  128. else:
  129. raise TypeError(
  130. "data must be either a numpy array or pandas DataFrame if pandas is "
  131. "installed; got {}".format(type(data).__name__))
  132. if shuffle:
  133. min_after_dequeue = int(capacity / 4 if min_after_dequeue is None else
  134. min_after_dequeue)
  135. queue = data_flow_ops.RandomShuffleQueue(capacity,
  136. min_after_dequeue,
  137. dtypes=types,
  138. shapes=queue_shapes,
  139. seed=seed)
  140. else:
  141. if num_threads > 1:
  142. # TODO(jamieas): Add TensorBoard warning here once available.
  143. logging.warning(
  144. "enqueue_data was called with shuffle=False and num_threads > 1. "
  145. "This will create multiple threads, all reading the "
  146. "array/dataframe in order. If you want examples read in order, use"
  147. " one thread; if you want multiple threads, enable shuffling.")
  148. min_after_dequeue = 0 # just for the summary text
  149. queue = data_flow_ops.FIFOQueue(capacity,
  150. dtypes=types,
  151. shapes=queue_shapes)
  152. enqueue_ops = []
  153. feed_fns = []
  154. for i in range(num_threads):
  155. # Note the placeholders have no shapes, so they will accept any
  156. # enqueue_size. enqueue_many below will break them up.
  157. placeholders = [array_ops.placeholder(t) for t in types]
  158. enqueue_ops.append(queue.enqueue_many(placeholders))
  159. seed_i = None if seed is None else (i + 1) * seed
  160. feed_fns.append(get_feed_fn(placeholders,
  161. data,
  162. enqueue_size,
  163. random_start=shuffle,
  164. seed=seed_i))
  165. runner = fqr.FeedingQueueRunner(queue=queue,
  166. enqueue_ops=enqueue_ops,
  167. feed_fns=feed_fns)
  168. queue_runner.add_queue_runner(runner)
  169. full = (math_ops.cast(
  170. math_ops.maximum(0, queue.size() - min_after_dequeue),
  171. dtypes.float32) * (1. / (capacity - min_after_dequeue)))
  172. # Note that name contains a '/' at the end so we intentionally do not place
  173. # a '/' after %s below.
  174. summary_name = ("queue/%sfraction_over_%d_of_%d_full" %
  175. (queue.name, min_after_dequeue,
  176. capacity - min_after_dequeue))
  177. logging_ops.scalar_summary(summary_name, full)
  178. return queue