/dipy/reconst/tests/test_shore_metrics.py

https://github.com/arokem/dipy · Python · 117 lines · 90 code · 17 blank · 10 comment · 5 complexity · d627b5a7f9aa3f418b3179387f816b2f MD5 · raw file

  1. import warnings
  2. import numpy as np
  3. import numpy.testing as npt
  4. from scipy.special import genlaguerre
  5. from dipy.data import get_gtab_taiwan_dsi, get_sphere
  6. from dipy.reconst.shm import descoteaux07_legacy_msg
  7. from dipy.reconst.shore import (ShoreModel,
  8. shore_matrix,
  9. shore_indices,
  10. shore_order)
  11. from dipy.sims.voxel import (multi_tensor, multi_tensor_rtop,
  12. multi_tensor_msd, multi_tensor_pdf)
  13. def test_shore_metrics():
  14. gtab = get_gtab_taiwan_dsi()
  15. mevals = np.array(([0.0015, 0.0003, 0.0003],
  16. [0.0015, 0.0003, 0.0003]))
  17. angl = [(0, 0), (60, 0)]
  18. S, _ = multi_tensor(gtab, mevals, S0=100.0, angles=angl,
  19. fractions=[50, 50], snr=None)
  20. # test shore_indices
  21. n = 7
  22. l = 6
  23. m = -4
  24. radial_order, c = shore_order(n, l, m)
  25. n2, l2, m2 = shore_indices(radial_order, c)
  26. npt.assert_equal(n, n2)
  27. npt.assert_equal(l, l2)
  28. npt.assert_equal(m, m2)
  29. radial_order = 6
  30. c = 41
  31. n, l, m = shore_indices(radial_order, c)
  32. radial_order2, c2 = shore_order(n, l, m)
  33. npt.assert_equal(radial_order, radial_order2)
  34. npt.assert_equal(c, c2)
  35. npt.assert_raises(ValueError, shore_indices, 6, 100)
  36. npt.assert_raises(ValueError, shore_order, m, n, l)
  37. # since we are testing without noise we can use higher order and lower
  38. # lambdas, with respect to the default.
  39. radial_order = 8
  40. zeta = 700
  41. lambdaN = 1e-12
  42. lambdaL = 1e-12
  43. asm = ShoreModel(gtab, radial_order=radial_order,
  44. zeta=zeta, lambdaN=lambdaN, lambdaL=lambdaL)
  45. with warnings.catch_warnings():
  46. warnings.filterwarnings(
  47. "ignore", message=descoteaux07_legacy_msg,
  48. category=PendingDeprecationWarning)
  49. asmfit = asm.fit(S)
  50. c_shore = asmfit.shore_coeff
  51. with warnings.catch_warnings():
  52. warnings.filterwarnings(
  53. "ignore", message=descoteaux07_legacy_msg,
  54. category=PendingDeprecationWarning)
  55. cmat = shore_matrix(radial_order, zeta, gtab)
  56. S_reconst = np.dot(cmat, c_shore)
  57. # test the signal reconstruction
  58. S = S / S[0]
  59. nmse_signal = np.sqrt(np.sum((S - S_reconst) ** 2)) / (S.sum())
  60. npt.assert_almost_equal(nmse_signal, 0.0, 4)
  61. # test if the analytical integral of the pdf is equal to one
  62. integral = 0
  63. for n in range(int(radial_order/2 + 1)):
  64. integral += c_shore[n] * (np.pi**(-1.5) * zeta ** (-1.5) *
  65. genlaguerre(n, 0.5)(0)) ** 0.5
  66. npt.assert_almost_equal(integral, 1.0, 10)
  67. # test if the integral of the pdf calculated on a discrete grid is
  68. # equal to one
  69. with warnings.catch_warnings():
  70. warnings.filterwarnings(
  71. "ignore", message=descoteaux07_legacy_msg,
  72. category=PendingDeprecationWarning)
  73. pdf_discrete = asmfit.pdf_grid(17, 40e-3)
  74. integral = pdf_discrete.sum()
  75. npt.assert_almost_equal(integral, 1.0, 1)
  76. # compare the shore pdf with the ground truth multi_tensor pdf
  77. sphere = get_sphere('symmetric724')
  78. v = sphere.vertices
  79. radius = 10e-3
  80. with warnings.catch_warnings():
  81. warnings.filterwarnings(
  82. "ignore", message=descoteaux07_legacy_msg,
  83. category=PendingDeprecationWarning)
  84. pdf_shore = asmfit.pdf(v * radius)
  85. pdf_mt = multi_tensor_pdf(v * radius, mevals=mevals,
  86. angles=angl, fractions=[50, 50])
  87. nmse_pdf = np.sqrt(np.sum((pdf_mt - pdf_shore) ** 2)) / (pdf_mt.sum())
  88. npt.assert_almost_equal(nmse_pdf, 0.0, 2)
  89. # compare the shore rtop with the ground truth multi_tensor rtop
  90. rtop_shore_signal = asmfit.rtop_signal()
  91. rtop_shore_pdf = asmfit.rtop_pdf()
  92. npt.assert_almost_equal(rtop_shore_signal, rtop_shore_pdf, 9)
  93. rtop_mt = multi_tensor_rtop([.5, .5], mevals=mevals)
  94. npt.assert_(rtop_mt / rtop_shore_signal > 0.95)
  95. npt.assert_(rtop_mt / rtop_shore_signal < 1.10)
  96. # compare the shore msd with the ground truth multi_tensor msd
  97. msd_mt = multi_tensor_msd([.5, .5], mevals=mevals)
  98. msd_shore = asmfit.msd()
  99. npt.assert_(msd_mt / msd_shore > 0.95)
  100. npt.assert_(msd_mt / msd_shore < 1.05)