/data/det/preprocess/voc/voc_det_generator.py

https://github.com/donnyyou/torchcv · Python · 138 lines · 113 code · 20 blank · 5 comment · 21 complexity · a39cf23e054f24040743ee3c0de2f352 MD5 · raw file

  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # Author: Donny You(youansheng@gmail.com)
  4. # VOC det data generator.
  5. import json
  6. import os
  7. import argparse
  8. import shutil
  9. from bs4 import BeautifulSoup
  10. JOSN_DIR = 'json'
  11. IMAGE_DIR = 'image'
  12. CAT_DICT = {
  13. 'aeroplane': 0, 'bicycle':1,'bird':2,'boat':3,'bottle':4,
  14. 'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9, 'diningtable': 10,
  15. 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15,
  16. 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19
  17. }
  18. class VocDetGenerator(object):
  19. def __init__(self, args, json_dir=JOSN_DIR, image_dir=IMAGE_DIR):
  20. self.args = args
  21. self.train_json_dir = os.path.join(self.args.save_dir, 'train', json_dir)
  22. self.val_json_dir = os.path.join(self.args.save_dir, 'val', json_dir)
  23. if not os.path.exists(self.train_json_dir):
  24. os.makedirs(self.train_json_dir)
  25. if not os.path.exists(self.val_json_dir):
  26. os.makedirs(self.val_json_dir)
  27. self.train_image_dir = os.path.join(self.args.save_dir, 'train', image_dir)
  28. self.val_image_dir = os.path.join(self.args.save_dir, 'val', image_dir)
  29. if not os.path.exists(self.train_image_dir):
  30. os.makedirs(self.train_image_dir)
  31. if not os.path.exists(self.val_image_dir):
  32. os.makedirs(self.val_image_dir)
  33. def _get_info_tree(self, label_file, dataset='VOC2007'):
  34. label_file_path = os.path.join(self.args.root_dir, dataset, 'Annotations', label_file)
  35. object_list = list()
  36. tree_dict = dict()
  37. with open(label_file_path, 'r') as file_stream:
  38. xml_tree = file_stream.readlines()
  39. xml_tree = ''.join([line.strip('\t') for line in xml_tree])
  40. xml_tree = BeautifulSoup(xml_tree, "html5lib")
  41. for obj in xml_tree.findAll('object'):
  42. object = dict()
  43. for name_tag in obj.findChildren('name'):
  44. name = str(name_tag.contents[0])
  45. difficult = int(obj.find('difficult').contents[0])
  46. if name in CAT_DICT:
  47. bbox = obj.findChildren('bndbox')[0]
  48. # 1-indexing to 0-indexing.
  49. xmin = int(float(bbox.findChildren('xmin')[0].contents[0])) - 1
  50. ymin = int(float(bbox.findChildren('ymin')[0].contents[0])) - 1
  51. xmax = int(float(bbox.findChildren('xmax')[0].contents[0])) - 1
  52. ymax = int(float(bbox.findChildren('ymax')[0].contents[0])) - 1
  53. object['bbox'] = [xmin, ymin, xmax, ymax]
  54. object['label'] = CAT_DICT[name]
  55. object['difficult'] = difficult
  56. object_list.append(object)
  57. tree_dict['objects'] = object_list
  58. return tree_dict
  59. def generate_label(self):
  60. file_count = 0
  61. if self.args.dataset in ['VOC07', 'VOC07+12', 'VOC07++12']:
  62. with open(os.path.join(self.args.root_dir, 'VOC2007/ImageSets/Main/trainval.txt'), 'r') as train_stream:
  63. for img_name in train_stream.readlines():
  64. img_name = img_name.rstrip()
  65. label_file = '{}.xml'.format(img_name)
  66. file_count += 1
  67. tree_dict = self._get_info_tree(label_file, dataset='VOC2007')
  68. fw = open(os.path.join(self.train_json_dir, '{}.json'.format(img_name)), 'w')
  69. fw.write(json.dumps(tree_dict))
  70. fw.close()
  71. shutil.copy(os.path.join(self.args.root_dir, 'VOC2007/JPEGImages', '{}.jpg'.format(img_name)),
  72. os.path.join(self.train_image_dir, '{}.jpg'.format(img_name)))
  73. if self.args.dataset in ['VOC07+12', 'VOC07++12', 'VOC12']:
  74. with open(os.path.join(self.args.root_dir, 'VOC2012/ImageSets/Main/trainval.txt'), 'r') as train_stream:
  75. for img_name in train_stream.readlines():
  76. img_name = img_name.rstrip()
  77. label_file = '{}.xml'.format(img_name)
  78. file_count += 1
  79. tree_dict = self._get_info_tree(label_file, dataset='VOC2012')
  80. fw = open(os.path.join(self.train_json_dir, '{}.json'.format(img_name)), 'w')
  81. fw.write(json.dumps(tree_dict))
  82. fw.close()
  83. shutil.copy(os.path.join(self.args.root_dir, 'VOC2012/JPEGImages', '{}.jpg'.format(img_name)),
  84. os.path.join(self.train_image_dir, '{}.jpg'.format(img_name)))
  85. if self.args.dataset in ['VOC07++12']:
  86. with open(os.path.join(self.args.root_dir, 'VOC2007/ImageSets/Main/test.txt'), 'r') as train_stream:
  87. for img_name in train_stream.readlines():
  88. img_name = img_name.rstrip()
  89. label_file = '{}.xml'.format(img_name)
  90. file_count += 1
  91. tree_dict = self._get_info_tree(label_file)
  92. fw = open(os.path.join(self.train_json_dir, '{}.json'.format(img_name)), 'w')
  93. fw.write(json.dumps(tree_dict))
  94. fw.close()
  95. shutil.copy(os.path.join(self.args.root_dir, 'VOC2007/JPEGImages', '{}.jpg'.format(img_name)),
  96. os.path.join(self.train_image_dir, '{}.jpg'.format(img_name)))
  97. with open(os.path.join(self.args.root_dir, 'VOC2007/ImageSets/Main/test.txt'), 'r') as train_stream:
  98. for img_name in train_stream.readlines():
  99. img_name = img_name.rstrip()
  100. label_file = '{}.xml'.format(img_name)
  101. file_count += 1
  102. tree_dict = self._get_info_tree(label_file)
  103. fw = open(os.path.join(self.val_json_dir, '{}.json'.format(img_name)), 'w')
  104. fw.write(json.dumps(tree_dict))
  105. fw.close()
  106. shutil.copy(os.path.join(self.args.root_dir, 'VOC2007/JPEGImages', '{}.jpg'.format(img_name)),
  107. os.path.join(self.val_image_dir, '{}.jpg'.format(img_name)))
  108. if __name__ == "__main__":
  109. parser = argparse.ArgumentParser()
  110. parser.add_argument('--save_dir', default=None, type=str,
  111. dest='save_dir', help='The directory to save the data.')
  112. parser.add_argument('--root_dir', default=None, type=str,
  113. dest='root_dir', help='The directory of the voc root.')
  114. parser.add_argument('--dataset', default=None, type=str,
  115. dest='dataset', help='The target dataset that will be generated.')
  116. args = parser.parse_args()
  117. voc_det_generator = VocDetGenerator(args)
  118. voc_det_generator.generate_label()