PageRenderTime 48ms CodeModel.GetById 22ms RepoModel.GetById 1ms app.codeStats 0ms

/dipy/reconst/tests/test_rumba.py

https://github.com/nipy/dipy
Python | 394 lines | 266 code | 85 blank | 43 comment | 15 complexity | 0a40eeb66935ab4c47232d09183b5b04 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import warnings
  2. import numpy as np
  3. from numpy.testing import (assert_equal,
  4. assert_almost_equal,
  5. assert_array_equal,
  6. assert_allclose,
  7. assert_raises,)
  8. from numpy.testing import assert_
  9. from dipy.reconst.rumba import RumbaSDModel, generate_kernel
  10. from dipy.reconst.csdeconv import AxSymShResponse
  11. from dipy.data import get_fnames, dsi_voxels, default_sphere, get_sphere
  12. from dipy.core.gradients import gradient_table
  13. from dipy.core.geometry import cart2sphere
  14. from dipy.core.sphere_stats import angular_similarity
  15. from dipy.reconst.tests.test_dsi import sticks_and_ball_dummies
  16. from dipy.sims.voxel import sticks_and_ball, multi_tensor, single_tensor
  17. from dipy.direction.peaks import peak_directions
  18. from dipy.reconst.shm import descoteaux07_legacy_msg
  19. def test_rumba():
  20. # Test fODF results from ideal examples.
  21. sphere = default_sphere # repulsion 724
  22. sphere2 = get_sphere('symmetric362')
  23. btable = np.loadtxt(get_fnames('dsi515btable'))
  24. bvals = btable[:, 0]
  25. bvecs = btable[:, 1:]
  26. gtab = gradient_table(bvals, bvecs)
  27. data, golden_directions = sticks_and_ball(gtab, d=0.0015, S0=100,
  28. angles=[(0, 0), (90, 0)],
  29. fractions=[50, 50], snr=None)
  30. # Testing input validation
  31. gtab_broken = gradient_table(
  32. bvals[~gtab.b0s_mask], bvecs[~gtab.b0s_mask])
  33. assert_raises(ValueError, RumbaSDModel, gtab_broken)
  34. with warnings.catch_warnings(record=True) as w:
  35. _ = RumbaSDModel(gtab, verbose=True)
  36. assert_equal(len(w), 1)
  37. assert_(w[0].category, UserWarning)
  38. assert_raises(ValueError, RumbaSDModel, gtab, use_tv=True)
  39. assert_raises(ValueError, RumbaSDModel, gtab, n_iter=0)
  40. rumba_broken = RumbaSDModel(gtab, recon_type='test')
  41. assert_raises(ValueError, rumba_broken.fit, data)
  42. # Models to validate
  43. rumba_smf = RumbaSDModel(gtab, n_iter=20, recon_type='smf', n_coils=1,
  44. sphere=sphere)
  45. rumba_sos = RumbaSDModel(gtab, n_iter=20, recon_type='sos', n_coils=32,
  46. sphere=sphere)
  47. model_list = [rumba_smf, rumba_sos]
  48. # Test on repulsion724 sphere
  49. for model in model_list:
  50. model_fit = model.fit(data)
  51. # Verify only works on original sphere
  52. assert_raises(ValueError, model_fit.odf, sphere2)
  53. odf = model_fit.odf(sphere)
  54. directions, _, _ = peak_directions(odf, sphere, .35, 25)
  55. assert_equal(len(directions), 2)
  56. assert_almost_equal(angular_similarity(directions, golden_directions),
  57. 2, 1)
  58. # Test on data with 1, 2, 3, or no peaks
  59. sb_dummies = sticks_and_ball_dummies(gtab)
  60. for model in model_list:
  61. for sbd in sb_dummies:
  62. data, golden_directions = sb_dummies[sbd]
  63. model_fit = model.fit(data)
  64. odf = model_fit.odf(sphere)
  65. directions, _, _ = peak_directions(
  66. odf, sphere, .35, 25)
  67. if len(directions) <= 3:
  68. # Verify small isotropic fraction in anisotropic case
  69. assert_equal(model_fit.f_iso < 0.1, True)
  70. assert_equal(len(directions), len(golden_directions))
  71. if len(directions) > 3:
  72. # Verify large isotropic fraction in isotropic case
  73. assert_equal(model_fit.f_iso > 0.8, True)
  74. def test_predict():
  75. # Test signal reconstruction on ideal example
  76. sphere = default_sphere
  77. btable = np.loadtxt(get_fnames('dsi515btable'))
  78. bvals = btable[:, 0]
  79. bvecs = btable[:, 1:]
  80. gtab = gradient_table(bvals, bvecs)
  81. rumba = RumbaSDModel(gtab, n_iter=600, sphere=sphere)
  82. # Simulated data
  83. data = single_tensor(gtab, S0=1, evals=rumba.wm_response)
  84. rumba_fit = rumba.fit(data)
  85. data_pred = rumba_fit.predict()
  86. assert_allclose(data_pred, data, atol=0.01, rtol=0.05)
  87. def test_recursive_rumba():
  88. # Test with recursive data-driven response
  89. sphere = default_sphere # repulsion 724
  90. btable = np.loadtxt(get_fnames('dsi515btable'))
  91. bvals = btable[:, 0]
  92. bvecs = btable[:, 1:]
  93. gtab = gradient_table(bvals, bvecs)
  94. data, golden_directions = sticks_and_ball(gtab, d=0.0015, S0=100,
  95. angles=[(0, 0), (90, 0)],
  96. fractions=[50, 50], snr=None)
  97. wm_response = AxSymShResponse(480, np.array([570.35065982,
  98. -262.81741086,
  99. 80.23104069,
  100. -16.93940972,
  101. 2.57628738]))
  102. model = RumbaSDModel(gtab, wm_response, n_iter=20, sphere=sphere)
  103. with warnings.catch_warnings():
  104. warnings.filterwarnings(
  105. "ignore", message=descoteaux07_legacy_msg,
  106. category=PendingDeprecationWarning)
  107. model_fit = model.fit(data)
  108. # Test peaks
  109. odf = model_fit.odf(sphere)
  110. directions, _, _ = peak_directions(odf, sphere, .35, 25)
  111. assert_equal(len(directions), 2)
  112. assert_almost_equal(angular_similarity(directions, golden_directions),
  113. 2, 1)
  114. def test_multishell_rumba():
  115. # Test with multi-shell response
  116. sphere = default_sphere # repulsion 724
  117. btable = np.loadtxt(get_fnames('dsi515btable'))
  118. bvals = btable[:, 0]
  119. bvecs = btable[:, 1:]
  120. gtab = gradient_table(bvals, bvecs)
  121. data, golden_directions = sticks_and_ball(gtab, d=0.0015, S0=100,
  122. angles=[(0, 0), (90, 0)],
  123. fractions=[50, 50], snr=None)
  124. wm_response = np.tile(np.array([1.7E-3, 0.2E-3, 0.2E-3]), (22, 1))
  125. model = RumbaSDModel(gtab, wm_response, n_iter=20, sphere=sphere)
  126. model_fit = model.fit(data)
  127. # Test peaks
  128. odf = model_fit.odf(sphere)
  129. directions, _, _ = peak_directions(odf, sphere, .35, 25)
  130. assert_equal(len(directions), 2)
  131. assert_almost_equal(angular_similarity(directions, golden_directions),
  132. 2, 1)
  133. def test_mvoxel_rumba():
  134. # Verify form of results in multi-voxel situation.
  135. data, gtab = dsi_voxels() # multi-voxel data
  136. sphere = default_sphere # repulsion 724
  137. # Models to validate
  138. rumba_smf = RumbaSDModel(gtab, n_iter=5, recon_type='smf', n_coils=1,
  139. sphere=sphere)
  140. rumba_sos = RumbaSDModel(gtab, n_iter=5, recon_type='sos', n_coils=32,
  141. sphere=sphere)
  142. model_list = [rumba_smf, rumba_sos]
  143. for model in model_list:
  144. model_fit = model.fit(data)
  145. odf = model_fit.odf(sphere)
  146. f_iso = model_fit.f_iso
  147. f_wm = model_fit.f_wm
  148. f_gm = model_fit.f_gm
  149. f_csf = model_fit.f_csf
  150. combined = model_fit.combined_odf_iso
  151. # Verify prediction properties
  152. pred_sig_1 = model_fit.predict()
  153. pred_sig_2 = model_fit.predict(S0=1)
  154. pred_sig_3 = model_fit.predict(S0=np.ones(odf.shape[:-1]))
  155. pred_sig_4 = model_fit.predict(gtab=gtab)
  156. assert_equal(pred_sig_1, pred_sig_2)
  157. assert_equal(pred_sig_3, pred_sig_4)
  158. assert_equal(pred_sig_1, pred_sig_3)
  159. assert_equal(data.shape, pred_sig_1.shape)
  160. assert_equal(np.alltrue(np.isreal(pred_sig_1)), True)
  161. assert_equal(np.alltrue(pred_sig_1 > 0), True)
  162. # Verify shape, positivity, realness of results
  163. assert_equal(data.shape[:-1] + (len(sphere.vertices),), odf.shape)
  164. assert_equal(np.alltrue(np.isreal(odf)), True)
  165. assert_equal(np.alltrue(odf > 0), True)
  166. assert_equal(data.shape[:-1], f_iso.shape)
  167. assert_equal(np.alltrue(np.isreal(f_iso)), True)
  168. assert_equal(np.alltrue(f_iso > 0), True)
  169. # Verify properties of fODF and volume fractions
  170. assert_equal(f_iso, f_gm + f_csf)
  171. assert_equal(combined, odf + f_iso[..., None] / len(sphere.vertices))
  172. assert_almost_equal(f_iso + f_wm, np.ones(f_iso.shape))
  173. assert_almost_equal(np.sum(combined, axis=3), np.ones(f_iso.shape))
  174. assert_equal(np.sum(odf, axis=3), f_wm)
  175. def test_global_fit():
  176. # Test fODF results on ideal examples in global fitting paradigm.
  177. sphere = default_sphere # repulsion 724
  178. btable = np.loadtxt(get_fnames('dsi515btable'))
  179. bvals = btable[:, 0]
  180. bvecs = btable[:, 1:]
  181. gtab = gradient_table(bvals, bvecs)
  182. data, golden_directions = sticks_and_ball(gtab, d=0.0015, S0=100,
  183. angles=[(0, 0), (90, 0)],
  184. fractions=[50, 50], snr=None)
  185. # global_fit requires 4D argument
  186. data = data[None, None, None, :]
  187. # TV requires non-singleton size in all volume dimensions
  188. data_mvoxel = np.tile(data, (2, 2, 2, 1))
  189. # Model to validate
  190. rumba = RumbaSDModel(gtab, n_iter=20, recon_type='smf', n_coils=1, R=2,
  191. voxelwise=False, sphere=sphere)
  192. rumba_tv = RumbaSDModel(gtab, n_iter=20, recon_type='smf', n_coils=1, R=2,
  193. voxelwise=False, use_tv=True, sphere=sphere)
  194. # Testing input validation
  195. assert_raises(ValueError, rumba.fit, data[:, :, :, 0]) # Must be 4D
  196. # TV can't work with singleton dimensions in data volume
  197. assert_raises(ValueError, rumba_tv.fit, data)
  198. # Mask must match first 3 dimensions of data
  199. assert_raises(ValueError, rumba.fit, data, mask=np.ones(data.shape))
  200. # Recon type validation
  201. rumba_broken = RumbaSDModel(gtab, recon_type='test', voxelwise=False)
  202. assert_raises(ValueError, rumba_broken.fit, data)
  203. # Test on repulsion 724 sphere, with/wihout TV regularization
  204. for ix, model in enumerate([rumba, rumba_tv]):
  205. if ix:
  206. model_fit = model.fit(data_mvoxel)
  207. else:
  208. model_fit = model.fit(data)
  209. odf = model_fit.odf(sphere)
  210. directions, _, _ = peak_directions(
  211. odf[0, 0, 0], sphere, .35, 25)
  212. assert_equal(len(directions), 2)
  213. assert_almost_equal(angular_similarity(directions, golden_directions),
  214. 2, 1)
  215. # Test on data with 1, 2, 3, or no peaks
  216. sb_dummies = sticks_and_ball_dummies(gtab)
  217. for sbd in sb_dummies:
  218. data, golden_directions = sb_dummies[sbd]
  219. data = data[None, None, None, :] # make 4D
  220. rumba_fit = rumba.fit(data)
  221. odf = rumba_fit.odf(sphere)
  222. f_iso = rumba_fit.f_iso
  223. directions, _, _ = peak_directions(
  224. odf[0, 0, 0], sphere, .35, 25)
  225. if len(directions) <= 3:
  226. # Verify small isotropic fraction in anisotropic case
  227. assert_equal(f_iso[0, 0, 0] < 0.1, True)
  228. assert_equal(len(directions), len(golden_directions))
  229. if len(directions) > 3:
  230. # Verify large isotropic fraction in isotropic case
  231. assert_equal(f_iso[0, 0, 0] > 0.8, True)
  232. def test_mvoxel_global_fit():
  233. # Verify form of results in global fitting paradigm.
  234. data, gtab = dsi_voxels() # multi-voxel data
  235. sphere = default_sphere # repulsion 724
  236. # Models to validate
  237. rumba_sos = RumbaSDModel(gtab, recon_type='sos', n_iter=5, n_coils=32, R=1,
  238. voxelwise=False, verbose=True, sphere=sphere)
  239. rumba_sos_tv = RumbaSDModel(gtab, recon_type='sos', n_iter=5, n_coils=32,
  240. R=1, voxelwise=False, use_tv=True,
  241. sphere=sphere)
  242. rumba_r = RumbaSDModel(gtab, recon_type='smf', n_iter=5, n_coils=1, R=2,
  243. voxelwise=False, sphere=sphere)
  244. rumba_r_tv = RumbaSDModel(gtab, recon_type='smf', n_iter=5, n_coils=1, R=2,
  245. voxelwise=False, use_tv=True, sphere=sphere)
  246. model_list = [rumba_sos, rumba_sos_tv, rumba_r, rumba_r_tv]
  247. # Test each model with/without TV regularization
  248. for model in model_list:
  249. model_fit = model.fit(data)
  250. odf = model_fit.odf(sphere)
  251. f_iso = model_fit.f_iso
  252. f_wm = model_fit.f_wm
  253. f_gm = model_fit.f_gm
  254. f_csf = model_fit.f_csf
  255. combined = model_fit.combined_odf_iso
  256. # Verify shape, positivity, realness of results
  257. assert_equal(data.shape[:-1] + (len(sphere.vertices),), odf.shape)
  258. assert_equal(np.alltrue(np.isreal(odf)), True)
  259. assert_equal(np.alltrue(odf > 0), True)
  260. assert_equal(data.shape[:-1], f_iso.shape)
  261. assert_equal(np.alltrue(np.isreal(f_iso)), True)
  262. assert_equal(np.alltrue(f_iso > 0), True)
  263. # Verify normalization
  264. assert_equal(f_iso, f_gm + f_csf)
  265. assert_equal(combined, odf +
  266. f_iso[..., None] / len(sphere.vertices))
  267. assert_almost_equal(f_iso + f_wm, np.ones(f_iso.shape))
  268. assert_almost_equal(np.sum(combined, axis=3), np.ones(f_iso.shape))
  269. assert_equal(np.sum(odf, axis=3), f_wm)
  270. def test_generate_kernel():
  271. # Test form and content of kernel generation result.
  272. # load repulsion 724 sphere
  273. sphere = default_sphere
  274. btable = np.loadtxt(get_fnames('dsi515btable'))
  275. bvals = btable[:, 0]
  276. bvecs = btable[:, 1:]
  277. gtab = gradient_table(bvals, bvecs)
  278. # Kernel parameters
  279. wm_response = np.array([1.7e-3, 0.2e-3, 0.2e-3])
  280. gm_response = 0.2e-4
  281. csf_response = 3.0e-3
  282. # Test kernel shape
  283. kernel = generate_kernel(
  284. gtab, sphere, wm_response, gm_response, csf_response)
  285. assert_equal(kernel.shape, (len(gtab.bvals), len(sphere.vertices) + 2))
  286. # Verify first column of kernel
  287. _, theta, phi = cart2sphere(
  288. sphere.x,
  289. sphere.y,
  290. sphere.z
  291. )
  292. S0 = 1 # S0 assumed to be 1
  293. fi = 100 # volume fraction assumed to be 100%
  294. S, _ = multi_tensor(gtab, np.array([wm_response]),
  295. S0, [[theta[0] * 180 / np.pi, phi[0] * 180 / np.pi]],
  296. [fi],
  297. None)
  298. assert_almost_equal(kernel[:, 0], S)
  299. # Multi-shell version
  300. wm_response_multi = np.tile(wm_response, (22, 1))
  301. kernel_multi = generate_kernel(
  302. gtab, sphere, wm_response_multi, gm_response, csf_response)
  303. assert_equal(kernel.shape, (len(gtab.bvals), len(sphere.vertices) + 2))
  304. assert_almost_equal(kernel, kernel_multi)
  305. # Test optional isotropic compartment; should cause last column of zeroes
  306. kernel = generate_kernel(
  307. gtab, sphere, wm_response, gm_response=None, csf_response=None)
  308. assert_array_equal(kernel[:, -2], np.zeros(len(gtab.bvals)))
  309. assert_array_equal(kernel[:, -1], np.zeros(len(gtab.bvals)))