/dipy/reconst/tests/test_shore.py

https://github.com/arokem/dipy · Python · 105 lines · 82 code · 21 blank · 2 comment · 4 complexity · 553a42f8a22d965b4d3ebae46d367c70 MD5 · raw file

  1. # Tests for shore fitting
  2. import warnings
  3. from math import factorial
  4. import numpy as np
  5. import numpy.testing as npt
  6. from scipy.special import genlaguerre, gamma
  7. from dipy.data import get_gtab_taiwan_dsi
  8. from dipy.reconst.shm import descoteaux07_legacy_msg
  9. from dipy.reconst.shore import ShoreModel
  10. from dipy.sims.voxel import multi_tensor
  11. import pytest
  12. from dipy.utils.optpkg import optional_package
  13. cvxpy, have_cvxpy, _ = optional_package("cvxpy")
  14. needs_cvxpy = pytest.mark.skipif(not have_cvxpy, reason="Requires CVXPY")
  15. # Object to hold module global data
  16. class _C(object):
  17. pass
  18. data = _C()
  19. def setup():
  20. data.gtab = get_gtab_taiwan_dsi()
  21. data.mevals = np.array(([0.0015, 0.0003, 0.0003],
  22. [0.0015, 0.0003, 0.0003]))
  23. data.angl = [(0, 0), (60, 0)]
  24. data.S, sticks = multi_tensor(data.gtab, data.mevals, S0=100.0,
  25. angles=data.angl, fractions=[50, 50],
  26. snr=None)
  27. data.radial_order = 6
  28. data.zeta = 700
  29. data.lambdaN = 1e-12
  30. data.lambdaL = 1e-12
  31. def test_shore_error():
  32. data.gtab = get_gtab_taiwan_dsi()
  33. npt.assert_raises(ValueError, ShoreModel, data.gtab, radial_order=-4)
  34. npt.assert_raises(ValueError, ShoreModel, data.gtab, radial_order=7)
  35. npt.assert_raises(ValueError, ShoreModel, data.gtab, constrain_e0=False,
  36. positive_constraint=True)
  37. @needs_cvxpy
  38. def test_shore_positive_constrain():
  39. asm = ShoreModel(data.gtab,
  40. radial_order=data.radial_order,
  41. zeta=data.zeta,
  42. lambdaN=data.lambdaN,
  43. lambdaL=data.lambdaL,
  44. constrain_e0=True,
  45. positive_constraint=True,
  46. pos_grid=11,
  47. pos_radius=20e-03)
  48. with warnings.catch_warnings():
  49. warnings.filterwarnings(
  50. "ignore", message=descoteaux07_legacy_msg,
  51. category=PendingDeprecationWarning)
  52. asmfit = asm.fit(data.S)
  53. eap = asmfit.pdf_grid(11, 20e-03)
  54. npt.assert_almost_equal(eap[eap < 0].sum(), 0, 3)
  55. def test_shore_fitting_no_constrain_e0():
  56. asm = ShoreModel(data.gtab, radial_order=data.radial_order,
  57. zeta=data.zeta, lambdaN=data.lambdaN,
  58. lambdaL=data.lambdaL)
  59. with warnings.catch_warnings():
  60. warnings.filterwarnings(
  61. "ignore", message=descoteaux07_legacy_msg,
  62. category=PendingDeprecationWarning)
  63. asmfit = asm.fit(data.S)
  64. npt.assert_almost_equal(compute_e0(asmfit), 1)
  65. @needs_cvxpy
  66. def test_shore_fitting_constrain_e0():
  67. asm = ShoreModel(data.gtab, radial_order=data.radial_order,
  68. zeta=data.zeta, lambdaN=data.lambdaN,
  69. lambdaL=data.lambdaL,
  70. constrain_e0=True)
  71. with warnings.catch_warnings():
  72. warnings.filterwarnings(
  73. "ignore", message=descoteaux07_legacy_msg,
  74. category=PendingDeprecationWarning)
  75. asmfit = asm.fit(data.S)
  76. npt.assert_almost_equal(compute_e0(asmfit), 1)
  77. def compute_e0(shorefit):
  78. signal_0 = 0
  79. for n in range(int(shorefit.model.radial_order / 2) + 1):
  80. signal_0 += (shorefit.shore_coeff[n] * (genlaguerre(n, 0.5)(0) *
  81. ((factorial(n)) / (2 * np.pi *
  82. (shorefit.model.zeta ** 1.5) *
  83. gamma(n + 1.5))) ** 0.5))
  84. return signal_0