/Tensorflow/Detector/input_producer.py

https://github.com/qjadud1994/Text_Detector · Python · 116 lines · 103 code · 8 blank · 5 comment · 2 complexity · a3349047efdd8f964b37dad6acc6e681 MD5 · raw file

  1. import os
  2. import tensorflow as tf
  3. slim = tf.contrib.slim
  4. class InputProducer(object):
  5. def __init__(self, preprocess_image_fn=None, vertical_image=False):
  6. self.vertical_image = vertical_image
  7. self._preprocess_image = preprocess_image_fn if preprocess_image_fn is not None \
  8. else self._default_preprocess_image_fn
  9. self.ITEMS_TO_DESCRIPTIONS = {
  10. 'image': 'A color image of varying height and width.',
  11. 'shape': 'Shape of the image',
  12. 'object/bbox': 'A list of bounding boxes, one per each object.',
  13. 'object/label': 'A list of labels, one per each object.',
  14. }
  15. self.SPLITS_TO_SIZES = {
  16. 'train_IC13': 229,
  17. 'val_IC13': 233,
  18. 'train_2': 850000,
  19. 'val_2': 8750,
  20. 'train_quad': 850000,
  21. 'val_quad': 8750,
  22. 'train_IC15': 1000,
  23. 'val_IC15': 500,
  24. 'train_IC15_mask': 1000,
  25. 'val_IC15_mask': 500
  26. }
  27. self.FILE_PATTERN = '%s.record'
  28. def num_classes(self):
  29. return 20
  30. def get_split(self, split_name, dataset_dir, is_rect=True):
  31. """Gets a dataset tuple with instructions for reading Pascal VOC dataset.
  32. Args:
  33. split_name: A train/test split name.
  34. dataset_dir: The base directory of the dataset sources.
  35. file_pattern: The file pattern to use when matching the dataset sources.
  36. It is assumed that the pattern contains a '%s' string so that the split
  37. name can be inserted.
  38. reader: The TensorFlow reader type.
  39. Returns:
  40. A `Dataset` namedtuple.
  41. Raises:
  42. ValueError: if `split_name` is not a valid train/test split.
  43. """
  44. if split_name not in self.SPLITS_TO_SIZES:
  45. raise ValueError('split name %s was not recognized.' % split_name)
  46. file_pattern = os.path.join(dataset_dir, self.FILE_PATTERN % split_name)
  47. reader = tf.TFRecordReader
  48. if is_rect: # Rect annotations
  49. keys_to_features = {
  50. 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
  51. 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
  52. 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
  53. 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
  54. 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
  55. 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
  56. 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
  57. }
  58. items_to_handlers = {
  59. 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
  60. 'object/bbox': slim.tfexample_decoder.BoundingBox(
  61. ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
  62. 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'),
  63. }
  64. else: #Quad annotations
  65. keys_to_features = {
  66. 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
  67. 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
  68. 'image/object/bbox/y0': tf.VarLenFeature(dtype=tf.float32),
  69. 'image/object/bbox/x0': tf.VarLenFeature(dtype=tf.float32),
  70. 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32),
  71. 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32),
  72. 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32),
  73. 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32),
  74. 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32),
  75. 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32),
  76. 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
  77. }
  78. items_to_handlers = {
  79. 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
  80. 'object/quad1': slim.tfexample_decoder.BoundingBox(
  81. ['y0', 'x0', 'y1', 'x1'], 'image/object/bbox/'),
  82. 'object/quad2': slim.tfexample_decoder.BoundingBox(
  83. ['y2', 'x2', 'y3', 'x3'], 'image/object/bbox/'),
  84. 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'),
  85. }
  86. decoder = slim.tfexample_decoder.TFExampleDecoder(
  87. keys_to_features, items_to_handlers)
  88. labels_to_names = None
  89. #if has_labels(dataset_dir):
  90. # labels_to_names = read_label_file(dataset_dir)
  91. return slim.dataset.Dataset(
  92. data_sources=file_pattern,
  93. reader=reader,
  94. decoder=decoder,
  95. num_samples=self.SPLITS_TO_SIZES[split_name],
  96. items_to_descriptions=self.ITEMS_TO_DESCRIPTIONS,
  97. num_classes=self.num_classes(),
  98. labels_to_names=labels_to_names)
  99. def _default_preprocess_image_fn(self, image, is_train=True):
  100. return image