/dipy/reconst/tests/test_shore_odf.py

https://github.com/arokem/dipy · Python · 96 lines · 80 code · 12 blank · 4 comment · 8 complexity · b523fa82957a35953d471e7c52e28a63 MD5 · raw file

  1. import warnings
  2. import numpy as np
  3. import numpy.testing as npt
  4. from dipy.data import default_sphere, get_isbi2013_2shell_gtab, get_3shell_gtab
  5. from dipy.reconst.shore import ShoreModel, shore_matrix
  6. from dipy.reconst.shm import sh_to_sf, descoteaux07_legacy_msg
  7. from dipy.direction.peaks import peak_directions
  8. from dipy.reconst.odf import gfa
  9. from dipy.sims.voxel import sticks_and_ball
  10. from dipy.core.subdivide_octahedron import create_unit_sphere
  11. from dipy.core.sphere_stats import angular_similarity
  12. from dipy.reconst.tests.test_dsi import sticks_and_ball_dummies
  13. def test_shore_odf():
  14. gtab = get_isbi2013_2shell_gtab()
  15. # load repulsion 724 sphere
  16. sphere = default_sphere
  17. # load icosahedron sphere
  18. sphere2 = create_unit_sphere(5)
  19. data, golden_directions = sticks_and_ball(gtab, d=0.0015, S0=100,
  20. angles=[(0, 0), (90, 0)],
  21. fractions=[50, 50], snr=None)
  22. asm = ShoreModel(gtab, radial_order=6,
  23. zeta=700, lambdaN=1e-8, lambdaL=1e-8)
  24. # repulsion724
  25. with warnings.catch_warnings():
  26. warnings.filterwarnings(
  27. "ignore", message=descoteaux07_legacy_msg,
  28. category=PendingDeprecationWarning)
  29. asmfit = asm.fit(data)
  30. odf = asmfit.odf(sphere)
  31. odf_sh = asmfit.odf_sh()
  32. with warnings.catch_warnings():
  33. warnings.filterwarnings(
  34. "ignore", message=descoteaux07_legacy_msg,
  35. category=PendingDeprecationWarning)
  36. odf_from_sh = sh_to_sf(odf_sh, sphere, 6, basis_type=None,
  37. legacy=True)
  38. npt.assert_almost_equal(odf, odf_from_sh, 10)
  39. with warnings.catch_warnings():
  40. warnings.filterwarnings(
  41. "ignore", message=descoteaux07_legacy_msg,
  42. category=PendingDeprecationWarning)
  43. expected_phi = shore_matrix(radial_order=6, zeta=700, gtab=gtab)
  44. npt.assert_array_almost_equal(np.dot(expected_phi, asmfit.shore_coeff),
  45. asmfit.fitted_signal())
  46. directions, _, _ = peak_directions(odf, sphere, .35, 25)
  47. npt.assert_equal(len(directions), 2)
  48. npt.assert_almost_equal(
  49. angular_similarity(directions, golden_directions), 2, 1)
  50. # 5 subdivisions
  51. with warnings.catch_warnings():
  52. warnings.filterwarnings(
  53. "ignore", message=descoteaux07_legacy_msg,
  54. category=PendingDeprecationWarning)
  55. odf = asmfit.odf(sphere2)
  56. directions, _, _ = peak_directions(odf, sphere2, .35, 25)
  57. npt.assert_equal(len(directions), 2)
  58. npt.assert_almost_equal(
  59. angular_similarity(directions, golden_directions), 2, 1)
  60. sb_dummies = sticks_and_ball_dummies(gtab)
  61. for sbd in sb_dummies:
  62. data, golden_directions = sb_dummies[sbd]
  63. asmfit = asm.fit(data)
  64. odf = asmfit.odf(sphere2)
  65. directions, _, _ = peak_directions(odf, sphere2, .35, 25)
  66. if len(directions) <= 3:
  67. npt.assert_equal(len(directions), len(golden_directions))
  68. if len(directions) > 3:
  69. npt.assert_equal(gfa(odf) < 0.1, True)
  70. def test_multivox_shore():
  71. gtab = get_3shell_gtab()
  72. data = np.random.random([20, 30, 1, gtab.gradients.shape[0]])
  73. radial_order = 4
  74. zeta = 700
  75. asm = ShoreModel(gtab, radial_order=radial_order,
  76. zeta=zeta, lambdaN=1e-8, lambdaL=1e-8)
  77. with warnings.catch_warnings():
  78. warnings.filterwarnings(
  79. "ignore", message=descoteaux07_legacy_msg,
  80. category=PendingDeprecationWarning)
  81. asmfit = asm.fit(data)
  82. c_shore = asmfit.shore_coeff
  83. npt.assert_equal(c_shore.shape[0:3], data.shape[0:3])
  84. npt.assert_equal(np.alltrue(np.isreal(c_shore)), True)