/dipy/reconst/tests/test_forecast.py

https://github.com/arokem/dipy · Python · 300 lines · 245 code · 48 blank · 7 comment · 22 complexity · 1265599c921a427e9cd4869c11c2da9c MD5 · raw file

  1. # Tests for FORECAST fitting and metrics
  2. import warnings
  3. import numpy as np
  4. from dipy.data import get_sphere, default_sphere, get_3shell_gtab
  5. from dipy.reconst.forecast import ForecastModel
  6. from dipy.reconst.shm import descoteaux07_legacy_msg
  7. from dipy.sims.voxel import multi_tensor
  8. from numpy.testing import assert_almost_equal, assert_equal
  9. import pytest
  10. from dipy.direction.peaks import peak_directions
  11. from dipy.core.sphere_stats import angular_similarity
  12. from dipy.utils.optpkg import optional_package
  13. cvxpy, have_cvxpy, _ = optional_package("cvxpy")
  14. needs_cvxpy = pytest.mark.skipif(not have_cvxpy, reason="Requires CVXPY")
  15. # Object to hold module global data
  16. class _C(object):
  17. pass
  18. data = _C()
  19. def setup_module():
  20. global data
  21. data.gtab = get_3shell_gtab()
  22. data.mevals = np.array(([0.0017, 0.0003, 0.0003],
  23. [0.0017, 0.0003, 0.0003]))
  24. data.angl = [(0, 0), (60, 0)]
  25. data.S, data.sticks = multi_tensor(data.gtab, data.mevals, S0=100.0,
  26. angles=data.angl, fractions=[50, 50],
  27. snr=None)
  28. data.sh_order = 6
  29. data.lambda_lb = 1e-8
  30. data.lambda_csd = 1.0
  31. sphere = get_sphere('repulsion100')
  32. data.sphere = sphere.vertices[0:int(sphere.vertices.shape[0]/2), :]
  33. @needs_cvxpy
  34. def test_forecast_positive_constrain():
  35. with warnings.catch_warnings():
  36. warnings.filterwarnings(
  37. "ignore", message=descoteaux07_legacy_msg,
  38. category=PendingDeprecationWarning)
  39. fm = ForecastModel(data.gtab,
  40. sh_order=data.sh_order,
  41. lambda_lb=data.lambda_lb,
  42. dec_alg='POS',
  43. sphere=data.sphere)
  44. f_fit = fm.fit(data.S)
  45. sphere = get_sphere('repulsion100')
  46. with warnings.catch_warnings():
  47. warnings.filterwarnings(
  48. "ignore", message=descoteaux07_legacy_msg,
  49. category=PendingDeprecationWarning)
  50. fodf = f_fit.odf(sphere, clip_negative=False)
  51. assert_almost_equal(fodf[fodf < 0].sum(), 0, 2)
  52. coeff = f_fit.sh_coeff
  53. c0 = np.sqrt(1.0/(4*np.pi))
  54. assert_almost_equal(coeff[0], c0, 5)
  55. def test_forecast_csd():
  56. sphere = get_sphere('repulsion100')
  57. with warnings.catch_warnings():
  58. warnings.filterwarnings(
  59. "ignore", message=descoteaux07_legacy_msg,
  60. category=PendingDeprecationWarning)
  61. fm = ForecastModel(data.gtab, dec_alg='CSD',
  62. sphere=data.sphere, lambda_csd=data.lambda_csd)
  63. f_fit = fm.fit(data.S)
  64. with warnings.catch_warnings():
  65. warnings.filterwarnings(
  66. "ignore", message=descoteaux07_legacy_msg,
  67. category=PendingDeprecationWarning)
  68. fodf_csd = f_fit.odf(sphere, clip_negative=False)
  69. with warnings.catch_warnings():
  70. warnings.filterwarnings(
  71. "ignore", message=descoteaux07_legacy_msg,
  72. category=PendingDeprecationWarning)
  73. fm = ForecastModel(data.gtab, sh_order=data.sh_order,
  74. lambda_lb=data.lambda_lb, dec_alg='WLS')
  75. f_fit = fm.fit(data.S)
  76. with warnings.catch_warnings():
  77. warnings.filterwarnings(
  78. "ignore", message=descoteaux07_legacy_msg,
  79. category=PendingDeprecationWarning)
  80. fodf_wls = f_fit.odf(sphere, clip_negative=False)
  81. value = fodf_wls[fodf_wls < 0].sum() < fodf_csd[fodf_csd < 0].sum()
  82. assert_equal(value, 1)
  83. def test_forecast_odf():
  84. # check FORECAST fODF at different SH order
  85. with warnings.catch_warnings():
  86. warnings.filterwarnings(
  87. "ignore", message=descoteaux07_legacy_msg,
  88. category=PendingDeprecationWarning)
  89. fm = ForecastModel(data.gtab, sh_order=4,
  90. dec_alg='CSD', sphere=data.sphere)
  91. f_fit = fm.fit(data.S)
  92. sphere = default_sphere
  93. with warnings.catch_warnings():
  94. warnings.filterwarnings(
  95. "ignore", message=descoteaux07_legacy_msg,
  96. category=PendingDeprecationWarning)
  97. fodf = f_fit.odf(sphere)
  98. directions, _, _ = peak_directions(fodf, sphere, .35, 25)
  99. assert_equal(len(directions), 2)
  100. assert_almost_equal(
  101. angular_similarity(directions, data.sticks), 2, 1)
  102. with warnings.catch_warnings():
  103. warnings.filterwarnings(
  104. "ignore", message=descoteaux07_legacy_msg,
  105. category=PendingDeprecationWarning)
  106. fm = ForecastModel(data.gtab, sh_order=6,
  107. dec_alg='CSD', sphere=data.sphere)
  108. f_fit = fm.fit(data.S)
  109. with warnings.catch_warnings():
  110. warnings.filterwarnings(
  111. "ignore", message=descoteaux07_legacy_msg,
  112. category=PendingDeprecationWarning)
  113. fodf = f_fit.odf(sphere)
  114. directions, _, _ = peak_directions(fodf, sphere, .35, 25)
  115. assert_equal(len(directions), 2)
  116. assert_almost_equal(
  117. angular_similarity(directions, data.sticks), 2, 1)
  118. with warnings.catch_warnings():
  119. warnings.filterwarnings(
  120. "ignore", message=descoteaux07_legacy_msg,
  121. category=PendingDeprecationWarning)
  122. fm = ForecastModel(data.gtab, sh_order=8,
  123. dec_alg='CSD', sphere=data.sphere)
  124. f_fit = fm.fit(data.S)
  125. with warnings.catch_warnings():
  126. warnings.filterwarnings(
  127. "ignore", message=descoteaux07_legacy_msg,
  128. category=PendingDeprecationWarning)
  129. fodf = f_fit.odf(sphere)
  130. directions, _, _ = peak_directions(fodf, sphere, .35, 25)
  131. assert_equal(len(directions), 2)
  132. assert_almost_equal(
  133. angular_similarity(directions, data.sticks), 2, 1)
  134. # stronger regularization is required for high order SH
  135. with warnings.catch_warnings():
  136. warnings.filterwarnings(
  137. "ignore", message=descoteaux07_legacy_msg,
  138. category=PendingDeprecationWarning)
  139. fm = ForecastModel(data.gtab, sh_order=10,
  140. dec_alg='CSD', sphere=sphere.vertices)
  141. f_fit = fm.fit(data.S)
  142. with warnings.catch_warnings():
  143. warnings.filterwarnings(
  144. "ignore", message=descoteaux07_legacy_msg,
  145. category=PendingDeprecationWarning)
  146. fodf = f_fit.odf(sphere)
  147. directions, _, _ = peak_directions(fodf, sphere, .35, 25)
  148. assert_equal(len(directions), 2)
  149. assert_almost_equal(
  150. angular_similarity(directions, data.sticks), 2, 1)
  151. with warnings.catch_warnings():
  152. warnings.filterwarnings(
  153. "ignore", message=descoteaux07_legacy_msg,
  154. category=PendingDeprecationWarning)
  155. fm = ForecastModel(data.gtab, sh_order=12,
  156. dec_alg='CSD', sphere=sphere.vertices)
  157. f_fit = fm.fit(data.S)
  158. with warnings.catch_warnings():
  159. warnings.filterwarnings(
  160. "ignore", message=descoteaux07_legacy_msg,
  161. category=PendingDeprecationWarning)
  162. fodf = f_fit.odf(sphere)
  163. directions, _, _ = peak_directions(fodf, sphere, .35, 25)
  164. assert_equal(len(directions), 2)
  165. assert_almost_equal(
  166. angular_similarity(directions, data.sticks), 2, 1)
  167. def test_forecast_indices():
  168. # check anisotropic tensor
  169. with warnings.catch_warnings():
  170. warnings.filterwarnings(
  171. "ignore", message=descoteaux07_legacy_msg,
  172. category=PendingDeprecationWarning)
  173. fm = ForecastModel(data.gtab, sh_order=2,
  174. lambda_lb=data.lambda_lb, dec_alg='WLS')
  175. f_fit = fm.fit(data.S)
  176. d_par = f_fit.dpar
  177. d_perp = f_fit.dperp
  178. assert_almost_equal(d_par, data.mevals[0, 0], 5)
  179. assert_almost_equal(d_perp, data.mevals[0, 1], 5)
  180. gt_fa = np.sqrt(0.5 * (2*(data.mevals[0, 0] - data.mevals[0, 1])**2) / (
  181. data.mevals[0, 0]**2 + 2*data.mevals[0, 1]**2))
  182. gt_md = (data.mevals[0, 0] + 2*data.mevals[0, 1])/3.0
  183. assert_almost_equal(f_fit.fractional_anisotropy(), gt_fa, 2)
  184. assert_almost_equal(f_fit.mean_diffusivity(), gt_md, 5)
  185. # check isotropic tensor
  186. mevals = np.array(([0.003, 0.003, 0.003],
  187. [0.003, 0.003, 0.003]))
  188. data.angl = [(0, 0), (60, 0)]
  189. S, sticks = multi_tensor(data.gtab, mevals, S0=100.0, angles=data.angl,
  190. fractions=[50, 50], snr=None)
  191. with warnings.catch_warnings():
  192. warnings.filterwarnings(
  193. "ignore", message=descoteaux07_legacy_msg,
  194. category=PendingDeprecationWarning)
  195. fm = ForecastModel(data.gtab, sh_order=data.sh_order,
  196. lambda_lb=data.lambda_lb, dec_alg='WLS')
  197. f_fit = fm.fit(S)
  198. d_par = f_fit.dpar
  199. d_perp = f_fit.dperp
  200. assert_almost_equal(d_par, 3e-03, 5)
  201. assert_almost_equal(d_perp, 3e-03, 5)
  202. assert_almost_equal(f_fit.fractional_anisotropy(), 0.0, 5)
  203. assert_almost_equal(f_fit.mean_diffusivity(), 3e-03, 10)
  204. def test_forecast_predict():
  205. # check anisotropic tensor
  206. with warnings.catch_warnings():
  207. warnings.filterwarnings(
  208. "ignore", message=descoteaux07_legacy_msg,
  209. category=PendingDeprecationWarning)
  210. fm = ForecastModel(data.gtab, sh_order=8,
  211. dec_alg='CSD', sphere=data.sphere)
  212. f_fit = fm.fit(data.S)
  213. with warnings.catch_warnings():
  214. warnings.filterwarnings(
  215. "ignore", message=descoteaux07_legacy_msg,
  216. category=PendingDeprecationWarning)
  217. S = f_fit.predict(S0=1.0)
  218. mse = np.sum((S-data.S/100.0)**2) / len(S)
  219. assert_almost_equal(mse, 0.0, 3)
  220. def test_multivox_forecast():
  221. gtab = get_3shell_gtab()
  222. mevals = np.array(([0.0017, 0.0003, 0.0003],
  223. [0.0017, 0.0003, 0.0003]))
  224. angl1 = [(0, 0), (60, 0)]
  225. angl2 = [(90, 0), (45, 90)]
  226. angl3 = [(0, 0), (90, 0)]
  227. S = np.zeros((3, 1, 1, len(gtab.bvals)))
  228. S[0, 0, 0], _ = multi_tensor(gtab, mevals, S0=1.0, angles=angl1,
  229. fractions=[50, 50], snr=None)
  230. S[1, 0, 0], _ = multi_tensor(gtab, mevals, S0=1.0, angles=angl2,
  231. fractions=[50, 50], snr=None)
  232. S[2, 0, 0], _ = multi_tensor(gtab, mevals, S0=1.0, angles=angl3,
  233. fractions=[50, 50], snr=None)
  234. with warnings.catch_warnings():
  235. warnings.filterwarnings(
  236. "ignore", message=descoteaux07_legacy_msg,
  237. category=PendingDeprecationWarning)
  238. fm = ForecastModel(gtab, sh_order=8,
  239. dec_alg='CSD')
  240. f_fit = fm.fit(S)
  241. with warnings.catch_warnings():
  242. warnings.filterwarnings(
  243. "ignore", message=descoteaux07_legacy_msg,
  244. category=PendingDeprecationWarning)
  245. S_predict = f_fit.predict()
  246. assert_equal(S_predict.shape, S.shape)
  247. mse1 = np.sum((S_predict[0, 0, 0]-S[0, 0, 0])**2) / len(gtab.bvals)
  248. assert_almost_equal(mse1, 0.0, 3)
  249. mse2 = np.sum((S_predict[1, 0, 0]-S[1, 0, 0])**2) / len(gtab.bvals)
  250. assert_almost_equal(mse2, 0.0, 3)
  251. mse3 = np.sum((S_predict[2, 0, 0]-S[2, 0, 0])**2) / len(gtab.bvals)
  252. assert_almost_equal(mse3, 0.0, 3)