PageRenderTime 40ms CodeModel.GetById 14ms RepoModel.GetById 1ms app.codeStats 0ms

/jni-build/jni/include/tensorflow/contrib/learn/python/learn/tests/dataframe/dataframe_test.py

https://gitlab.com/zharfi/GunSafety
Python | 149 lines | 101 code | 28 blank | 20 comment | 1 complexity | f5544ce8c51364769aba880efc4fbdba MD5 | raw file
  1. # pylint: disable=g-bad-file-header
  2. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. """Tests of the DataFrame class."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import tensorflow as tf
  21. from tensorflow.contrib.learn.python import learn
  22. from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
  23. def setup_test_df():
  24. """Create a dataframe populated with some test columns."""
  25. df = learn.DataFrame()
  26. df["a"] = learn.TransformedSeries(
  27. [mocks.MockSeries("foobar", [])],
  28. mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out1")
  29. df["b"] = learn.TransformedSeries(
  30. [mocks.MockSeries("foobar", [])],
  31. mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out2")
  32. df["c"] = learn.TransformedSeries(
  33. [mocks.MockSeries("foobar", [])],
  34. mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out1")
  35. return df
  36. class DataFrameTest(tf.test.TestCase):
  37. """Test of `DataFrame`."""
  38. def test_create(self):
  39. df = setup_test_df()
  40. self.assertEqual(df.columns(), frozenset(["a", "b", "c"]))
  41. def test_select(self):
  42. df = setup_test_df()
  43. df2 = df.select(["a", "c"])
  44. self.assertEqual(df2.columns(), frozenset(["a", "c"]))
  45. def test_get_item(self):
  46. df = setup_test_df()
  47. c1 = df["b"]
  48. self.assertEqual("Fake Tensor 2", c1.build())
  49. def test_set_item_column(self):
  50. df = setup_test_df()
  51. self.assertEqual(3, len(df))
  52. col1 = mocks.MockSeries("QuackColumn", [])
  53. df["quack"] = col1
  54. self.assertEqual(4, len(df))
  55. col2 = df["quack"]
  56. self.assertEqual(col1, col2)
  57. def test_set_item_column_multi(self):
  58. df = setup_test_df()
  59. self.assertEqual(3, len(df))
  60. col1 = mocks.MockSeries("QuackColumn", [])
  61. col2 = mocks.MockSeries("MooColumn", [])
  62. df["quack", "moo"] = [col1, col2]
  63. self.assertEqual(5, len(df))
  64. col3 = df["quack"]
  65. self.assertEqual(col1, col3)
  66. col4 = df["moo"]
  67. self.assertEqual(col2, col4)
  68. def test_set_item_pandas(self):
  69. # TODO(jamieas)
  70. pass
  71. def test_set_item_numpy(self):
  72. # TODO(jamieas)
  73. pass
  74. def test_build(self):
  75. df = setup_test_df()
  76. result = df.build()
  77. expected = {"a": "Fake Tensor 1",
  78. "b": "Fake Tensor 2",
  79. "c": "Fake Tensor 1"}
  80. self.assertEqual(expected, result)
  81. def test_to_input_fn_all_features(self):
  82. df = setup_test_df()
  83. input_fn = df.to_input_fn()
  84. f, t = input_fn()
  85. expected_f = {"a": "Fake Tensor 1",
  86. "b": "Fake Tensor 2",
  87. "c": "Fake Tensor 1"}
  88. self.assertEqual(expected_f, f)
  89. expected_t = {}
  90. self.assertEqual(expected_t, t)
  91. def test_to_input_fn_features_only(self):
  92. df = setup_test_df()
  93. input_fn = df.to_input_fn(["b", "c"])
  94. f, t = input_fn()
  95. expected_f = {"b": "Fake Tensor 2", "c": "Fake Tensor 1"}
  96. self.assertEqual(expected_f, f)
  97. expected_t = {}
  98. self.assertEqual(expected_t, t)
  99. def test_to_input_fn_targets_only(self):
  100. df = setup_test_df()
  101. input_fn = df.to_input_fn(target_keys=["b", "c"])
  102. f, t = input_fn()
  103. expected_f = {"a": "Fake Tensor 1"}
  104. self.assertEqual(expected_f, f)
  105. expected_t = {"b": "Fake Tensor 2", "c": "Fake Tensor 1"}
  106. self.assertEqual(expected_t, t)
  107. def test_to_input_fn_both(self):
  108. df = setup_test_df()
  109. input_fn = df.to_input_fn(feature_keys=["a"], target_keys=["b"])
  110. f, t = input_fn()
  111. expected_f = {"a": "Fake Tensor 1"}
  112. self.assertEqual(expected_f, f)
  113. expected_t = {"b": "Fake Tensor 2"}
  114. self.assertEqual(expected_t, t)
  115. def test_to_input_fn_not_disjoint(self):
  116. df = setup_test_df()
  117. def get_not_disjoint():
  118. df.to_input_fn(feature_keys=["a", "b"], target_keys=["b"])
  119. self.assertRaises(ValueError, get_not_disjoint)
  120. if __name__ == "__main__":
  121. tf.test.main()