/datasets/svhn.py

https://github.com/jerryli27/TwinGAN · Python · 114 lines · 78 code · 11 blank · 25 comment · 0 complexity · 18f54d18e6dfd0ddca321e74e7fc8693 MD5 · raw file

  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Provides data for the MNIST dataset.
  16. The dataset scripts used to create the dataset can be found at:
  17. tensorflow/models/research/slim/datasets/download_and_convert_mnist.py
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import os
  23. import tensorflow as tf
  24. from datasets import dataset_utils
  25. slim = tf.contrib.slim
  26. _FILE_PATTERN = '%s*'
  27. _SPLITS_TO_SIZES = {'train': 73257, 'test': 26032}
  28. _IMAGE_HW = 32
  29. _NUM_CLASSES = 10
  30. _ITEMS_TO_DESCRIPTIONS = {
  31. 'image': 'A [32 x 32 x 1] grayscale image.',
  32. 'label': 'A single integer between 0 and 9',
  33. }
  34. def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  35. """Gets a dataset tuple with instructions for reading MNIST.
  36. Args:
  37. split_name: A train/test split name.
  38. dataset_dir: The base directory of the dataset sources.
  39. file_pattern: The file pattern to use when matching the dataset sources.
  40. It is assumed that the pattern contains a '%s' string so that the split
  41. name can be inserted.
  42. reader: The TensorFlow reader type.
  43. Returns:
  44. A `Dataset` namedtuple.
  45. Raises:
  46. ValueError: if `split_name` is not a valid train/test split.
  47. """
  48. if split_name not in _SPLITS_TO_SIZES:
  49. raise ValueError('split name %s was not recognized.' % split_name)
  50. if not file_pattern:
  51. file_pattern = _FILE_PATTERN
  52. file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
  53. # Allowing None in the signature so that dataset_factory can use the default.
  54. if reader is None:
  55. reader = tf.TFRecordReader
  56. keys_to_features = {
  57. 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
  58. 'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'),
  59. # 'image/class/label': tf.FixedLenFeature(
  60. # [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
  61. # Changed to varlen to be compatible with the rest of the multi-label framework.
  62. 'image/class/label': tf.VarLenFeature(
  63. tf.int64,),
  64. }
  65. num_channels = 3
  66. if hasattr(tf.flags.FLAGS, 'color_space') and tf.flags.FLAGS.color_space =="gray":
  67. num_channels = 1
  68. items_to_handlers = {
  69. 'image': slim.tfexample_decoder.Image(shape=[_IMAGE_HW, _IMAGE_HW, num_channels], channels=num_channels),
  70. # 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
  71. # Took off shape to be compatible with the rest of the multi-label framework.
  72. 'label': slim.tfexample_decoder.Tensor('image/class/label',),
  73. }
  74. items_to_handlers['source'] = items_to_handlers['image']
  75. items_to_handlers['target'] = items_to_handlers['label']
  76. items_to_handlers['conditional_labels'] = items_to_handlers['label']
  77. decoder = slim.tfexample_decoder.TFExampleDecoder(
  78. keys_to_features, items_to_handlers)
  79. labels_to_names = None
  80. if dataset_utils.has_labels(dataset_dir):
  81. labels_to_names = dataset_utils.read_label_file(dataset_dir)
  82. return slim.dataset.Dataset(
  83. data_sources=file_pattern,
  84. reader=reader,
  85. decoder=decoder,
  86. num_samples=_SPLITS_TO_SIZES[split_name],
  87. num_classes=_NUM_CLASSES,
  88. items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
  89. labels_to_names=labels_to_names,
  90. items_used=['image', 'label', 'source', 'target', 'conditional_labels'],
  91. items_need_preprocessing=['image', 'label', 'source', 'target', 'conditional_labels'],
  92. has_source=True,)