/dipy/reconst/tests/test_cross_validation.py

https://github.com/arokem/dipy · Python · 132 lines · 96 code · 22 blank · 14 comment · 3 complexity · 73ef1759d102370c164cbe84d09c76dd MD5 · raw file

  1. """Testing cross-validation analysis."""
  2. import warnings
  3. import numpy as np
  4. import numpy.testing as npt
  5. import dipy.reconst.cross_validation as xval
  6. import dipy.data as dpd
  7. import dipy.reconst.dti as dti
  8. import dipy.core.gradients as gt
  9. import dipy.sims.voxel as sims
  10. import dipy.reconst.csdeconv as csd
  11. import dipy.reconst.base as base
  12. from dipy.io.image import load_nifti_data
  13. from dipy.reconst.shm import descoteaux07_legacy_msg
  14. # We'll set these globally:
  15. fdata, fbval, fbvec = dpd.get_fnames('small_64D')
  16. def test_coeff_of_determination():
  17. model = np.random.randn(10, 10, 10, 150)
  18. data = np.copy(model)
  19. # If the model predicts the data perfectly, the COD is all 100s:
  20. cod = xval.coeff_of_determination(data, model)
  21. npt.assert_array_equal(100, cod)
  22. def test_dti_xval():
  23. data = load_nifti_data(fdata)
  24. gtab = gt.gradient_table(fbval, fbvec)
  25. dm = dti.TensorModel(gtab, 'LS')
  26. # The data has 102 directions, so will not divide neatly into 10 bits
  27. npt.assert_raises(ValueError, xval.kfold_xval, dm, data, 10)
  28. # In simulation with no noise, COD should be perfect:
  29. psphere = dpd.get_sphere('symmetric362')
  30. bvecs = np.concatenate(([[0, 0, 0]], psphere.vertices))
  31. bvals = np.zeros(len(bvecs)) + 1000
  32. bvals[0] = 0
  33. gtab = gt.gradient_table(bvals, bvecs)
  34. mevals = np.array(([0.0015, 0.0003, 0.0001], [0.0015, 0.0003, 0.0003]))
  35. mevecs = [np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
  36. np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]])]
  37. S = sims.single_tensor(gtab, 100, mevals[0], mevecs[0], snr=None)
  38. dm = dti.TensorModel(gtab, 'LS')
  39. kf_xval = xval.kfold_xval(dm, S, 2)
  40. cod = xval.coeff_of_determination(S, kf_xval)
  41. npt.assert_array_almost_equal(cod, np.ones(kf_xval.shape[:-1]) * 100)
  42. # Test with 2D data for use of a mask
  43. S = np.array([[S, S], [S, S]])
  44. mask = np.ones(S.shape[:-1], dtype=bool)
  45. mask[1, 1] = 0
  46. kf_xval = xval.kfold_xval(dm, S, 2, mask=mask)
  47. cod2d = xval.coeff_of_determination(S, kf_xval)
  48. npt.assert_array_almost_equal(np.round(cod2d[0, 0]), cod)
  49. def test_csd_xval():
  50. # First, let's see that it works with some data:
  51. data = load_nifti_data(fdata)[1:3, 1:3, 1:3] # Make it *small*
  52. gtab = gt.gradient_table(fbval, fbvec)
  53. S0 = np.mean(data[..., gtab.b0s_mask])
  54. response = ([0.0015, 0.0003, 0.0001], S0)
  55. # In simulation, it should work rather well (high COD):
  56. psphere = dpd.get_sphere('symmetric362')
  57. bvecs = np.concatenate(([[0, 0, 0]], psphere.vertices))
  58. bvals = np.zeros(len(bvecs)) + 1000
  59. bvals[0] = 0
  60. gtab = gt.gradient_table(bvals, bvecs)
  61. mevals = np.array(([0.0015, 0.0003, 0.0001], [0.0015, 0.0003, 0.0003]))
  62. mevecs = [np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
  63. np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]])]
  64. S0 = 100
  65. S = sims.single_tensor(gtab, S0, mevals[0], mevecs[0], snr=None)
  66. with warnings.catch_warnings():
  67. warnings.filterwarnings(
  68. "ignore", message=descoteaux07_legacy_msg,
  69. category=PendingDeprecationWarning)
  70. sm = csd.ConstrainedSphericalDeconvModel(gtab, response)
  71. np.random.seed(12345)
  72. response = ([0.0015, 0.0003, 0.0001], S0)
  73. with warnings.catch_warnings():
  74. warnings.filterwarnings(
  75. "ignore", message=descoteaux07_legacy_msg,
  76. category=PendingDeprecationWarning)
  77. kf_xval = xval.kfold_xval(sm, S, 2, response, sh_order=2)
  78. # Because of the regularization, COD is not going to be perfect here:
  79. cod = xval.coeff_of_determination(S, kf_xval)
  80. # We'll just test for regressions:
  81. csd_cod = 97 # pre-computed by hand for this random seed
  82. # We're going to be really lenient here:
  83. npt.assert_array_almost_equal(np.round(cod), csd_cod)
  84. # Test for sD data with more than one voxel for use of a mask:
  85. S = np.array([[S, S], [S, S]])
  86. mask = np.ones(S.shape[:-1], dtype=bool)
  87. mask[1, 1] = 0
  88. with warnings.catch_warnings():
  89. warnings.filterwarnings(
  90. "ignore", message=descoteaux07_legacy_msg,
  91. category=PendingDeprecationWarning)
  92. kf_xval = xval.kfold_xval(sm, S, 2, response, sh_order=2,
  93. mask=mask)
  94. cod = xval.coeff_of_determination(S, kf_xval)
  95. npt.assert_array_almost_equal(np.round(cod[0]), csd_cod)
  96. def test_no_predict():
  97. # Test that if you try to do this with a model that doesn't have a `predict`
  98. # method, you get something reasonable.
  99. class NoPredictModel(base.ReconstModel):
  100. def __init__(self, gtab):
  101. base.ReconstModel.__init__(self, gtab)
  102. def fit(self, data, mask=None):
  103. return NoPredictFit(self, data, mask=mask)
  104. class NoPredictFit(base.ReconstFit):
  105. def __init__(self, model, data, mask=None):
  106. base.ReconstFit.__init__(self, model, data)
  107. gtab = gt.gradient_table(fbval, fbvec)
  108. my_model = NoPredictModel(gtab)
  109. data = load_nifti_data(fdata)[1:3, 1:3, 1:3] # Whatever
  110. npt.assert_raises(ValueError, xval.kfold_xval, my_model, data, 2)