/dipy/direction/tests/test_pmf.py

https://github.com/arokem/dipy · Python · 138 lines · 103 code · 22 blank · 13 comment · 7 complexity · db9674768a24030d34b6ad61c3f6e9b4 MD5 · raw file

  1. import warnings
  2. import numpy as np
  3. import numpy.testing as npt
  4. from dipy.core.gradients import gradient_table
  5. from dipy.core.sphere import HemiSphere, unit_octahedron
  6. from dipy.data import default_sphere, get_sphere
  7. from dipy.direction.pmf import SimplePmfGen, SHCoeffPmfGen, BootPmfGen
  8. from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel
  9. from dipy.reconst.dti import TensorModel
  10. from dipy.reconst.shm import descoteaux07_legacy_msg
  11. from dipy.sims.voxel import single_tensor
  12. response = (np.array([1.5e3, 0.3e3, 0.3e3]), 1)
  13. def test_pmf_val():
  14. sphere = get_sphere('symmetric724')
  15. with warnings.catch_warnings():
  16. warnings.filterwarnings(
  17. "ignore", message=descoteaux07_legacy_msg,
  18. category=PendingDeprecationWarning)
  19. pmfgen = SHCoeffPmfGen(np.random.random([2, 2, 2, 28]), sphere, None)
  20. point = np.array([1, 1, 1], dtype='float')
  21. for idx in [0, 5, 15, -1]:
  22. pmf = pmfgen.get_pmf(point)
  23. # Create a direction vector close to the vertex idx
  24. xyz = sphere.vertices[idx] + np.random.random([3]) / 100
  25. pmf_idx = pmfgen.get_pmf_value(point, xyz)
  26. # Test that the pmf sampled for the direction xyz is correct
  27. npt.assert_array_equal(pmf[idx], pmf_idx)
  28. def test_pmf_from_sh():
  29. sphere = HemiSphere.from_sphere(unit_octahedron)
  30. with warnings.catch_warnings():
  31. warnings.filterwarnings(
  32. "ignore", message=descoteaux07_legacy_msg,
  33. category=PendingDeprecationWarning)
  34. pmfgen = SHCoeffPmfGen(np.ones([2, 2, 2, 28]), sphere, None)
  35. # Test that the pmf is greater than 0 for a valid point
  36. pmf = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'))
  37. npt.assert_equal(np.sum(pmf) > 0, True)
  38. # Test that the pmf is 0 for invalid Points
  39. npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype='float')),
  40. np.zeros(len(sphere.vertices)))
  41. npt.assert_array_equal(pmfgen.get_pmf(np.array([0, 0, 10], dtype='float')),
  42. np.zeros(len(sphere.vertices)))
  43. def test_pmf_from_array():
  44. sphere = HemiSphere.from_sphere(unit_octahedron)
  45. pmfgen = SimplePmfGen(np.ones([2, 2, 2, len(sphere.vertices)]), sphere)
  46. # Test that the pmf is greater than 0 for a valid point
  47. pmf = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'))
  48. npt.assert_equal(np.sum(pmf) > 0, True)
  49. # Test that the pmf is 0 for invalid Points
  50. npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype='float')),
  51. np.zeros(len(sphere.vertices)))
  52. npt.assert_array_equal(pmfgen.get_pmf(np.array([0, 0, 10], dtype='float')),
  53. np.zeros(len(sphere.vertices)))
  54. # Test ValueError for negative pmf
  55. npt.assert_raises(
  56. ValueError,
  57. lambda: SimplePmfGen(np.ones([2, 2, 2, len(sphere.vertices)])*-1,
  58. sphere))
  59. # Test ValueError for non matching pmf and sphere
  60. npt.assert_raises(
  61. ValueError,
  62. lambda: SimplePmfGen(np.ones([2, 2, 2, len(sphere.vertices)]),
  63. default_sphere))
  64. def test_boot_pmf():
  65. # This tests the local model used for the bootstrapping.
  66. hsph_updated = HemiSphere.from_sphere(unit_octahedron)
  67. vertices = hsph_updated.vertices
  68. bvecs = vertices
  69. bvals = np.ones(len(vertices)) * 1000
  70. bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0)
  71. bvals = np.insert(bvals, 0, 0)
  72. gtab = gradient_table(bvals, bvecs)
  73. voxel = single_tensor(gtab)
  74. data = np.tile(voxel, (3, 3, 3, 1))
  75. point = np.array([1., 1., 1.])
  76. tensor_model = TensorModel(gtab)
  77. with warnings.catch_warnings():
  78. warnings.filterwarnings(
  79. "ignore", message=descoteaux07_legacy_msg,
  80. category=PendingDeprecationWarning)
  81. boot_pmf_gen = BootPmfGen(
  82. data, model=tensor_model, sphere=hsph_updated)
  83. no_boot_pmf = boot_pmf_gen.get_pmf_no_boot(point)
  84. model_pmf = tensor_model.fit(voxel).odf(hsph_updated)
  85. npt.assert_equal(len(hsph_updated.vertices), no_boot_pmf.shape[0])
  86. npt.assert_array_almost_equal(no_boot_pmf, model_pmf)
  87. # test model spherical harmonic order different than bootstrap order
  88. with warnings.catch_warnings(record=True) as w:
  89. warnings.simplefilter("always", category=UserWarning)
  90. warnings.simplefilter("always", category=PendingDeprecationWarning)
  91. csd_model = ConstrainedSphericalDeconvModel(gtab, response,
  92. sh_order=6)
  93. # Tests that the first caught warning comes from the CSD model
  94. # constructor
  95. npt.assert_(issubclass(w[0].category, UserWarning))
  96. npt.assert_("Number of parameters required " in str(w[0].message))
  97. # Tests that additional warnings are raised for outdated SH basis
  98. npt.assert_(len(w) > 1)
  99. with warnings.catch_warnings():
  100. warnings.filterwarnings(
  101. "ignore", message=descoteaux07_legacy_msg,
  102. category=PendingDeprecationWarning)
  103. boot_pmf_gen_sh4 = BootPmfGen(data, sphere=hsph_updated,
  104. model=csd_model, sh_order=4)
  105. pmf_sh4 = boot_pmf_gen_sh4.get_pmf(point)
  106. npt.assert_equal(len(hsph_updated.vertices), pmf_sh4.shape[0])
  107. npt.assert_(np.sum(pmf_sh4.shape) > 0)
  108. with warnings.catch_warnings():
  109. warnings.filterwarnings(
  110. "ignore", message=descoteaux07_legacy_msg,
  111. category=PendingDeprecationWarning)
  112. boot_pmf_gen_sh8 = BootPmfGen(data, model=csd_model,
  113. sphere=hsph_updated, sh_order=8)
  114. pmf_sh8 = boot_pmf_gen_sh8.get_pmf(point)
  115. npt.assert_equal(len(hsph_updated.vertices), pmf_sh8.shape[0])
  116. npt.assert_(np.sum(pmf_sh8.shape) > 0)