/dipy/direction/tests/test_prob_direction_getter.py

https://github.com/arokem/dipy · Python · 102 lines · 67 code · 23 blank · 12 comment · 3 complexity · 7d398c2fbf626d9ba925ba506ed419c2 MD5 · raw file

  1. import warnings
  2. import numpy as np
  3. import numpy.testing as npt
  4. from dipy.core.sphere import unit_octahedron
  5. from dipy.reconst.shm import (
  6. descoteaux07_legacy_msg, tournier07_legacy_msg, SphHarmFit, SphHarmModel)
  7. from dipy.direction import (DeterministicMaximumDirectionGetter,
  8. ProbabilisticDirectionGetter)
  9. def test_ProbabilisticDirectionGetter():
  10. # Test the constructors and errors of the ProbabilisticDirectionGetter
  11. class SillyModel(SphHarmModel):
  12. sh_order = 4
  13. def fit(self, data, mask=None):
  14. coeff = np.zeros(data.shape[:-1] + (15,))
  15. return SphHarmFit(self, coeff, mask=None)
  16. model = SillyModel(gtab=None)
  17. data = np.zeros((3, 3, 3, 7))
  18. # Test if the tracking works on different dtype of the same data.
  19. for dtype in [np.float32, np.float64]:
  20. fit = model.fit(data.astype(dtype))
  21. # Sample point and direction
  22. point = np.zeros(3)
  23. direction = unit_octahedron.vertices[0].copy()
  24. # make a dg from a fit
  25. with warnings.catch_warnings():
  26. warnings.filterwarnings(
  27. "ignore", message=descoteaux07_legacy_msg,
  28. category=PendingDeprecationWarning)
  29. dg = ProbabilisticDirectionGetter.from_shcoeff(
  30. fit.shm_coeff, 90, unit_octahedron)
  31. state = dg.get_direction(point, direction)
  32. npt.assert_equal(state, 1)
  33. # Make a dg from a pmf
  34. N = unit_octahedron.theta.shape[0]
  35. pmf = np.zeros((3, 3, 3, N))
  36. dg = ProbabilisticDirectionGetter.from_pmf(pmf, 90, unit_octahedron)
  37. state = dg.get_direction(point, direction)
  38. npt.assert_equal(state, 1)
  39. # pmf shape must match sphere
  40. bad_pmf = pmf[..., 1:]
  41. npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf,
  42. bad_pmf, 90, unit_octahedron)
  43. # pmf must have 4 dimensions
  44. bad_pmf = pmf[0, ...]
  45. npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf,
  46. bad_pmf, 90, unit_octahedron)
  47. # pmf cannot have negative values
  48. pmf[0, 0, 0, 0] = -1
  49. npt.assert_raises(ValueError, ProbabilisticDirectionGetter.from_pmf,
  50. pmf, 90, unit_octahedron)
  51. # Check basis_type keyword
  52. with warnings.catch_warnings():
  53. warnings.filterwarnings(
  54. "ignore", message=tournier07_legacy_msg,
  55. category=PendingDeprecationWarning)
  56. dg = ProbabilisticDirectionGetter.from_shcoeff(
  57. fit.shm_coeff, 90, unit_octahedron, basis_type="tournier07")
  58. npt.assert_raises(ValueError,
  59. ProbabilisticDirectionGetter.from_shcoeff,
  60. fit.shm_coeff, 90, unit_octahedron,
  61. basis_type="not a basis")
  62. def test_DeterministicMaximumDirectionGetter():
  63. # Test the DeterministicMaximumDirectionGetter
  64. direction = unit_octahedron.vertices[-1].copy()
  65. point = np.zeros(3)
  66. N = unit_octahedron.theta.shape[0]
  67. # No valid direction
  68. pmf = np.zeros((3, 3, 3, N))
  69. dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 90,
  70. unit_octahedron)
  71. state = dg.get_direction(point, direction)
  72. npt.assert_equal(state, 1)
  73. # Test BF #1566 - bad condition in DeterministicMaximumDirectionGetter
  74. pmf = np.zeros((3, 3, 3, N))
  75. pmf[0, 0, 0, 0] = 1
  76. dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 0,
  77. unit_octahedron)
  78. state = dg.get_direction(point, direction)
  79. npt.assert_equal(state, 1)