/data_preprocess.py

https://github.com/athon2/BraTS2018_NvNet
Python | 126 lines | 113 code | 8 blank | 5 comment | 5 complexity | bcd23ef9365baa529af5f67d149bbb3a MD5 | raw file
  1. import os
  2. import glob
  3. import numpy as np
  4. import tables
  5. from random import shuffle
  6. from utils.normalize import normalize_data_storage, reslice_image_set
  7. from utils import pickle_dump, pickle_load
  8. from main import config
  9. def create_data_file(out_file, n_channels, n_samples, image_shape):
  10. hdf5_file = tables.open_file(out_file, mode='w')
  11. filters = tables.Filters(complevel=5, complib='blosc')
  12. data_shape = tuple([0, n_channels] + list(image_shape))
  13. truth_shape = tuple([0, 1] + list(image_shape))
  14. data_storage = hdf5_file.create_earray(hdf5_file.root, 'data', tables.Float32Atom(), shape=data_shape, filters=filters, expectedrows=n_samples)
  15. truth_storage = hdf5_file.create_earray(hdf5_file.root, 'truth', tables.UInt8Atom(), shape=truth_shape, filters=filters, expectedrows=n_samples)
  16. affine_storage = hdf5_file.create_earray(hdf5_file.root, 'affine', tables.Float32Atom(), shape=(0, 4, 4),filters=filters, expectedrows=n_samples)
  17. return hdf5_file, data_storage, truth_storage, affine_storage
  18. def write_image_data_to_file(image_files, data_storage, truth_storage, image_shape, n_channels, affine_storage, truth_dtype=np.uint8, crop=True, label_indices=None, save_truth=True):
  19. for set_of_files in image_files:
  20. if label_indices is None:
  21. _label_indices = len(set_of_files) - 1
  22. else:
  23. _label_indices = label_indices
  24. images = reslice_image_set(set_of_files, image_shape, label_indices=_label_indices, crop=crop)
  25. subject_data = [image.get_data() for image in images]
  26. add_data_to_storage(data_storage, truth_storage, affine_storage, subject_data, images[0].affine, n_channels, truth_dtype, save_truth=save_truth)
  27. return data_storage, truth_storage
  28. def add_data_to_storage(data_storage, truth_storage, affine_storage, subject_data, affine, n_channels, truth_dtype, save_truth=True):
  29. data_storage.append(np.asarray(subject_data[:n_channels])[np.newaxis])
  30. if save_truth:
  31. truth_storage.append(np.asarray(subject_data[n_channels], dtype=truth_dtype)[np.newaxis][np.newaxis])
  32. affine_storage.append(np.asarray(affine)[np.newaxis])
  33. def write_data_to_file(training_data_files, out_file, image_shape, truth_dtype=np.uint8, subject_ids=None, normalize=True, crop=True, save_truth=True):
  34. """
  35. Takes in a set of training images and writes those images to an hdf5 file.
  36. :param training_data_files: List of tuples containing the training data files. The modalities should be listed in
  37. the same order in each tuple. The last item in each tuple must be the labeled image. If the label image is not
  38. available, set save_truth to False.
  39. Example: [('sub1-T1.nii.gz', 'sub1-T2.nii.gz', 'sub1-truth.nii.gz'),
  40. ('sub2-T1.nii.gz', 'sub2-T2.nii.gz', 'sub2-truth.nii.gz')]
  41. :param out_file: Where the hdf5 file will be written to.
  42. :param image_shape: Shape of the images that will be saved to the hdf5 file.
  43. :param truth_dtype: Default is 8-bit unsigned integer.
  44. :return: Location of the hdf5 file with the image data written to it.
  45. """
  46. n_samples = len(training_data_files)
  47. n_channels = len(training_data_files[0])
  48. if save_truth:
  49. n_channels = n_channels - 1
  50. try:
  51. hdf5_file, data_storage, truth_storage, affine_storage = create_data_file(out_file,
  52. n_channels=n_channels, n_samples=n_samples, image_shape=image_shape)
  53. except Exception as e:
  54. # If something goes wrong, delete the incomplete data file
  55. os.remove(out_file)
  56. raise e
  57. label_indices = None
  58. if not save_truth:
  59. label_indices = []
  60. write_image_data_to_file(training_data_files, data_storage, truth_storage, image_shape, truth_dtype=truth_dtype, n_channels=n_channels, affine_storage=affine_storage, crop=crop, label_indices=label_indices, save_truth=save_truth)
  61. if subject_ids:
  62. hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
  63. if normalize:
  64. normalize_data_storage(data_storage)
  65. hdf5_file.close()
  66. return out_file
  67. def open_data_file(filename, readwrite="r"):
  68. return tables.open_file(filename, readwrite)
  69. def split_list(input_list, split=0.8, shuffle_list=True):
  70. if shuffle_list:
  71. shuffle(input_list)
  72. n_training = int(len(input_list) * split)
  73. training = input_list[:n_training]
  74. testing = input_list[n_training:]
  75. return training, testing
  76. def get_validation_split(data_file, training_file, validation_file, data_split=0.8, overwrite=False):
  77. """
  78. """
  79. if overwrite or not os.path.exists(training_file):
  80. print("Creating validation split...")
  81. nb_samples = data_file.root.data.shape[0]
  82. sample_list = list(range(nb_samples))
  83. training_list, validation_list = split_list(sample_list, split=data_split)
  84. pickle_dump(training_list, training_file)
  85. pickle_dump(validation_list, validation_file)
  86. return training_list, validation_list
  87. else:
  88. print("Loading previous validation split...")
  89. return pickle_load(training_file), pickle_load(validation_file)
  90. def fetch_training_data_files(data_dir, return_subject_ids=True):
  91. training_data_files = list()
  92. subject_ids = list()
  93. for subject_dir in glob.glob(os.path.join(data_dir, "*", "*")):
  94. subject_ids.append(os.path.basename(subject_dir))
  95. subject_files = list()
  96. for modality in config["all_modalities"]+["truth"]:
  97. subject_files.append(os.path.join(subject_dir, modality + ".nii.gz"))
  98. training_data_files.append(tuple(subject_files))
  99. if return_subject_ids:
  100. return training_data_files, subject_ids
  101. else:
  102. return training_data_files
  103. if __name__ =='__main__':
  104. data_dir = os.path.join(os.path.dirname(__file__), "data")
  105. training_files, subject_ids = fetch_training_data_files(data_dir, return_subject_ids=True)
  106. # print(training_files)
  107. write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
  108. subject_ids=subject_ids)
  109. data_file_opened = open_data_file(config["data_file"])
  110. get_validation_split(data_file_opened, config["training_file"],config["validation_file"])