PageRenderTime 28ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py

https://gitlab.com/github-cloud-corporation/tensorflow
Python | 385 lines | 238 code | 45 blank | 102 comment | 37 complexity | dcae3dbe2d9593d6d692d700688a784c MD5 | raw file
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Some common SessionRunHook classes.
  16. @@
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import os
  22. import time
  23. import numpy as np
  24. import six
  25. from tensorflow.contrib.framework.python.ops import variables as contrib_variables
  26. from tensorflow.contrib.learn.python.learn import session_run_hook
  27. from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs
  28. from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache
  29. from tensorflow.core.framework.summary_pb2 import Summary
  30. from tensorflow.core.util.event_pb2 import SessionLog
  31. from tensorflow.python.framework import ops
  32. from tensorflow.python.platform import tf_logging as logging
  33. class LoggingTensorHook(session_run_hook.SessionRunHook):
  34. """Prints given tensors every N iteration.
  35. The tensors will be printed to the log, with `INFO` severity.
  36. """
  37. def __init__(self, tensors, every_n_iter=100):
  38. """Initializes a LoggingHook monitor.
  39. Args:
  40. tensors: `dict` of tag to tensors/names or
  41. `iterable` of tensors/names.
  42. every_n_iter: `int`, print every N iteration.
  43. """
  44. if not isinstance(tensors, dict):
  45. tensors = {item: item for item in tensors}
  46. self._tensors = tensors
  47. self._every_n_iter = every_n_iter
  48. def begin(self):
  49. self._iter_count = 0
  50. # Convert names to tensors if given
  51. self._current_tensors = {tag: _as_graph_element(tensor)
  52. for (tag, tensor) in self._tensors.items()}
  53. def before_run(self, run_context): # pylint: disable=unused-argument
  54. if self._iter_count % self._every_n_iter == 0:
  55. return SessionRunArgs(self._current_tensors)
  56. else:
  57. return None
  58. def after_run(self, run_context, run_values):
  59. _ = run_context
  60. if self._iter_count % self._every_n_iter == 0:
  61. stats = []
  62. for tag in sorted(self._current_tensors.keys()):
  63. stats.append("%s = %s" % (tag, run_values.results[tag]))
  64. logging.info("%s", ", ".join(stats))
  65. self._iter_count += 1
  66. class StopAtStepHook(session_run_hook.SessionRunHook):
  67. """Monitor to request stop at a specified step."""
  68. def __init__(self, num_steps=None, last_step=None):
  69. """Create a StopAtStep Hook.
  70. This hook requests stop after either a number of steps have been
  71. executed or a last step has been reached. Only of the two options can be
  72. specified.
  73. if `num_steps` is specified, it indicates the number of steps to execute
  74. after `begin()` is called. If instead `last_step` is specified, it
  75. indicates the last step we want to execute, as passed to the `after_run()`
  76. call.
  77. Args:
  78. num_steps: Number of steps to execute.
  79. last_step: Step after which to stop.
  80. Raises:
  81. ValueError: If one of the arguments is invalid.
  82. """
  83. if num_steps is None and last_step is None:
  84. raise ValueError("One of num_steps or last_step must be specified.")
  85. if num_steps is not None and last_step is not None:
  86. raise ValueError("Only one of num_steps or last_step can be specified.")
  87. self._num_steps = num_steps
  88. self._last_step = last_step
  89. def begin(self):
  90. self._global_step_tensor = contrib_variables.get_global_step()
  91. if self._global_step_tensor is None:
  92. raise RuntimeError("Global step should be created to use StopAtStepHook.")
  93. def before_run(self, run_context): # pylint: disable=unused-argument
  94. return SessionRunArgs(self._global_step_tensor)
  95. def after_run(self, run_context, run_values):
  96. global_step = run_values.results
  97. if self._last_step is None:
  98. self._last_step = global_step + self._num_steps - 1
  99. if global_step >= self._last_step:
  100. run_context.request_stop()
  101. class CheckpointSaverHook(session_run_hook.SessionRunHook):
  102. """Saves checkpoints every N steps or seconds."""
  103. def __init__(self,
  104. checkpoint_dir,
  105. save_secs=None,
  106. save_steps=None,
  107. saver=None,
  108. checkpoint_basename="model.ckpt",
  109. scaffold=None):
  110. """Initialize CheckpointSaverHook monitor.
  111. Args:
  112. checkpoint_dir: `str`, base directory for the checkpoint files.
  113. save_secs: `int`, save every N secs.
  114. save_steps: `int`, save every N steps.
  115. saver: `Saver` object, used for saving.
  116. checkpoint_basename: `str`, base name for the checkpoint files.
  117. scaffold: `Scaffold`, use to get saver object.
  118. Raises:
  119. ValueError: One of `save_steps` or `save_secs` should be set.
  120. """
  121. logging.info("Create CheckpointSaverHook")
  122. self._saver = saver
  123. self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
  124. self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
  125. self._scaffold = scaffold
  126. self._save_secs = save_secs
  127. self._save_steps = save_steps
  128. self._last_saved_time = None
  129. self._last_saved_step = None
  130. if save_steps is None and save_secs is None:
  131. raise ValueError("Either save_steps or save_secs should be provided")
  132. if (save_steps is not None) and (save_secs is not None):
  133. raise ValueError("Can not provide both save_steps and save_secs.")
  134. def begin(self):
  135. self._last_saved_time = None
  136. self._last_saved_step = None
  137. self._global_step_tensor = contrib_variables.get_global_step()
  138. if self._global_step_tensor is None:
  139. raise RuntimeError(
  140. "Global step should be created to use CheckpointSaverHook.")
  141. def before_run(self, run_context): # pylint: disable=unused-argument
  142. return SessionRunArgs(self._global_step_tensor)
  143. def after_run(self, run_context, run_values):
  144. global_step = run_values.results
  145. if self._last_saved_time is None:
  146. self._save(global_step, run_context.session)
  147. if self._save_steps is not None:
  148. if global_step >= self._last_saved_step + self._save_steps:
  149. self._save(global_step, run_context.session)
  150. if self._save_secs is not None:
  151. if time.time() >= self._last_saved_time + self._save_secs:
  152. self._save(global_step, run_context.session)
  153. def end(self, session):
  154. last_step = session.run(contrib_variables.get_global_step())
  155. self._save(last_step, session)
  156. def _save(self, step, session):
  157. """Saves the latest checkpoint."""
  158. if step == self._last_saved_step:
  159. return
  160. logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
  161. self._last_saved_time = time.time()
  162. self._last_saved_step = step
  163. if self._saver is None:
  164. self._scaffold.saver.save(session, self._save_path, global_step=step)
  165. else:
  166. self._saver.save(session, self._save_path, global_step=step)
  167. self._summary_writer.add_session_log(
  168. SessionLog(
  169. status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
  170. step)
  171. class StepCounterHook(session_run_hook.SessionRunHook):
  172. """Steps per second monitor."""
  173. def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None):
  174. self._summary_tag = "global_step/sec"
  175. self._every_n_steps = every_n_steps
  176. self._summary_writer = summary_writer
  177. if summary_writer is None and output_dir:
  178. self._summary_writer = SummaryWriterCache.get(output_dir)
  179. def begin(self):
  180. self._last_reported_time = None
  181. self._last_reported_step = None
  182. self._global_step_tensor = contrib_variables.get_global_step()
  183. if self._global_step_tensor is None:
  184. raise RuntimeError(
  185. "Global step should be created to use StepCounterHook.")
  186. def before_run(self, run_context): # pylint: disable=unused-argument
  187. return SessionRunArgs(self._global_step_tensor)
  188. def after_run(self, run_context, run_values):
  189. _ = run_context
  190. if not self._summary_writer:
  191. return
  192. global_step = run_values.results
  193. current_time = time.time()
  194. if self._last_reported_time is None:
  195. self._last_reported_step = global_step
  196. self._last_reported_time = current_time
  197. else:
  198. if global_step >= self._every_n_steps + self._last_reported_step:
  199. added_steps = global_step - self._last_reported_step
  200. elapsed_time = current_time - self._last_reported_time
  201. steps_per_sec = added_steps / elapsed_time
  202. summary = Summary(value=[Summary.Value(
  203. tag=self._summary_tag, simple_value=steps_per_sec)])
  204. self._summary_writer.add_summary(summary, global_step)
  205. self._last_reported_step = global_step
  206. self._last_reported_time = current_time
  207. class NanLossDuringTrainingError(RuntimeError):
  208. def __str__(self):
  209. return "NaN loss during training."
  210. class NanTensorHook(session_run_hook.SessionRunHook):
  211. """NaN Loss monitor.
  212. Monitors loss and stops training if loss is NaN.
  213. Can either fail with exception or just stop training.
  214. """
  215. def __init__(self, loss_tensor, fail_on_nan_loss=True):
  216. """Initializes NanLoss monitor.
  217. Args:
  218. loss_tensor: `Tensor`, the loss tensor.
  219. fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
  220. """
  221. self._loss_tensor = loss_tensor
  222. self._fail_on_nan_loss = fail_on_nan_loss
  223. def before_run(self, run_context): # pylint: disable=unused-argument
  224. return SessionRunArgs(self._loss_tensor)
  225. def after_run(self, run_context, run_values):
  226. if np.isnan(run_values.results):
  227. failure_message = "Model diverged with loss = NaN."
  228. if self._fail_on_nan_loss:
  229. logging.error(failure_message)
  230. raise NanLossDuringTrainingError
  231. else:
  232. logging.warning(failure_message)
  233. # We don't raise an error but we request stop without an exception.
  234. run_context.request_stop()
  235. class SummarySaverHook(session_run_hook.SessionRunHook):
  236. """Saves summaries every N steps."""
  237. def __init__(self,
  238. save_steps=100,
  239. output_dir=None,
  240. summary_writer=None,
  241. scaffold=None,
  242. summary_op=None):
  243. """Initializes a `SummarySaver` monitor.
  244. Args:
  245. save_steps: `int`, save summaries every N steps. See `EveryN`.
  246. output_dir: `string`, the directory to save the summaries to. Only used
  247. if no `summary_writer` is supplied.
  248. summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
  249. one will be created accordingly.
  250. scaffold: `Scaffold` to get summary_op if it's not provided.
  251. summary_op: `Tensor` of type `string`. A serialized `Summary` protocol
  252. buffer, as output by TF summary methods like `scalar_summary` or
  253. `merge_all_summaries`.
  254. """
  255. # TODO(ipolosukhin): Implement every N seconds.
  256. self._summary_op = summary_op
  257. self._summary_writer = summary_writer
  258. if summary_writer is None and output_dir:
  259. self._summary_writer = SummaryWriterCache.get(output_dir)
  260. self._scaffold = scaffold
  261. self._save_steps = save_steps
  262. # TODO(mdan): Throw an error if output_dir and summary_writer are None.
  263. def begin(self):
  264. self._last_saved_step = None
  265. self._request_summary = True
  266. self._global_step_tensor = contrib_variables.get_global_step()
  267. if self._global_step_tensor is None:
  268. raise RuntimeError(
  269. "Global step should be created to use SummarySaverHook.")
  270. def before_run(self, run_context): # pylint: disable=unused-argument
  271. requests = {"global_step": self._global_step_tensor}
  272. if self._request_summary:
  273. if self._summary_op is not None:
  274. requests["summary"] = self._summary_op
  275. elif self._scaffold.summary_op is not None:
  276. requests["summary"] = self._scaffold.summary_op
  277. return SessionRunArgs(requests)
  278. def after_run(self, run_context, run_values):
  279. _ = run_context
  280. if not self._summary_writer:
  281. return
  282. global_step = run_values.results["global_step"]
  283. if self._last_saved_step is None:
  284. self._summary_writer.add_session_log(
  285. SessionLog(status=SessionLog.START), global_step)
  286. if self._request_summary:
  287. self._last_saved_step = global_step
  288. if "summary" in run_values.results:
  289. self._summary_writer.add_summary(run_values.results["summary"],
  290. global_step)
  291. self._request_summary = (
  292. global_step >= self._last_saved_step + self._save_steps - 1)
  293. def end(self, session=None):
  294. if self._summary_writer:
  295. self._summary_writer.flush()
  296. def _as_graph_element(obj):
  297. """Retrieves Graph element."""
  298. graph = ops.get_default_graph()
  299. if not isinstance(obj, six.string_types):
  300. if not hasattr(obj, "graph") or obj.graph != graph:
  301. raise ValueError("Passed %s should have graph attribute that is equal "
  302. "to current graph %s." % (obj, graph))
  303. return obj
  304. if ":" in obj:
  305. element = graph.as_graph_element(obj)
  306. else:
  307. element = graph.as_graph_element(obj + ":0")
  308. # Check that there is no :1 (e.g. it's single output).
  309. try:
  310. graph.as_graph_element(obj + ":1")
  311. except (KeyError, ValueError):
  312. pass
  313. else:
  314. raise ValueError("Name %s is ambiguous, "
  315. "as this `Operation` has multiple outputs "
  316. "(at least 2)." % obj)
  317. return element