PageRenderTime 28ms CodeModel.GetById 18ms RepoModel.GetById 1ms app.codeStats 0ms

/dipy/workflows/tests/test_tracking.py

https://github.com/nipy/dipy
Python | 262 lines | 209 code | 38 blank | 15 comment | 13 complexity | 48c2c4efcc2005d8529482c2981e34c8 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import warnings
  2. import numpy as np
  3. from numpy.testing import assert_equal
  4. from dipy.testing import assert_false, assert_true
  5. from os.path import join
  6. import nibabel as nib
  7. from nibabel.tmpdirs import TemporaryDirectory
  8. from dipy.data import get_fnames
  9. from dipy.io.image import save_nifti, load_nifti
  10. from dipy.io.streamline import load_tractogram
  11. from dipy.reconst.shm import descoteaux07_legacy_msg
  12. from dipy.workflows.mask import MaskFlow
  13. from dipy.workflows.reconst import ReconstCSDFlow
  14. from dipy.workflows.tracking import (LocalFiberTrackingPAMFlow,
  15. PFTrackingPAMFlow)
  16. def test_particle_filtering_tracking_workflows():
  17. with TemporaryDirectory() as out_dir:
  18. dwi_path, bval_path, bvec_path = get_fnames('small_64D')
  19. volume, affine = load_nifti(dwi_path)
  20. # Create some mask
  21. mask = np.ones_like(volume[:, :, :, 0], dtype=np.uint8)
  22. mask_path = join(out_dir, 'tmp_mask.nii.gz')
  23. save_nifti(mask_path, mask, affine)
  24. simple_wm = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  25. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  26. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  27. [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
  28. [0, 0, 1, 1, 1, 1, 0, 1, 0, 0],
  29. [0, 0, 1, 0, 1, 0, 1, 0, 0, 0],
  30. [0, 0, 1, 0, 1, 1, 0, 1, 0, 0],
  31. [0, 0, 0, 1, 1, 0, 1, 0, 0, 0],
  32. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  33. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  34. ])
  35. simple_wm = np.dstack([np.zeros(simple_wm.shape),
  36. np.zeros(simple_wm.shape),
  37. simple_wm, simple_wm, simple_wm,
  38. simple_wm, simple_wm, simple_wm,
  39. np.zeros(simple_wm.shape),
  40. np.zeros(simple_wm.shape)])
  41. simple_gm = np.array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  42. [0, 0, 1, 1, 0, 0, 1, 1, 1, 0],
  43. [0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
  44. [0, 1, 0, 0, 0, 0, 0, 0, 1, 0],
  45. [0, 1, 0, 0, 0, 0, 0, 0, 1, 0],
  46. [0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
  47. [1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
  48. [0, 1, 0, 0, 0, 0, 1, 0, 1, 0],
  49. [0, 1, 1, 0, 1, 1, 1, 0, 1, 1],
  50. [0, 0, 0, 1, 0, 0, 0, 1, 1, 0],
  51. ])
  52. simple_gm = np.dstack([np.zeros(simple_gm.shape),
  53. np.zeros(simple_gm.shape),
  54. simple_gm, simple_gm, simple_gm,
  55. simple_gm, simple_gm, simple_gm,
  56. np.zeros(simple_gm.shape),
  57. np.zeros(simple_gm.shape)])
  58. simple_csf = np.ones(simple_wm.shape) - simple_wm - simple_gm
  59. wm_path = join(out_dir, 'tmp_wm.nii.gz')
  60. gm_path = join(out_dir, 'tmp_gm.nii.gz')
  61. csf_path = join(out_dir, 'tmp_csf.nii.gz')
  62. for path, arr in zip([wm_path, gm_path, csf_path],
  63. [simple_wm, simple_gm, simple_csf]):
  64. save_nifti(path, arr.astype(np.uint8), affine)
  65. # CSD Reconstruction
  66. reconst_csd_flow = ReconstCSDFlow()
  67. with warnings.catch_warnings():
  68. warnings.filterwarnings(
  69. "ignore", message=descoteaux07_legacy_msg,
  70. category=PendingDeprecationWarning)
  71. reconst_csd_flow.run(dwi_path, bval_path, bvec_path, mask_path,
  72. out_dir=out_dir, extract_pam_values=True)
  73. pam_path = reconst_csd_flow.last_generated_outputs['out_pam']
  74. gfa_path = reconst_csd_flow.last_generated_outputs['out_gfa']
  75. # Create seeding mask by thresholding the gfa
  76. mask_flow = MaskFlow()
  77. mask_flow.run(gfa_path, 0.8, out_dir=out_dir)
  78. seeds_path = mask_flow.last_generated_outputs['out_mask']
  79. # Test tracking
  80. pf_track_pam = PFTrackingPAMFlow()
  81. assert_equal(pf_track_pam.get_short_name(), 'track_pft')
  82. with warnings.catch_warnings():
  83. warnings.filterwarnings(
  84. "ignore", message=descoteaux07_legacy_msg,
  85. category=PendingDeprecationWarning)
  86. pf_track_pam.run(pam_path, wm_path, gm_path, csf_path, seeds_path)
  87. tractogram_path = \
  88. pf_track_pam.last_generated_outputs['out_tractogram']
  89. assert_false(is_tractogram_empty(tractogram_path))
  90. # Test that tracking returns seeds
  91. pf_track_pam = PFTrackingPAMFlow()
  92. pf_track_pam._force_overwrite = True
  93. with warnings.catch_warnings():
  94. warnings.filterwarnings(
  95. "ignore", message=descoteaux07_legacy_msg,
  96. category=PendingDeprecationWarning)
  97. pf_track_pam.run(pam_path,
  98. wm_path,
  99. gm_path,
  100. csf_path,
  101. seeds_path,
  102. save_seeds=True)
  103. tractogram_path = \
  104. pf_track_pam.last_generated_outputs['out_tractogram']
  105. assert_true(tractogram_has_seeds(tractogram_path))
  106. assert_true(seeds_are_same_space_as_streamlines(tractogram_path))
  107. def test_local_fiber_tracking_workflow():
  108. with TemporaryDirectory() as out_dir:
  109. data_path, bval_path, bvec_path = get_fnames('small_64D')
  110. volume, affine = load_nifti(data_path)
  111. mask = np.ones_like(volume[:, :, :, 0], dtype=np.uint8)
  112. mask_path = join(out_dir, 'tmp_mask.nii.gz')
  113. save_nifti(mask_path, mask, affine)
  114. reconst_csd_flow = ReconstCSDFlow()
  115. with warnings.catch_warnings():
  116. warnings.filterwarnings(
  117. "ignore", message=descoteaux07_legacy_msg,
  118. category=PendingDeprecationWarning)
  119. reconst_csd_flow.run(data_path, bval_path, bvec_path, mask_path,
  120. out_dir=out_dir, extract_pam_values=True)
  121. pam_path = reconst_csd_flow.last_generated_outputs['out_pam']
  122. gfa_path = reconst_csd_flow.last_generated_outputs['out_gfa']
  123. # Create seeding mask by thresholding the gfa
  124. mask_flow = MaskFlow()
  125. mask_flow.run(gfa_path, 0.8, out_dir=out_dir)
  126. seeds_path = mask_flow.last_generated_outputs['out_mask']
  127. mask_path = mask_flow.last_generated_outputs['out_mask']
  128. gfa_img, gfa_affine = load_nifti(gfa_path)
  129. save_nifti(gfa_path, gfa_img, gfa_affine)
  130. # Test tracking with pam no sh
  131. lf_track_pam = LocalFiberTrackingPAMFlow()
  132. lf_track_pam._force_overwrite = True
  133. assert_equal(lf_track_pam.get_short_name(), 'track_local')
  134. lf_track_pam.run(pam_path, gfa_path, seeds_path)
  135. tractogram_path = \
  136. lf_track_pam.last_generated_outputs['out_tractogram']
  137. assert_false(is_tractogram_empty(tractogram_path))
  138. # Test tracking with binary stopping criterion
  139. lf_track_pam = LocalFiberTrackingPAMFlow()
  140. lf_track_pam._force_overwrite = True
  141. lf_track_pam.run(pam_path, mask_path, seeds_path,
  142. use_binary_mask=True)
  143. tractogram_path = \
  144. lf_track_pam.last_generated_outputs['out_tractogram']
  145. assert_false(is_tractogram_empty(tractogram_path))
  146. # Test tracking with pam with sh
  147. lf_track_pam = LocalFiberTrackingPAMFlow()
  148. lf_track_pam._force_overwrite = True
  149. lf_track_pam.run(pam_path, gfa_path, seeds_path,
  150. tracking_method="eudx")
  151. tractogram_path = \
  152. lf_track_pam.last_generated_outputs['out_tractogram']
  153. assert_false(is_tractogram_empty(tractogram_path))
  154. # Test tracking with pam with sh and deterministic getter
  155. lf_track_pam = LocalFiberTrackingPAMFlow()
  156. lf_track_pam._force_overwrite = True
  157. with warnings.catch_warnings():
  158. warnings.filterwarnings(
  159. "ignore", message=descoteaux07_legacy_msg,
  160. category=PendingDeprecationWarning)
  161. lf_track_pam.run(pam_path, gfa_path, seeds_path,
  162. tracking_method="deterministic")
  163. tractogram_path = \
  164. lf_track_pam.last_generated_outputs['out_tractogram']
  165. assert_false(is_tractogram_empty(tractogram_path))
  166. # Test tracking with pam with sh and probabilistic getter
  167. lf_track_pam = LocalFiberTrackingPAMFlow()
  168. lf_track_pam._force_overwrite = True
  169. with warnings.catch_warnings():
  170. warnings.filterwarnings(
  171. "ignore", message=descoteaux07_legacy_msg,
  172. category=PendingDeprecationWarning)
  173. lf_track_pam.run(pam_path, gfa_path, seeds_path,
  174. tracking_method="probabilistic")
  175. tractogram_path = \
  176. lf_track_pam.last_generated_outputs['out_tractogram']
  177. assert_false(is_tractogram_empty(tractogram_path))
  178. # Test tracking with pam with sh and closest peaks getter
  179. lf_track_pam = LocalFiberTrackingPAMFlow()
  180. lf_track_pam._force_overwrite = True
  181. with warnings.catch_warnings():
  182. warnings.filterwarnings(
  183. "ignore", message=descoteaux07_legacy_msg,
  184. category=PendingDeprecationWarning)
  185. lf_track_pam.run(pam_path, gfa_path, seeds_path,
  186. tracking_method="closestpeaks")
  187. tractogram_path = \
  188. lf_track_pam.last_generated_outputs['out_tractogram']
  189. assert_false(is_tractogram_empty(tractogram_path))
  190. # Test that tracking returns seeds
  191. lf_track_pam = LocalFiberTrackingPAMFlow()
  192. lf_track_pam._force_overwrite = True
  193. with warnings.catch_warnings():
  194. warnings.filterwarnings(
  195. "ignore", message=descoteaux07_legacy_msg,
  196. category=PendingDeprecationWarning)
  197. lf_track_pam.run(pam_path, gfa_path, seeds_path,
  198. tracking_method="deterministic",
  199. save_seeds=True)
  200. tractogram_path = \
  201. lf_track_pam.last_generated_outputs['out_tractogram']
  202. assert_true(tractogram_has_seeds(tractogram_path))
  203. assert_true(seeds_are_same_space_as_streamlines(tractogram_path))
  204. def is_tractogram_empty(tractogram_path):
  205. tractogram_file = \
  206. nib.streamlines.load(tractogram_path)
  207. return len(tractogram_file.tractogram) == 0
  208. def tractogram_has_seeds(tractogram_path):
  209. tractogram = \
  210. nib.streamlines.load(tractogram_path).tractogram
  211. return len(tractogram.data_per_streamline['seeds']) > 0
  212. def seeds_are_same_space_as_streamlines(tractogram_path):
  213. sft = load_tractogram(tractogram_path, 'same', bbox_valid_check=False)
  214. seeds = sft.data_per_streamline['seeds']
  215. streamlines = sft.streamlines
  216. for seed, streamline in zip(seeds, streamlines):
  217. map_res = list(map(lambda x: np.allclose(seed, x,
  218. atol=1e-2,
  219. rtol=1e-4), streamline))
  220. # If no point is close to the seed, it likely means that the seed is
  221. # not in the same space as the streamline
  222. if not np.any(map_res):
  223. return False
  224. return True