/pytorch_segmentation_detection/datasets/endovis_instrument_2015.py

https://github.com/warmspringwinds/pytorch-segmentation-detection
Python | 225 lines | 111 code | 87 blank | 27 comment | 17 complexity | 286621769144194755adedf5873e3910 MD5 | raw file
  1. import os
  2. import sys
  3. import glob
  4. import torch
  5. import torch.utils.data as data
  6. import numpy as np
  7. # Reading video files
  8. import imageio
  9. import skimage.io as io
  10. from PIL import Image
  11. from ..utils.endovis_instrument import clean_up_annotation, merge_left_and_right_annotations
  12. class Endovis_Instrument_2015(data.Dataset):
  13. CLASS_NAMES = ['background', 'manipulator', 'shaft', 'ambigious']
  14. # Urls of original pascal and additional segmentations masks
  15. URL = 'https://endovissub-instrument.grand-challenge.org/'
  16. relative_image_save_path_train = 'Processed/train/images'
  17. relative_annotation_save_path_train = 'Processed/train/annotations'
  18. relative_image_save_path_validation = 'Processed/val/images'
  19. relative_annotation_save_path_validation = 'Processed/val/annotations'
  20. def __init__(self,
  21. root,
  22. train=True,
  23. joint_transform=None,
  24. prepare_dataset=False,
  25. dataset_type=0,
  26. split_mode=2,
  27. validation_datasets_numbers=[2]):
  28. # Dataset types:
  29. # 0 -- binary
  30. # 1 -- multiclass
  31. self.root = root
  32. self.joint_transform = joint_transform
  33. self.dataset_type = dataset_type
  34. self.validation_datasets_numbers = validation_datasets_numbers
  35. if prepare_dataset:
  36. self._prepare_dataset(train=True)
  37. self._prepare_dataset(train=False)
  38. if train:
  39. saved_images_path = os.path.join(self.root, self.relative_image_save_path_train)
  40. saved_annotations_path = os.path.join(self.root, self.relative_annotation_save_path_train)
  41. else:
  42. saved_images_path = os.path.join(self.root, self.relative_image_save_path_validation)
  43. saved_annotations_path = os.path.join(self.root, self.relative_annotation_save_path_validation)
  44. # We need this for __getitem__
  45. self.saved_images_template = os.path.join(saved_images_path, "{0:08d}.jpg")
  46. self.saved_annotations_template = os.path.join(saved_annotations_path, "{0:08d}.png")
  47. # Get the number of all annotations
  48. # we apply a regex here to filter annotations out from
  49. # anything else
  50. saved_annotations = glob.glob( os.path.join(saved_annotations_path, ('[0-9]' * 8) + '.png') )
  51. self.dataset_size = len(saved_annotations)
  52. # TODO: Create train/val split later
  53. def __len__(self):
  54. return self.dataset_size
  55. def merge_parts_annotation_numpy_into_binary_tool_annotation(self, parts_annotation_numpy, label_to_assign=1):
  56. parts_annotation_numpy_copy = parts_annotation_numpy.copy()
  57. parts_annotation_numpy_copy[parts_annotation_numpy_copy > 0] = label_to_assign
  58. return parts_annotation_numpy_copy
  59. def __getitem__(self, index):
  60. img_path = self.saved_images_template.format(index)
  61. annotation_path = self.saved_annotations_template.format(index)
  62. _img = Image.open(img_path).convert('RGB')
  63. # TODO: maybe can be done in a better way
  64. _target = Image.open(annotation_path)
  65. if self.dataset_type == 0:
  66. target_numpy = np.asarray(_target)
  67. target_numpy = self.merge_parts_annotation_numpy_into_binary_tool_annotation(target_numpy)
  68. _target = Image.fromarray(target_numpy)
  69. if self.joint_transform is not None:
  70. _img, _target = self.joint_transform([_img, _target])
  71. return _img, _target
  72. def _prepare_dataset(self, train=True):
  73. """
  74. Creates a new folder with the name Processed in the root of the dataset
  75. where all the images and annotations are stored as plain jpg and png images.
  76. """
  77. datasets_numbers = set(list(range(1, 5)))
  78. if train:
  79. annotation_folder_to_save = os.path.join(self.root, self.relative_annotation_save_path_train )
  80. images_folder_to_save = os.path.join(self.root, self.relative_image_save_path_train)
  81. datasets_numbers = datasets_numbers - set(self.validation_datasets_numbers)
  82. else:
  83. annotation_folder_to_save = os.path.join(self.root, self.relative_annotation_save_path_validation )
  84. images_folder_to_save = os.path.join(self.root, self.relative_image_save_path_validation)
  85. datasets_numbers = self.validation_datasets_numbers
  86. annotation_save_template = os.path.join( annotation_folder_to_save, "{0:08d}.png" )
  87. images_save_template = os.path.join( images_folder_to_save, "{0:08d}.jpg" )
  88. # Creating folders to save all the images and annotations
  89. if not os.path.exists(annotation_folder_to_save):
  90. os.makedirs(annotation_folder_to_save)
  91. if not os.path.exists(images_folder_to_save):
  92. os.makedirs(images_folder_to_save)
  93. # Creating template to go through the datasets folders
  94. dataset_folder_template = "Training/Dataset{}"
  95. dataset_template = os.path.join(self.root, dataset_folder_template)
  96. image_number_offset = 0
  97. # We have overall 4 datasets
  98. for current_dataset_number in datasets_numbers:
  99. current_dataset_path = dataset_template.format(current_dataset_number)
  100. if current_dataset_number == 1:
  101. # First dataset has two vides with separate annotations for each tool
  102. left_annotation_video_filename = os.path.join(current_dataset_path, 'Left_Instrument_Segmentation.avi')
  103. right_annotation_video_filename = os.path.join(current_dataset_path, 'Right_Instrument_Segmentation.avi')
  104. else:
  105. # Other datasets have just one video with annotation
  106. annotation_video_filename = os.path.join(current_dataset_path, 'Segmentation.avi')
  107. # Each dataset has just one video and it has the same name
  108. images_video_filename = os.path.join(current_dataset_path, 'Video.avi')
  109. # Creating readers for each of our videos
  110. images_reader = imageio.get_reader(images_video_filename, 'ffmpeg')
  111. # Once again -- first dataset is an exception
  112. if current_dataset_number == 1:
  113. left_annotations_reader = imageio.get_reader(left_annotation_video_filename, 'ffmpeg')
  114. right_annotations_reader = imageio.get_reader(right_annotation_video_filename, 'ffmpeg')
  115. else:
  116. annotations_reader = imageio.get_reader(annotation_video_filename, 'ffmpeg')
  117. current_dataset_number_of_images = images_reader.get_length()
  118. for current_image_number in range(current_dataset_number_of_images):
  119. # We need to merge two separate annotation files in the first dataset
  120. if current_dataset_number == 1:
  121. current_annotatio_left = left_annotations_reader.get_data(current_image_number)
  122. processed_current_annotation_left = clean_up_annotation(current_annotatio_left)
  123. current_annotation_right = right_annotations_reader.get_data(current_image_number)
  124. processed_current_annotation_right = clean_up_annotation(current_annotation_right)
  125. processed_current_annotation_final = merge_left_and_right_annotations(processed_current_annotation_left,
  126. processed_current_annotation_right)
  127. else:
  128. current_annotation = annotations_reader.get_data(current_image_number)
  129. processed_current_annotation_final = clean_up_annotation(current_annotation)
  130. current_image = images_reader.get_data(current_image_number)
  131. # add offset so that we respect the global count and not of the current dataset
  132. current_annotation_name_to_save = annotation_save_template.format(current_image_number + image_number_offset)
  133. current_image_name_to_save = images_save_template.format(current_image_number + image_number_offset)
  134. # add the offset from previous dataset image files -- so that we get all images saved
  135. io.imsave(current_annotation_name_to_save, processed_current_annotation_final)
  136. io.imsave(current_image_name_to_save, current_image)
  137. # Update the global count of images
  138. image_number_offset += current_dataset_number_of_images