PageRenderTime 79ms CodeModel.GetById 24ms RepoModel.GetById 0ms app.codeStats 0ms

/tensorflow/python/training/queue_runner.py

https://gitlab.com/github-cloud-corporation/tensorflow
Python | 357 lines | 329 code | 6 blank | 22 comment | 1 complexity | 193483f02a21e4e9a7654be43e367297 MD5 | raw file
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Create threads to run multiple enqueue ops."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import threading
  20. from tensorflow.core.protobuf import queue_runner_pb2
  21. from tensorflow.python.framework import errors
  22. from tensorflow.python.framework import ops
  23. from tensorflow.python.platform import tf_logging as logging
  24. class QueueRunner(object):
  25. """Holds a list of enqueue operations for a queue, each to be run in a thread.
  26. Queues are a convenient TensorFlow mechanism to compute tensors
  27. asynchronously using multiple threads. For example in the canonical 'Input
  28. Reader' setup one set of threads generates filenames in a queue; a second set
  29. of threads read records from the files, processes them, and enqueues tensors
  30. on a second queue; a third set of threads dequeues these input records to
  31. construct batches and runs them through training operations.
  32. There are several delicate issues when running multiple threads that way:
  33. closing the queues in sequence as the input is exhausted, correctly catching
  34. and reporting exceptions, etc.
  35. The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
  36. """
  37. def __init__(self, queue=None, enqueue_ops=None, close_op=None,
  38. cancel_op=None, queue_runner_def=None):
  39. """Create a QueueRunner.
  40. On construction the `QueueRunner` adds an op to close the queue. That op
  41. will be run if the enqueue ops raise exceptions.
  42. When you later call the `create_threads()` method, the `QueueRunner` will
  43. create one thread for each op in `enqueue_ops`. Each thread will run its
  44. enqueue op in parallel with the other threads. The enqueue ops do not have
  45. to all be the same op, but it is expected that they all enqueue tensors in
  46. `queue`.
  47. Args:
  48. queue: A `Queue`.
  49. enqueue_ops: List of enqueue ops to run in threads later.
  50. close_op: Op to close the queue. Pending enqueue ops are preserved.
  51. cancel_op: Op to close the queue and cancel pending enqueue ops.
  52. queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
  53. recreates the QueueRunner from its contents. `queue_runner_def` and the
  54. other arguments are mutually exclusive.
  55. Raises:
  56. ValueError: If both `queue_runner_def` and `queue` are both specified.
  57. ValueError: If `queue` or `enqueue_ops` are not provided when not
  58. restoring from `queue_runner_def`.
  59. """
  60. if queue_runner_def:
  61. if queue or enqueue_ops:
  62. raise ValueError("queue_runner_def and queue are mutually exclusive.")
  63. self._init_from_proto(queue_runner_def)
  64. else:
  65. self._init_from_args(queue=queue, enqueue_ops=enqueue_ops,
  66. close_op=close_op, cancel_op=cancel_op)
  67. # Protect the count of runs to wait for.
  68. self._lock = threading.Lock()
  69. self._runs = 0
  70. # List of exceptions raised by the running threads.
  71. self._exceptions_raised = []
  72. def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
  73. cancel_op=None):
  74. """Create a QueueRunner from arguments.
  75. Args:
  76. queue: A `Queue`.
  77. enqueue_ops: List of enqueue ops to run in threads later.
  78. close_op: Op to close the queue. Pending enqueue ops are preserved.
  79. cancel_op: Op to close the queue and cancel pending enqueue ops.
  80. Raises:
  81. ValueError: If `queue` or `enqueue_ops` are not provided when not
  82. restoring from `queue_runner_def`.
  83. """
  84. if not queue or not enqueue_ops:
  85. raise ValueError("Must provide queue and enqueue_ops.")
  86. self._queue = queue
  87. self._enqueue_ops = enqueue_ops
  88. self._close_op = close_op
  89. self._cancel_op = cancel_op
  90. # Close when no more will be produced, but pending enqueues should be
  91. # preserved.
  92. if not self._close_op:
  93. self._close_op = self._queue.close()
  94. # Close and cancel pending enqueues since there was an error and we want
  95. # to unblock everything so we can cleanly exit.
  96. if not self._cancel_op:
  97. self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
  98. def _init_from_proto(self, queue_runner_def):
  99. """Create a QueueRunner from `QueueRunnerDef`.
  100. Args:
  101. queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
  102. """
  103. assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
  104. g = ops.get_default_graph()
  105. self._queue = g.as_graph_element(queue_runner_def.queue_name)
  106. self._enqueue_ops = [g.as_graph_element(op) for op
  107. in queue_runner_def.enqueue_op_name]
  108. self._close_op = g.as_graph_element(queue_runner_def.close_op_name)
  109. self._cancel_op = g.as_graph_element(queue_runner_def.cancel_op_name)
  110. @property
  111. def queue(self):
  112. return self._queue
  113. @property
  114. def enqueue_ops(self):
  115. return self._enqueue_ops
  116. @property
  117. def close_op(self):
  118. return self._close_op
  119. @property
  120. def cancel_op(self):
  121. return self._cancel_op
  122. @property
  123. def exceptions_raised(self):
  124. """Exceptions raised but not handled by the `QueueRunner` threads.
  125. Exceptions raised in queue runner threads are handled in one of two ways
  126. depending on whether or not a `Coordinator` was passed to
  127. `create_threads()`:
  128. * With a `Coordinator`, exceptions are reported to the coordinator and
  129. forgotten by the `QueueRunner`.
  130. * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
  131. made available in this `exceptions_raised` property.
  132. Returns:
  133. A list of Python `Exception` objects. The list is empty if no exception
  134. was captured. (No exceptions are captured when using a Coordinator.)
  135. """
  136. return self._exceptions_raised
  137. @property
  138. def name(self):
  139. """The string name of the underlying Queue."""
  140. return self._queue.name
  141. # pylint: disable=broad-except
  142. def _run(self, sess, enqueue_op, coord=None):
  143. """Execute the enqueue op in a loop, close the queue in case of error.
  144. Args:
  145. sess: A Session.
  146. enqueue_op: The Operation to run.
  147. coord: Optional Coordinator object for reporting errors and checking
  148. for stop conditions.
  149. """
  150. if coord:
  151. coord.register_thread(threading.current_thread())
  152. decremented = False
  153. try:
  154. while True:
  155. if coord and coord.should_stop():
  156. break
  157. try:
  158. sess.run(enqueue_op)
  159. except errors.OutOfRangeError:
  160. # This exception indicates that a queue was closed.
  161. with self._lock:
  162. self._runs -= 1
  163. decremented = True
  164. if self._runs == 0:
  165. try:
  166. sess.run(self._close_op)
  167. except Exception as e:
  168. # Intentionally ignore errors from close_op.
  169. logging.vlog(1, "Ignored exception: %s", str(e))
  170. return
  171. except Exception as e:
  172. # This catches all other exceptions.
  173. if coord:
  174. coord.request_stop(e)
  175. else:
  176. logging.error("Exception in QueueRunner: %s", str(e))
  177. with self._lock:
  178. self._exceptions_raised.append(e)
  179. raise
  180. finally:
  181. # Make sure we account for all terminations: normal or errors.
  182. if not decremented:
  183. with self._lock:
  184. self._runs -= 1
  185. def _close_on_stop(self, sess, cancel_op, coord):
  186. """Close the queue when the Coordinator requests stop.
  187. Args:
  188. sess: A Session.
  189. cancel_op: The Operation to run.
  190. coord: Coordinator.
  191. """
  192. coord.register_thread(threading.current_thread())
  193. coord.wait_for_stop()
  194. try:
  195. sess.run(cancel_op)
  196. except Exception as e:
  197. # Intentionally ignore errors from cancel_op.
  198. logging.vlog(1, "Ignored exception: %s", str(e))
  199. # pylint: enable=broad-except
  200. def create_threads(self, sess, coord=None, daemon=False, start=False):
  201. """Create threads to run the enqueue ops.
  202. This method requires a session in which the graph was launched. It creates
  203. a list of threads, optionally starting them. There is one thread for each
  204. op passed in `enqueue_ops`.
  205. The `coord` argument is an optional coordinator, that the threads will use
  206. to terminate together and report exceptions. If a coordinator is given,
  207. this method starts an additional thread to close the queue when the
  208. coordinator requests a stop.
  209. This method may be called again as long as all threads from a previous call
  210. have stopped.
  211. Args:
  212. sess: A `Session`.
  213. coord: Optional `Coordinator` object for reporting errors and checking
  214. stop conditions.
  215. daemon: Boolean. If `True` make the threads daemon threads.
  216. start: Boolean. If `True` starts the threads. If `False` the
  217. caller must call the `start()` method of the returned threads.
  218. Returns:
  219. A list of threads.
  220. Raises:
  221. RuntimeError: If threads from a previous call to `create_threads()` are
  222. still running.
  223. """
  224. with self._lock:
  225. if self._runs > 0:
  226. # Already started: no new threads to return.
  227. return []
  228. self._runs = len(self._enqueue_ops)
  229. self._exceptions_raised = []
  230. ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord))
  231. for op in self._enqueue_ops]
  232. if coord:
  233. ret_threads.append(threading.Thread(target=self._close_on_stop,
  234. args=(sess, self._cancel_op, coord)))
  235. for t in ret_threads:
  236. if daemon:
  237. t.daemon = True
  238. if start:
  239. t.start()
  240. return ret_threads
  241. def to_proto(self):
  242. """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.
  243. Returns:
  244. A `QueueRunnerDef` protocol buffer.
  245. """
  246. queue_runner_def = queue_runner_pb2.QueueRunnerDef()
  247. queue_runner_def.queue_name = self.queue.name
  248. for enqueue_op in self.enqueue_ops:
  249. queue_runner_def.enqueue_op_name.append(enqueue_op.name)
  250. queue_runner_def.close_op_name = self.close_op.name
  251. queue_runner_def.cancel_op_name = self.cancel_op.name
  252. return queue_runner_def
  253. @staticmethod
  254. def from_proto(queue_runner_def):
  255. """Returns a `QueueRunner` object created from `queue_runner_def`."""
  256. return QueueRunner(queue_runner_def=queue_runner_def)
  257. def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
  258. """Adds a `QueueRunner` to a collection in the graph.
  259. When building a complex model that uses many queues it is often difficult to
  260. gather all the queue runners that need to be run. This convenience function
  261. allows you to add a queue runner to a well known collection in the graph.
  262. The companion method `start_queue_runners()` can be used to start threads for
  263. all the collected queue runners.
  264. Args:
  265. qr: A `QueueRunner`.
  266. collection: A `GraphKey` specifying the graph collection to add
  267. the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`.
  268. """
  269. ops.add_to_collection(collection, qr)
  270. def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
  271. collection=ops.GraphKeys.QUEUE_RUNNERS):
  272. """Starts all queue runners collected in the graph.
  273. This is a companion method to `add_queue_runner()`. It just starts
  274. threads for all queue runners collected in the graph. It returns
  275. the list of all threads.
  276. Args:
  277. sess: `Session` used to run the queue ops. Defaults to the
  278. default session.
  279. coord: Optional `Coordinator` for coordinating the started threads.
  280. daemon: Whether the threads should be marked as `daemons`, meaning
  281. they don't block program exit.
  282. start: Set to `False` to only create the threads, not start them.
  283. collection: A `GraphKey` specifying the graph collection to
  284. get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
  285. Returns:
  286. A list of threads.
  287. """
  288. if sess is None:
  289. sess = ops.get_default_session()
  290. if not sess:
  291. raise ValueError("Cannot start queue runners: No default session is "
  292. "registered. Use `with sess.as_default()` or pass an "
  293. "explicit session to tf.start_queue_runners(sess=sess)")
  294. with sess.graph.as_default():
  295. threads = []
  296. for qr in ops.get_collection(collection):
  297. threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
  298. start=start))
  299. return threads
  300. ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS,
  301. proto_type=queue_runner_pb2.QueueRunnerDef,
  302. to_proto=QueueRunner.to_proto,
  303. from_proto=QueueRunner.from_proto)