PageRenderTime 24ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/tensorflow/python/training/summary_writer_test.py

https://gitlab.com/wilane/tensorflow
Python | 151 lines | 135 code | 11 blank | 5 comment | 2 complexity | 00539006c38c5761a1588d92b4b00ba6 MD5 | raw file
  1. """Tests for training_coordinator.py."""
  2. import glob
  3. import os.path
  4. import shutil
  5. import time
  6. import tensorflow.python.platform
  7. import tensorflow as tf
  8. class SummaryWriterTestCase(tf.test.TestCase):
  9. def _TestDir(self, test_name):
  10. test_dir = os.path.join(self.get_temp_dir(), test_name)
  11. return test_dir
  12. def _CleanTestDir(self, test_name):
  13. test_dir = self._TestDir(test_name)
  14. if os.path.exists(test_dir):
  15. shutil.rmtree(test_dir)
  16. return test_dir
  17. def _EventsReader(self, test_dir):
  18. event_paths = glob.glob(os.path.join(test_dir, "event*"))
  19. # If the tests runs multiple time in the same directory we can have
  20. # more than one matching event file. We only want to read the last one.
  21. self.assertTrue(event_paths)
  22. return tf.train.summary_iterator(event_paths[-1])
  23. def _assertRecent(self, t):
  24. self.assertTrue(abs(t - time.time()) < 5)
  25. def testBasics(self):
  26. test_dir = self._CleanTestDir("basics")
  27. sw = tf.train.SummaryWriter(test_dir)
  28. sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="mee",
  29. simple_value=10.0)]),
  30. 10)
  31. sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="boo",
  32. simple_value=20.0)]),
  33. 20)
  34. with tf.Graph().as_default() as g:
  35. tf.constant([0], name="zero")
  36. gd = g.as_graph_def()
  37. sw.add_graph(gd, global_step=30)
  38. sw.close()
  39. rr = self._EventsReader(test_dir)
  40. # The first event should list the file_version.
  41. ev = next(rr)
  42. self._assertRecent(ev.wall_time)
  43. self.assertEquals("brain.Event:1", ev.file_version)
  44. # The next event should have the value 'mee=10.0'.
  45. ev = next(rr)
  46. self._assertRecent(ev.wall_time)
  47. self.assertEquals(10, ev.step)
  48. self.assertProtoEquals("""
  49. value { tag: 'mee' simple_value: 10.0 }
  50. """, ev.summary)
  51. # The next event should have the value 'boo=20.0'.
  52. ev = next(rr)
  53. self._assertRecent(ev.wall_time)
  54. self.assertEquals(20, ev.step)
  55. self.assertProtoEquals("""
  56. value { tag: 'boo' simple_value: 20.0 }
  57. """, ev.summary)
  58. # The next event should have the graph_def.
  59. ev = next(rr)
  60. self._assertRecent(ev.wall_time)
  61. self.assertEquals(30, ev.step)
  62. self.assertProtoEquals(gd, ev.graph_def)
  63. # We should be done.
  64. self.assertRaises(StopIteration, lambda: next(rr))
  65. def testConstructWithGraph(self):
  66. test_dir = self._CleanTestDir("basics_with_graph")
  67. with tf.Graph().as_default() as g:
  68. tf.constant([12], name="douze")
  69. gd = g.as_graph_def()
  70. sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
  71. sw.close()
  72. rr = self._EventsReader(test_dir)
  73. # The first event should list the file_version.
  74. ev = next(rr)
  75. self._assertRecent(ev.wall_time)
  76. self.assertEquals("brain.Event:1", ev.file_version)
  77. # The next event should have the graph.
  78. ev = next(rr)
  79. self._assertRecent(ev.wall_time)
  80. self.assertEquals(0, ev.step)
  81. self.assertProtoEquals(gd, ev.graph_def)
  82. # We should be done.
  83. self.assertRaises(StopIteration, lambda: next(rr))
  84. # Checks that values returned from session Run() calls are added correctly to
  85. # summaries. These are numpy types so we need to check they fit in the
  86. # protocol buffers correctly.
  87. def testSummariesAndStopFromSessionRunCalls(self):
  88. test_dir = self._CleanTestDir("global_step")
  89. sw = tf.train.SummaryWriter(test_dir)
  90. with self.test_session():
  91. i = tf.constant(1, dtype=tf.int32, shape=[])
  92. l = tf.constant(2, dtype=tf.int64, shape=[])
  93. # Test the summary can be passed serialized.
  94. summ = tf.Summary(value=[tf.Summary.Value(tag="i", simple_value=1.0)])
  95. sw.add_summary(summ.SerializeToString(), i.eval())
  96. sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="l",
  97. simple_value=2.0)]),
  98. l.eval())
  99. sw.close()
  100. rr = self._EventsReader(test_dir)
  101. # File_version.
  102. ev = next(rr)
  103. self.assertTrue(ev)
  104. self._assertRecent(ev.wall_time)
  105. self.assertEquals("brain.Event:1", ev.file_version)
  106. # Summary passed serialized.
  107. ev = next(rr)
  108. self.assertTrue(ev)
  109. self._assertRecent(ev.wall_time)
  110. self.assertEquals(1, ev.step)
  111. self.assertProtoEquals("""
  112. value { tag: 'i' simple_value: 1.0 }
  113. """, ev.summary)
  114. # Summary passed as SummaryObject.
  115. ev = next(rr)
  116. self.assertTrue(ev)
  117. self._assertRecent(ev.wall_time)
  118. self.assertEquals(2, ev.step)
  119. self.assertProtoEquals("""
  120. value { tag: 'l' simple_value: 2.0 }
  121. """, ev.summary)
  122. # We should be done.
  123. self.assertRaises(StopIteration, lambda: next(rr))
  124. if __name__ == "__main__":
  125. tf.test.main()