/dipy/workflows/tests/test_reconst_csa_csd.py

https://github.com/arokem/dipy · Python · 137 lines · 110 code · 24 blank · 3 comment · 9 complexity · c1b7464f612af0cfae7d578236f9af06 MD5 · raw file

  1. import warnings
  2. import logging
  3. import numpy as np
  4. from os.path import join as pjoin
  5. import numpy.testing as npt
  6. from dipy.io.peaks import load_peaks
  7. from dipy.io.gradients import read_bvals_bvecs
  8. from dipy.io.image import load_nifti, save_nifti, load_nifti_data
  9. from dipy.core.gradients import generate_bvecs
  10. from nibabel.tmpdirs import TemporaryDirectory
  11. from dipy.data import get_fnames
  12. from dipy.workflows.reconst import ReconstCSDFlow, ReconstCSAFlow
  13. from dipy.reconst.shm import descoteaux07_legacy_msg, sph_harm_ind_list
  14. logging.getLogger().setLevel(logging.INFO)
  15. def test_reconst_csa():
  16. with warnings.catch_warnings():
  17. warnings.filterwarnings(
  18. "ignore", message=descoteaux07_legacy_msg,
  19. category=PendingDeprecationWarning)
  20. reconst_flow_core(ReconstCSAFlow)
  21. def test_reconst_csd():
  22. with warnings.catch_warnings():
  23. warnings.filterwarnings(
  24. "ignore", message=descoteaux07_legacy_msg,
  25. category=PendingDeprecationWarning)
  26. reconst_flow_core(ReconstCSDFlow)
  27. def reconst_flow_core(flow):
  28. with TemporaryDirectory() as out_dir:
  29. data_path, bval_path, bvec_path = get_fnames('small_64D')
  30. volume, affine = load_nifti(data_path)
  31. mask = np.ones_like(volume[:, :, :, 0])
  32. mask_path = pjoin(out_dir, 'tmp_mask.nii.gz')
  33. save_nifti(mask_path, mask.astype(np.uint8), affine)
  34. reconst_flow = flow()
  35. for sh_order in [4, 6, 8]:
  36. if flow.get_short_name() == 'csd':
  37. reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
  38. sh_order=sh_order,
  39. out_dir=out_dir, extract_pam_values=True)
  40. elif flow.get_short_name() == 'csa':
  41. reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
  42. sh_order=sh_order,
  43. odf_to_sh_order=sh_order,
  44. out_dir=out_dir, extract_pam_values=True)
  45. gfa_path = reconst_flow.last_generated_outputs['out_gfa']
  46. gfa_data = load_nifti_data(gfa_path)
  47. npt.assert_equal(gfa_data.shape, volume.shape[:-1])
  48. peaks_dir_path =\
  49. reconst_flow.last_generated_outputs['out_peaks_dir']
  50. peaks_dir_data = load_nifti_data(peaks_dir_path)
  51. npt.assert_equal(peaks_dir_data.shape[-1], 15)
  52. npt.assert_equal(peaks_dir_data.shape[:-1], volume.shape[:-1])
  53. peaks_idx_path = \
  54. reconst_flow.last_generated_outputs['out_peaks_indices']
  55. peaks_idx_data = load_nifti_data(peaks_idx_path)
  56. npt.assert_equal(peaks_idx_data.shape[-1], 5)
  57. npt.assert_equal(peaks_idx_data.shape[:-1], volume.shape[:-1])
  58. peaks_vals_path = \
  59. reconst_flow.last_generated_outputs['out_peaks_values']
  60. peaks_vals_data = load_nifti_data(peaks_vals_path)
  61. npt.assert_equal(peaks_vals_data.shape[-1], 5)
  62. npt.assert_equal(peaks_vals_data.shape[:-1], volume.shape[:-1])
  63. shm_path = reconst_flow.last_generated_outputs['out_shm']
  64. shm_data = load_nifti_data(shm_path)
  65. # Test that the number of coefficients is what you would expect
  66. # given the order of the sh basis:
  67. npt.assert_equal(shm_data.shape[-1],
  68. sph_harm_ind_list(sh_order)[0].shape[0])
  69. npt.assert_equal(shm_data.shape[:-1], volume.shape[:-1])
  70. pam = load_peaks(reconst_flow.last_generated_outputs['out_pam'])
  71. npt.assert_allclose(pam.peak_dirs.reshape(peaks_dir_data.shape),
  72. peaks_dir_data)
  73. npt.assert_allclose(pam.peak_values, peaks_vals_data)
  74. npt.assert_allclose(pam.peak_indices, peaks_idx_data)
  75. npt.assert_allclose(pam.shm_coeff, shm_data)
  76. npt.assert_allclose(pam.gfa, gfa_data)
  77. bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path)
  78. bvals[0] = 5.
  79. bvecs = generate_bvecs(len(bvals))
  80. tmp_bval_path = pjoin(out_dir, "tmp.bval")
  81. tmp_bvec_path = pjoin(out_dir, "tmp.bvec")
  82. np.savetxt(tmp_bval_path, bvals)
  83. np.savetxt(tmp_bvec_path, bvecs.T)
  84. reconst_flow._force_overwrite = True
  85. if flow.get_short_name() == 'csd':
  86. reconst_flow = flow()
  87. reconst_flow._force_overwrite = True
  88. reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
  89. out_dir=out_dir, frf=[15, 5, 5])
  90. reconst_flow = flow()
  91. reconst_flow._force_overwrite = True
  92. reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
  93. out_dir=out_dir, frf='15, 5, 5')
  94. reconst_flow = flow()
  95. reconst_flow._force_overwrite = True
  96. reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
  97. out_dir=out_dir, frf=None)
  98. reconst_flow2 = flow()
  99. reconst_flow2._force_overwrite = True
  100. reconst_flow2.run(data_path, bval_path, bvec_path, mask_path,
  101. out_dir=out_dir, frf=None,
  102. roi_center=[5, 5, 5])
  103. else:
  104. with npt.assert_raises(BaseException):
  105. npt.assert_warns(UserWarning, reconst_flow.run, data_path,
  106. tmp_bval_path, tmp_bvec_path, mask_path,
  107. out_dir=out_dir, extract_pam_values=True)
  108. # test parallel implementation
  109. reconst_flow = flow()
  110. reconst_flow._force_overwrite = True
  111. reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
  112. out_dir=out_dir,
  113. parallel=True, num_processes=2)