PageRenderTime 27ms CodeModel.GetById 0ms RepoModel.GetById 0ms app.codeStats 0ms

/tensorflow/core/util/tensor_slice_reader_test.cc

https://gitlab.com/hrishikeshvganu/tensorflow
C++ | 461 lines | 265 code | 49 blank | 147 comment | 7 complexity | bd4fc903c2b17de82487f8eb85bd087e MD5 | raw file
  1. /* Copyright 2015 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/core/util/tensor_slice_reader.h"
  13. #include "tensorflow/core/framework/types.h"
  14. #include "tensorflow/core/lib/core/status_test_util.h"
  15. #include "tensorflow/core/lib/core/stringpiece.h"
  16. #include "tensorflow/core/lib/io/path.h"
  17. #include "tensorflow/core/lib/strings/strcat.h"
  18. #include "tensorflow/core/platform/env.h"
  19. #include "tensorflow/core/platform/logging.h"
  20. #include "tensorflow/core/platform/protobuf.h"
  21. #include "tensorflow/core/platform/test.h"
  22. #include "tensorflow/core/platform/types.h"
  23. #include "tensorflow/core/public/version.h"
  24. #include "tensorflow/core/util/saved_tensor_slice_util.h"
  25. #include "tensorflow/core/util/tensor_slice_reader_cache.h"
  26. #include "tensorflow/core/util/tensor_slice_writer.h"
  27. namespace tensorflow {
  28. namespace checkpoint {
  29. namespace {
  30. // A simple test where we write a few tensor slices with a number of tensor
  31. // slice writers and then read them back from a tensor slice reader.
  32. //
  33. // We have a 2-d tensor of shape 4 X 5 that looks like this:
  34. //
  35. // 0 1 2 3 4
  36. // 5 6 7 8 9
  37. // 10 11 12 13 14
  38. // 15 16 17 18 19
  39. //
  40. // We assume this is a row-major matrix.
  41. void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function,
  42. TensorSliceReader::OpenTableFunction open_function) {
  43. const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
  44. TensorShape shape({4, 5});
  45. // File #0 contains a slice that is the top two rows:
  46. //
  47. // 0 1 2 3 4
  48. // 5 6 7 8 9
  49. // . . . . .
  50. // . . . . .
  51. {
  52. const string fname = strings::StrCat(fname_base, "_0");
  53. TensorSliceWriter writer(fname, create_function);
  54. const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  55. TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
  56. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  57. TF_CHECK_OK(writer.Finish());
  58. }
  59. // File #1 contains two slices:
  60. //
  61. // slice #0 is the bottom left corner
  62. // . . . . .
  63. // . . . . .
  64. // 10 11 12 . .
  65. // 15 16 17 . .
  66. //
  67. // slice #1 is the bottom right corner
  68. // . . . . .
  69. // . . . . .
  70. // . . . . .
  71. // . . . 18 19
  72. {
  73. const string fname = strings::StrCat(fname_base, "_1");
  74. TensorSliceWriter writer(fname, create_function);
  75. // slice #0
  76. {
  77. const float data[] = {10, 11, 12, 15, 16, 17};
  78. TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
  79. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  80. }
  81. // slice #1
  82. {
  83. const float data[] = {18, 19};
  84. TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
  85. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  86. }
  87. TF_CHECK_OK(writer.Finish());
  88. }
  89. // Notice that we leave a hole in the tensor
  90. // . . . . .
  91. // . . . . .
  92. // . . . (13) (14)
  93. // . . . . .
  94. // Now we need to read the tensor slices
  95. const string filepattern = strings::StrCat(fname_base, "_*");
  96. TensorSliceReader reader(filepattern, open_function);
  97. TF_EXPECT_OK(reader.status());
  98. EXPECT_EQ(2, reader.num_files());
  99. // We query some of the tensors
  100. {
  101. TensorShape shape;
  102. DataType type;
  103. EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
  104. EXPECT_EQ("[4,5]", shape.DebugString());
  105. EXPECT_EQ(DT_FLOAT, type);
  106. EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
  107. }
  108. // Now we query some slices
  109. //
  110. // Slice #1 is an exact match
  111. // 0 1 2 3 4
  112. // 5 6 7 8 9
  113. // . . . . .
  114. // . . . . .
  115. {
  116. TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
  117. float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  118. float results[10];
  119. EXPECT_TRUE(reader.CopySliceData("test", s, results));
  120. for (int i = 0; i < 10; ++i) {
  121. EXPECT_EQ(expected[i], results[i]);
  122. }
  123. }
  124. // Slice #2 is a subset match
  125. // . . . . .
  126. // 5 6 7 8 9
  127. // . . . . .
  128. // . . . . .
  129. {
  130. TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
  131. float expected[] = {5, 6, 7, 8, 9};
  132. float results[5];
  133. EXPECT_TRUE(reader.CopySliceData("test", s, results));
  134. for (int i = 0; i < 5; ++i) {
  135. EXPECT_EQ(expected[i], results[i]);
  136. }
  137. }
  138. // Slice #4 includes the hole and so there is no match
  139. // . . . . .
  140. // . . 7 8 9
  141. // . . 12 13 14
  142. // . . . . .
  143. {
  144. TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
  145. float results[6];
  146. EXPECT_FALSE(reader.CopySliceData("test", s, results));
  147. }
  148. }
  149. TEST(TensorSliceReaderTest, SimpleFloat) {
  150. SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader);
  151. }
  152. template <typename T, typename U>
  153. void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function,
  154. TensorSliceReader::OpenTableFunction open_function,
  155. const string& checkpoint_file) {
  156. const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file);
  157. TensorShape shape({4, 5});
  158. // File #0 contains a slice that is the top two rows:
  159. //
  160. // 0 1 2 3 4
  161. // 5 6 7 8 9
  162. // . . . . .
  163. // . . . . .
  164. {
  165. const string fname = strings::StrCat(fname_base, "_0");
  166. TensorSliceWriter writer(fname, create_function);
  167. const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  168. TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
  169. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  170. TF_CHECK_OK(writer.Finish());
  171. }
  172. // File #1 contains two slices:
  173. //
  174. // slice #0 is the bottom left corner
  175. // . . . . .
  176. // . . . . .
  177. // 10 11 12 . .
  178. // 15 16 17 . .
  179. //
  180. // slice #1 is the bottom right corner
  181. // . . . . .
  182. // . . . . .
  183. // . . . . .
  184. // . . . 18 19
  185. {
  186. const string fname = strings::StrCat(fname_base, "_1");
  187. TensorSliceWriter writer(fname, create_function);
  188. // slice #0
  189. {
  190. const T data[] = {10, 11, 12, 15, 16, 17};
  191. TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
  192. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  193. }
  194. // slice #1
  195. {
  196. const T data[] = {18, 19};
  197. TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
  198. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  199. }
  200. TF_CHECK_OK(writer.Finish());
  201. }
  202. // Notice that we leave a hole in the tensor
  203. // . . . . .
  204. // . . . . .
  205. // . . . (13) (14)
  206. // . . . . .
  207. // Now we need to read the tensor slices
  208. const string filepattern = strings::StrCat(fname_base, "_*");
  209. TensorSliceReader reader(filepattern, open_function);
  210. TF_EXPECT_OK(reader.status());
  211. EXPECT_EQ(2, reader.num_files());
  212. // We query some of the tensors
  213. {
  214. TensorShape shape;
  215. DataType type;
  216. EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
  217. EXPECT_EQ("[4,5]", shape.DebugString());
  218. EXPECT_EQ(DataTypeToEnum<T>::v(), type);
  219. EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
  220. }
  221. // Now we query some slices
  222. //
  223. // Slice #1 is an exact match
  224. // 0 1 2 3 4
  225. // 5 6 7 8 9
  226. // . . . . .
  227. // . . . . .
  228. {
  229. TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
  230. T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  231. U results[10];
  232. EXPECT_TRUE(reader.CopySliceData("test", s, results));
  233. for (int i = 0; i < 10; ++i) {
  234. EXPECT_EQ(expected[i], results[i]);
  235. }
  236. }
  237. // Slice #2 is a subset match
  238. // . . . . .
  239. // 5 6 7 8 9
  240. // . . . . .
  241. // . . . . .
  242. {
  243. TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
  244. T expected[] = {5, 6, 7, 8, 9};
  245. U results[5];
  246. EXPECT_TRUE(reader.CopySliceData("test", s, results));
  247. for (int i = 0; i < 5; ++i) {
  248. EXPECT_EQ(expected[i], results[i]);
  249. }
  250. }
  251. // Slice #4 includes the hole and so there is no match
  252. // . . . . .
  253. // . . 7 8 9
  254. // . . 12 13 14
  255. // . . . . .
  256. {
  257. TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
  258. U results[6];
  259. EXPECT_FALSE(reader.CopySliceData("test", s, results));
  260. }
  261. }
  262. #define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \
  263. TEST(TensorSliceReaderTest, Simple##TYPE) { \
  264. SimpleIntXHelper<TYPE, SAVED_TYPE>(CreateTableTensorSliceBuilder, \
  265. OpenTableTensorSliceReader, \
  266. #TYPE "_checkpoint"); \
  267. }
  268. TEST_SIMPLE_INT(int32, int32)
  269. TEST_SIMPLE_INT(int64, int64)
  270. TEST_SIMPLE_INT(int16, int32)
  271. TEST_SIMPLE_INT(int8, int32)
  272. TEST_SIMPLE_INT(uint8, int32)
  273. void CachedTensorSliceReaderTesterHelper(
  274. TensorSliceWriter::CreateBuilderFunction create_function,
  275. TensorSliceReader::OpenTableFunction open_function) {
  276. const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
  277. TensorShape shape({4, 5});
  278. // File #0 contains a slice that is the top two rows:
  279. //
  280. // 0 1 2 3 4
  281. // 5 6 7 8 9
  282. // . . . . .
  283. // . . . . .
  284. {
  285. const string fname = strings::StrCat(fname_base, "_0");
  286. TensorSliceWriter writer(fname, create_function);
  287. const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  288. TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
  289. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  290. TF_CHECK_OK(writer.Finish());
  291. }
  292. // File #1 contains two slices:
  293. //
  294. // slice #0 is the bottom left corner
  295. // . . . . .
  296. // . . . . .
  297. // 10 11 12 . .
  298. // 15 16 17 . .
  299. //
  300. // slice #1 is the bottom right corner
  301. // . . . . .
  302. // . . . . .
  303. // . . . . .
  304. // . . . 18 19
  305. {
  306. const string fname = strings::StrCat(fname_base, "_1");
  307. TensorSliceWriter writer(fname, create_function);
  308. // slice #0
  309. {
  310. const float data[] = {10, 11, 12, 15, 16, 17};
  311. TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
  312. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  313. }
  314. // slice #1
  315. {
  316. const float data[] = {18, 19};
  317. TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
  318. TF_CHECK_OK(writer.Add("test", shape, slice, data));
  319. }
  320. TF_CHECK_OK(writer.Finish());
  321. }
  322. // Notice that we leave a hole in the tensor
  323. // . . . . .
  324. // . . . . .
  325. // . . . (13) (14)
  326. // . . . . .
  327. // Now we need to read the tensor slices
  328. TensorSliceReaderCache cache;
  329. const string filepattern = strings::StrCat(fname_base, "_*");
  330. const TensorSliceReader* reader = cache.GetReader(
  331. filepattern, open_function, TensorSliceReader::kLoadAllShards);
  332. EXPECT_TRUE(reader != nullptr);
  333. EXPECT_EQ(2, reader->num_files());
  334. // We query some of the tensors
  335. {
  336. TensorShape shape;
  337. DataType type;
  338. EXPECT_TRUE(reader->HasTensor("test", &shape, &type));
  339. EXPECT_EQ("[4,5]", shape.DebugString());
  340. EXPECT_EQ(DT_FLOAT, type);
  341. EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr));
  342. }
  343. // Make sure the reader is cached.
  344. const TensorSliceReader* reader2 = cache.GetReader(
  345. filepattern, open_function, TensorSliceReader::kLoadAllShards);
  346. EXPECT_EQ(reader, reader2);
  347. reader = cache.GetReader("file_does_not_exist", open_function,
  348. TensorSliceReader::kLoadAllShards);
  349. EXPECT_TRUE(reader == nullptr);
  350. }
  351. TEST(CachedTensorSliceReaderTest, SimpleFloat) {
  352. CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder,
  353. OpenTableTensorSliceReader);
  354. }
  355. static void VersionTest(const VersionDef& versions, const string& error) {
  356. const string path = io::JoinPath(testing::TmpDir(), "checkpoint");
  357. {
  358. // Prepare an empty checkpoint with some version information
  359. SavedTensorSlices sts;
  360. sts.mutable_meta()->mutable_versions()->CopyFrom(versions);
  361. string contents;
  362. EXPECT_TRUE(sts.SerializeToString(&contents));
  363. // Write it to disk
  364. TensorSliceWriter::Builder* builder;
  365. TF_ASSERT_OK(CreateTableTensorSliceBuilder(path, &builder));
  366. builder->Add(kSavedTensorSlicesKey, contents);
  367. int64 file_size;
  368. builder->Finish(&file_size);
  369. delete builder;
  370. }
  371. // Read it back in and verify that we get the expected error
  372. TensorSliceReader reader(path, OpenTableTensorSliceReader);
  373. EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT &&
  374. StringPiece(reader.status().error_message()).starts_with(error))
  375. << "Expected error starting with '" << errors::InvalidArgument(error)
  376. << "', got '" << reader.status() << "'";
  377. }
  378. TEST(CheckpointVersionTest, MinConsumer) {
  379. VersionDef versions;
  380. versions.set_producer(TF_CHECKPOINT_VERSION + 1);
  381. versions.set_min_consumer(TF_CHECKPOINT_VERSION + 1);
  382. VersionTest(
  383. versions,
  384. strings::StrCat("Checkpoint min consumer version ",
  385. TF_CHECKPOINT_VERSION + 1, " above current version ",
  386. TF_CHECKPOINT_VERSION, " for TensorFlow"));
  387. }
  388. TEST(CheckpointVersionTest, MinProducer) {
  389. VersionDef versions;
  390. versions.set_producer(TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1);
  391. VersionTest(versions, strings::StrCat("Checkpoint producer version ",
  392. TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1,
  393. " below min producer ",
  394. TF_CHECKPOINT_VERSION_MIN_PRODUCER,
  395. " supported by TensorFlow"));
  396. }
  397. TEST(CheckpointVersionTest, BadConsumer) {
  398. VersionDef versions;
  399. versions.set_producer(TF_CHECKPOINT_VERSION + 1);
  400. versions.add_bad_consumers(TF_CHECKPOINT_VERSION);
  401. VersionTest(
  402. versions,
  403. strings::StrCat(
  404. "Checkpoint disallows consumer version ", TF_CHECKPOINT_VERSION,
  405. ". Please upgrade TensorFlow: this version is likely buggy."));
  406. }
  407. } // namespace
  408. } // namespace checkpoint
  409. } // namespace tensorflow