/tensorflow/core/util/tensor_slice_reader_test.cc
C++ | 461 lines | 265 code | 49 blank | 147 comment | 7 complexity | bd4fc903c2b17de82487f8eb85bd087e MD5 | raw file
- /* Copyright 2015 Google Inc. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include "tensorflow/core/util/tensor_slice_reader.h"
- #include "tensorflow/core/framework/types.h"
- #include "tensorflow/core/lib/core/status_test_util.h"
- #include "tensorflow/core/lib/core/stringpiece.h"
- #include "tensorflow/core/lib/io/path.h"
- #include "tensorflow/core/lib/strings/strcat.h"
- #include "tensorflow/core/platform/env.h"
- #include "tensorflow/core/platform/logging.h"
- #include "tensorflow/core/platform/protobuf.h"
- #include "tensorflow/core/platform/test.h"
- #include "tensorflow/core/platform/types.h"
- #include "tensorflow/core/public/version.h"
- #include "tensorflow/core/util/saved_tensor_slice_util.h"
- #include "tensorflow/core/util/tensor_slice_reader_cache.h"
- #include "tensorflow/core/util/tensor_slice_writer.h"
- namespace tensorflow {
- namespace checkpoint {
- namespace {
- // A simple test where we write a few tensor slices with a number of tensor
- // slice writers and then read them back from a tensor slice reader.
- //
- // We have a 2-d tensor of shape 4 X 5 that looks like this:
- //
- // 0 1 2 3 4
- // 5 6 7 8 9
- // 10 11 12 13 14
- // 15 16 17 18 19
- //
- // We assume this is a row-major matrix.
- void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function,
- TensorSliceReader::OpenTableFunction open_function) {
- const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
- TensorShape shape({4, 5});
- // File #0 contains a slice that is the top two rows:
- //
- // 0 1 2 3 4
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- const string fname = strings::StrCat(fname_base, "_0");
- TensorSliceWriter writer(fname, create_function);
- const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- TF_CHECK_OK(writer.Finish());
- }
- // File #1 contains two slices:
- //
- // slice #0 is the bottom left corner
- // . . . . .
- // . . . . .
- // 10 11 12 . .
- // 15 16 17 . .
- //
- // slice #1 is the bottom right corner
- // . . . . .
- // . . . . .
- // . . . . .
- // . . . 18 19
- {
- const string fname = strings::StrCat(fname_base, "_1");
- TensorSliceWriter writer(fname, create_function);
- // slice #0
- {
- const float data[] = {10, 11, 12, 15, 16, 17};
- TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- }
- // slice #1
- {
- const float data[] = {18, 19};
- TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- }
- TF_CHECK_OK(writer.Finish());
- }
- // Notice that we leave a hole in the tensor
- // . . . . .
- // . . . . .
- // . . . (13) (14)
- // . . . . .
- // Now we need to read the tensor slices
- const string filepattern = strings::StrCat(fname_base, "_*");
- TensorSliceReader reader(filepattern, open_function);
- TF_EXPECT_OK(reader.status());
- EXPECT_EQ(2, reader.num_files());
- // We query some of the tensors
- {
- TensorShape shape;
- DataType type;
- EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
- EXPECT_EQ("[4,5]", shape.DebugString());
- EXPECT_EQ(DT_FLOAT, type);
- EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
- }
- // Now we query some slices
- //
- // Slice #1 is an exact match
- // 0 1 2 3 4
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
- float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- float results[10];
- EXPECT_TRUE(reader.CopySliceData("test", s, results));
- for (int i = 0; i < 10; ++i) {
- EXPECT_EQ(expected[i], results[i]);
- }
- }
- // Slice #2 is a subset match
- // . . . . .
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
- float expected[] = {5, 6, 7, 8, 9};
- float results[5];
- EXPECT_TRUE(reader.CopySliceData("test", s, results));
- for (int i = 0; i < 5; ++i) {
- EXPECT_EQ(expected[i], results[i]);
- }
- }
- // Slice #4 includes the hole and so there is no match
- // . . . . .
- // . . 7 8 9
- // . . 12 13 14
- // . . . . .
- {
- TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
- float results[6];
- EXPECT_FALSE(reader.CopySliceData("test", s, results));
- }
- }
- TEST(TensorSliceReaderTest, SimpleFloat) {
- SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader);
- }
- template <typename T, typename U>
- void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function,
- TensorSliceReader::OpenTableFunction open_function,
- const string& checkpoint_file) {
- const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file);
- TensorShape shape({4, 5});
- // File #0 contains a slice that is the top two rows:
- //
- // 0 1 2 3 4
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- const string fname = strings::StrCat(fname_base, "_0");
- TensorSliceWriter writer(fname, create_function);
- const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- TF_CHECK_OK(writer.Finish());
- }
- // File #1 contains two slices:
- //
- // slice #0 is the bottom left corner
- // . . . . .
- // . . . . .
- // 10 11 12 . .
- // 15 16 17 . .
- //
- // slice #1 is the bottom right corner
- // . . . . .
- // . . . . .
- // . . . . .
- // . . . 18 19
- {
- const string fname = strings::StrCat(fname_base, "_1");
- TensorSliceWriter writer(fname, create_function);
- // slice #0
- {
- const T data[] = {10, 11, 12, 15, 16, 17};
- TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- }
- // slice #1
- {
- const T data[] = {18, 19};
- TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- }
- TF_CHECK_OK(writer.Finish());
- }
- // Notice that we leave a hole in the tensor
- // . . . . .
- // . . . . .
- // . . . (13) (14)
- // . . . . .
- // Now we need to read the tensor slices
- const string filepattern = strings::StrCat(fname_base, "_*");
- TensorSliceReader reader(filepattern, open_function);
- TF_EXPECT_OK(reader.status());
- EXPECT_EQ(2, reader.num_files());
- // We query some of the tensors
- {
- TensorShape shape;
- DataType type;
- EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
- EXPECT_EQ("[4,5]", shape.DebugString());
- EXPECT_EQ(DataTypeToEnum<T>::v(), type);
- EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
- }
- // Now we query some slices
- //
- // Slice #1 is an exact match
- // 0 1 2 3 4
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
- T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- U results[10];
- EXPECT_TRUE(reader.CopySliceData("test", s, results));
- for (int i = 0; i < 10; ++i) {
- EXPECT_EQ(expected[i], results[i]);
- }
- }
- // Slice #2 is a subset match
- // . . . . .
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
- T expected[] = {5, 6, 7, 8, 9};
- U results[5];
- EXPECT_TRUE(reader.CopySliceData("test", s, results));
- for (int i = 0; i < 5; ++i) {
- EXPECT_EQ(expected[i], results[i]);
- }
- }
- // Slice #4 includes the hole and so there is no match
- // . . . . .
- // . . 7 8 9
- // . . 12 13 14
- // . . . . .
- {
- TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
- U results[6];
- EXPECT_FALSE(reader.CopySliceData("test", s, results));
- }
- }
- #define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \
- TEST(TensorSliceReaderTest, Simple##TYPE) { \
- SimpleIntXHelper<TYPE, SAVED_TYPE>(CreateTableTensorSliceBuilder, \
- OpenTableTensorSliceReader, \
- #TYPE "_checkpoint"); \
- }
- TEST_SIMPLE_INT(int32, int32)
- TEST_SIMPLE_INT(int64, int64)
- TEST_SIMPLE_INT(int16, int32)
- TEST_SIMPLE_INT(int8, int32)
- TEST_SIMPLE_INT(uint8, int32)
- void CachedTensorSliceReaderTesterHelper(
- TensorSliceWriter::CreateBuilderFunction create_function,
- TensorSliceReader::OpenTableFunction open_function) {
- const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
- TensorShape shape({4, 5});
- // File #0 contains a slice that is the top two rows:
- //
- // 0 1 2 3 4
- // 5 6 7 8 9
- // . . . . .
- // . . . . .
- {
- const string fname = strings::StrCat(fname_base, "_0");
- TensorSliceWriter writer(fname, create_function);
- const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- TF_CHECK_OK(writer.Finish());
- }
- // File #1 contains two slices:
- //
- // slice #0 is the bottom left corner
- // . . . . .
- // . . . . .
- // 10 11 12 . .
- // 15 16 17 . .
- //
- // slice #1 is the bottom right corner
- // . . . . .
- // . . . . .
- // . . . . .
- // . . . 18 19
- {
- const string fname = strings::StrCat(fname_base, "_1");
- TensorSliceWriter writer(fname, create_function);
- // slice #0
- {
- const float data[] = {10, 11, 12, 15, 16, 17};
- TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- }
- // slice #1
- {
- const float data[] = {18, 19};
- TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
- TF_CHECK_OK(writer.Add("test", shape, slice, data));
- }
- TF_CHECK_OK(writer.Finish());
- }
- // Notice that we leave a hole in the tensor
- // . . . . .
- // . . . . .
- // . . . (13) (14)
- // . . . . .
- // Now we need to read the tensor slices
- TensorSliceReaderCache cache;
- const string filepattern = strings::StrCat(fname_base, "_*");
- const TensorSliceReader* reader = cache.GetReader(
- filepattern, open_function, TensorSliceReader::kLoadAllShards);
- EXPECT_TRUE(reader != nullptr);
- EXPECT_EQ(2, reader->num_files());
- // We query some of the tensors
- {
- TensorShape shape;
- DataType type;
- EXPECT_TRUE(reader->HasTensor("test", &shape, &type));
- EXPECT_EQ("[4,5]", shape.DebugString());
- EXPECT_EQ(DT_FLOAT, type);
- EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr));
- }
- // Make sure the reader is cached.
- const TensorSliceReader* reader2 = cache.GetReader(
- filepattern, open_function, TensorSliceReader::kLoadAllShards);
- EXPECT_EQ(reader, reader2);
- reader = cache.GetReader("file_does_not_exist", open_function,
- TensorSliceReader::kLoadAllShards);
- EXPECT_TRUE(reader == nullptr);
- }
- TEST(CachedTensorSliceReaderTest, SimpleFloat) {
- CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder,
- OpenTableTensorSliceReader);
- }
- static void VersionTest(const VersionDef& versions, const string& error) {
- const string path = io::JoinPath(testing::TmpDir(), "checkpoint");
- {
- // Prepare an empty checkpoint with some version information
- SavedTensorSlices sts;
- sts.mutable_meta()->mutable_versions()->CopyFrom(versions);
- string contents;
- EXPECT_TRUE(sts.SerializeToString(&contents));
- // Write it to disk
- TensorSliceWriter::Builder* builder;
- TF_ASSERT_OK(CreateTableTensorSliceBuilder(path, &builder));
- builder->Add(kSavedTensorSlicesKey, contents);
- int64 file_size;
- builder->Finish(&file_size);
- delete builder;
- }
- // Read it back in and verify that we get the expected error
- TensorSliceReader reader(path, OpenTableTensorSliceReader);
- EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT &&
- StringPiece(reader.status().error_message()).starts_with(error))
- << "Expected error starting with '" << errors::InvalidArgument(error)
- << "', got '" << reader.status() << "'";
- }
- TEST(CheckpointVersionTest, MinConsumer) {
- VersionDef versions;
- versions.set_producer(TF_CHECKPOINT_VERSION + 1);
- versions.set_min_consumer(TF_CHECKPOINT_VERSION + 1);
- VersionTest(
- versions,
- strings::StrCat("Checkpoint min consumer version ",
- TF_CHECKPOINT_VERSION + 1, " above current version ",
- TF_CHECKPOINT_VERSION, " for TensorFlow"));
- }
- TEST(CheckpointVersionTest, MinProducer) {
- VersionDef versions;
- versions.set_producer(TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1);
- VersionTest(versions, strings::StrCat("Checkpoint producer version ",
- TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1,
- " below min producer ",
- TF_CHECKPOINT_VERSION_MIN_PRODUCER,
- " supported by TensorFlow"));
- }
- TEST(CheckpointVersionTest, BadConsumer) {
- VersionDef versions;
- versions.set_producer(TF_CHECKPOINT_VERSION + 1);
- versions.add_bad_consumers(TF_CHECKPOINT_VERSION);
- VersionTest(
- versions,
- strings::StrCat(
- "Checkpoint disallows consumer version ", TF_CHECKPOINT_VERSION,
- ". Please upgrade TensorFlow: this version is likely buggy."));
- }
- } // namespace
- } // namespace checkpoint
- } // namespace tensorflow