/source/pic2card/commands/generate_tfrecord.py

https://github.com/Microsoft/AdaptiveCards · Python · 134 lines · 83 code · 16 blank · 35 comment · 11 complexity · 47d5059d73edc6f6620a69b20b2bf579 MD5 · raw file

  1. """
  2. Generates tensorflow records from the label mapped train and test csv
  3. files.
  4. Usage:
  5. # From tensorflow/models/
  6. # Create train data:
  7. python generate_tfrecord.py --csv_input=data/train_labels.csv \
  8. --output_path=train.record
  9. # Create test data:
  10. python generate_tfrecord.py --csv_input=data/test_labels.csv \
  11. --output_path=test.record
  12. """
  13. from __future__ import absolute_import
  14. from __future__ import division
  15. from __future__ import print_function
  16. import io
  17. import os
  18. from collections import namedtuple
  19. import pandas as pd
  20. import tensorflow as tf
  21. from PIL import Image
  22. from object_detection.utils import dataset_util
  23. flags = tf.app.flags
  24. flags.DEFINE_string("csv_input", "", "Path to the CSV input")
  25. flags.DEFINE_string("output_path", "", "Path to output TFRecord")
  26. flags.DEFINE_string("image_dir", "", "Path to images")
  27. FLAGS = flags.FLAGS
  28. # TO-DO replace this with label map
  29. def class_text_to_int(row_label):
  30. """
  31. Function to define the class lables
  32. @param row_label: integer class value from the csv
  33. """
  34. if row_label == "textbox":
  35. return 1
  36. if row_label == "radio_button":
  37. return 2
  38. if row_label == "checkbox":
  39. return 3
  40. if row_label == "actionset":
  41. return 4
  42. if row_label == "image":
  43. return 5
  44. else:
  45. return 0
  46. def create_tf_example(group, path):
  47. """
  48. Generate tf recods by parsing the xml with
  49. the properites and labels.
  50. @param group: filename group
  51. @param path: images path
  52. @return: the tf record
  53. """
  54. # import pdb; pdb.set_trace()
  55. # name_filename = group.filename[:group.filename.find(".")] + ".png"
  56. with tf.gfile.GFile(os.path.join(path, "{}".format(group.filename)),
  57. "rb") as fid:
  58. encoded_jpg = fid.read()
  59. encoded_jpg_io = io.BytesIO(encoded_jpg)
  60. image = Image.open(encoded_jpg_io)
  61. width, height = image.size
  62. filename = group.filename.encode("utf8")
  63. image_format = b"png"
  64. xmins = []
  65. xmaxs = []
  66. ymins = []
  67. ymaxs = []
  68. classes_text = []
  69. classes = []
  70. for index, row in group.object.iterrows():
  71. xmins.append(row["xmin"] / width)
  72. xmaxs.append(row["xmax"] / width)
  73. ymins.append(row["ymin"] / height)
  74. ymaxs.append(row["ymax"] / height)
  75. classes_text.append(row["class"].encode("utf8"))
  76. classes.append(class_text_to_int(row["class"]))
  77. tf_example = tf.train.Example(features=tf.train.Features(feature={
  78. "image/height": dataset_util.int64_feature(height),
  79. "image/width": dataset_util.int64_feature(width),
  80. "image/filename": dataset_util.bytes_feature(filename),
  81. "image/source_id": dataset_util.bytes_feature(filename),
  82. "image/encoded": dataset_util.bytes_feature(encoded_jpg),
  83. "image/format": dataset_util.bytes_feature(image_format),
  84. "image/object/bbox/xmin": dataset_util.float_list_feature(xmins),
  85. "image/object/bbox/xmax": dataset_util.float_list_feature(xmaxs),
  86. "image/object/bbox/ymin": dataset_util.float_list_feature(ymins),
  87. "image/object/bbox/ymax": dataset_util.float_list_feature(ymaxs),
  88. "image/object/class/text": dataset_util.bytes_list_feature(
  89. classes_text),
  90. "image/object/class/label": dataset_util.int64_list_feature(
  91. classes),
  92. }))
  93. return tf_example
  94. def main(_):
  95. """
  96. Writes the generated tensorflow records into the specified ouput
  97. directory
  98. """
  99. writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  100. path = os.path.join(FLAGS.image_dir)
  101. examples = pd.read_csv(FLAGS.csv_input)
  102. data = namedtuple("data", ["filename", "object"])
  103. gb = examples.groupby("filename")
  104. grouped = [data(filename, gb.get_group(x))
  105. for filename, x in zip(gb.groups.keys(), gb.groups)]
  106. for group in grouped:
  107. tf_example = create_tf_example(group, path)
  108. writer.write(tf_example.SerializeToString())
  109. writer.close()
  110. output_path = os.path.join(os.getcwd(), FLAGS.output_path)
  111. print("Successfully created the TFRecords: {}".format(output_path))
  112. if __name__ == "__main__":
  113. tf.app.run()