/data_preprocess.py
Python | 126 lines | 113 code | 8 blank | 5 comment | 5 complexity | bcd23ef9365baa529af5f67d149bbb3a MD5 | raw file
- import os
- import glob
- import numpy as np
- import tables
- from random import shuffle
- from utils.normalize import normalize_data_storage, reslice_image_set
- from utils import pickle_dump, pickle_load
- from main import config
- def create_data_file(out_file, n_channels, n_samples, image_shape):
- hdf5_file = tables.open_file(out_file, mode='w')
- filters = tables.Filters(complevel=5, complib='blosc')
- data_shape = tuple([0, n_channels] + list(image_shape))
- truth_shape = tuple([0, 1] + list(image_shape))
- data_storage = hdf5_file.create_earray(hdf5_file.root, 'data', tables.Float32Atom(), shape=data_shape, filters=filters, expectedrows=n_samples)
- truth_storage = hdf5_file.create_earray(hdf5_file.root, 'truth', tables.UInt8Atom(), shape=truth_shape, filters=filters, expectedrows=n_samples)
- affine_storage = hdf5_file.create_earray(hdf5_file.root, 'affine', tables.Float32Atom(), shape=(0, 4, 4),filters=filters, expectedrows=n_samples)
- return hdf5_file, data_storage, truth_storage, affine_storage
- def write_image_data_to_file(image_files, data_storage, truth_storage, image_shape, n_channels, affine_storage, truth_dtype=np.uint8, crop=True, label_indices=None, save_truth=True):
- for set_of_files in image_files:
- if label_indices is None:
- _label_indices = len(set_of_files) - 1
- else:
- _label_indices = label_indices
- images = reslice_image_set(set_of_files, image_shape, label_indices=_label_indices, crop=crop)
- subject_data = [image.get_data() for image in images]
- add_data_to_storage(data_storage, truth_storage, affine_storage, subject_data, images[0].affine, n_channels, truth_dtype, save_truth=save_truth)
- return data_storage, truth_storage
- def add_data_to_storage(data_storage, truth_storage, affine_storage, subject_data, affine, n_channels, truth_dtype, save_truth=True):
- data_storage.append(np.asarray(subject_data[:n_channels])[np.newaxis])
- if save_truth:
- truth_storage.append(np.asarray(subject_data[n_channels], dtype=truth_dtype)[np.newaxis][np.newaxis])
- affine_storage.append(np.asarray(affine)[np.newaxis])
- def write_data_to_file(training_data_files, out_file, image_shape, truth_dtype=np.uint8, subject_ids=None, normalize=True, crop=True, save_truth=True):
- """
- Takes in a set of training images and writes those images to an hdf5 file.
- :param training_data_files: List of tuples containing the training data files. The modalities should be listed in
- the same order in each tuple. The last item in each tuple must be the labeled image. If the label image is not
- available, set save_truth to False.
- Example: [('sub1-T1.nii.gz', 'sub1-T2.nii.gz', 'sub1-truth.nii.gz'),
- ('sub2-T1.nii.gz', 'sub2-T2.nii.gz', 'sub2-truth.nii.gz')]
- :param out_file: Where the hdf5 file will be written to.
- :param image_shape: Shape of the images that will be saved to the hdf5 file.
- :param truth_dtype: Default is 8-bit unsigned integer.
- :return: Location of the hdf5 file with the image data written to it.
- """
- n_samples = len(training_data_files)
- n_channels = len(training_data_files[0])
- if save_truth:
- n_channels = n_channels - 1
- try:
- hdf5_file, data_storage, truth_storage, affine_storage = create_data_file(out_file,
- n_channels=n_channels, n_samples=n_samples, image_shape=image_shape)
- except Exception as e:
- # If something goes wrong, delete the incomplete data file
- os.remove(out_file)
- raise e
- label_indices = None
- if not save_truth:
- label_indices = []
- write_image_data_to_file(training_data_files, data_storage, truth_storage, image_shape, truth_dtype=truth_dtype, n_channels=n_channels, affine_storage=affine_storage, crop=crop, label_indices=label_indices, save_truth=save_truth)
- if subject_ids:
- hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
- if normalize:
- normalize_data_storage(data_storage)
- hdf5_file.close()
- return out_file
- def open_data_file(filename, readwrite="r"):
- return tables.open_file(filename, readwrite)
- def split_list(input_list, split=0.8, shuffle_list=True):
- if shuffle_list:
- shuffle(input_list)
- n_training = int(len(input_list) * split)
- training = input_list[:n_training]
- testing = input_list[n_training:]
- return training, testing
- def get_validation_split(data_file, training_file, validation_file, data_split=0.8, overwrite=False):
- """
- """
- if overwrite or not os.path.exists(training_file):
- print("Creating validation split...")
- nb_samples = data_file.root.data.shape[0]
- sample_list = list(range(nb_samples))
- training_list, validation_list = split_list(sample_list, split=data_split)
- pickle_dump(training_list, training_file)
- pickle_dump(validation_list, validation_file)
- return training_list, validation_list
- else:
- print("Loading previous validation split...")
- return pickle_load(training_file), pickle_load(validation_file)
- def fetch_training_data_files(data_dir, return_subject_ids=True):
- training_data_files = list()
- subject_ids = list()
- for subject_dir in glob.glob(os.path.join(data_dir, "*", "*")):
- subject_ids.append(os.path.basename(subject_dir))
- subject_files = list()
- for modality in config["all_modalities"]+["truth"]:
- subject_files.append(os.path.join(subject_dir, modality + ".nii.gz"))
- training_data_files.append(tuple(subject_files))
- if return_subject_ids:
- return training_data_files, subject_ids
- else:
- return training_data_files
-
- if __name__ =='__main__':
- data_dir = os.path.join(os.path.dirname(__file__), "data")
- training_files, subject_ids = fetch_training_data_files(data_dir, return_subject_ids=True)
- # print(training_files)
- write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
- subject_ids=subject_ids)
- data_file_opened = open_data_file(config["data_file"])
- get_validation_split(data_file_opened, config["training_file"],config["validation_file"])