PageRenderTime 27ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/dipy/viz/tests/test_fury.py

https://github.com/nipy/dipy
Python | 447 lines | 319 code | 85 blank | 43 comment | 31 complexity | 0eb4c6b668e9946075052bf743d01dab MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import warnings
  2. import pytest
  3. import numpy as np
  4. import numpy.testing as npt
  5. from dipy.testing.decorators import use_xvfb
  6. from dipy.utils.optpkg import optional_package
  7. from dipy.data import get_sphere
  8. from dipy.align.reslice import reslice
  9. from dipy.data import read_stanford_labels
  10. from dipy.reconst.shm import CsaOdfModel
  11. from dipy.data import default_sphere
  12. from dipy.direction import peaks_from_model
  13. from dipy.tracking import utils
  14. from dipy.tracking.stopping_criterion import \
  15. ThresholdStoppingCriterion
  16. from dipy.tracking.local_tracking import LocalTracking
  17. from dipy.reconst.shm import descoteaux07_legacy_msg, sh_to_sf_matrix
  18. from dipy.align.tests.test_streamlinear import fornix_streamlines
  19. from dipy.tracking.streamline import (center_streamlines,
  20. transform_streamlines)
  21. from dipy.reconst.dti import color_fa, fractional_anisotropy
  22. fury, has_fury, setup_module = optional_package('fury')
  23. if has_fury:
  24. from dipy.viz import actor, window, colormap
  25. skip_it = use_xvfb == 'skip'
  26. @pytest.mark.skipif(skip_it or not has_fury,
  27. reason="Needs xvfb")
  28. def test_slicer():
  29. scene = window.Scene()
  30. data = (255 * np.random.rand(50, 50, 50))
  31. affine = np.diag([1, 3, 2, 1])
  32. data2, affine2 = reslice(data, affine, zooms=(1, 3, 2),
  33. new_zooms=(1, 1, 1))
  34. slicer = actor.slicer(data2, affine2, interpolation='linear')
  35. slicer.display(None, None, 25)
  36. scene.add(slicer)
  37. scene.reset_camera()
  38. scene.reset_clipping_range()
  39. # window.show(scene, reset_camera=False)
  40. arr = window.snapshot(scene, offscreen=True)
  41. report = window.analyze_snapshot(arr, find_objects=True)
  42. npt.assert_equal(report.objects, 1)
  43. npt.assert_array_equal([1, 3, 2] * np.array(data.shape),
  44. np.array(slicer.shape))
  45. @pytest.mark.skipif(skip_it or not has_fury,
  46. reason="Needs xvfb")
  47. def test_contour_from_roi():
  48. hardi_img, gtab, labels_img = read_stanford_labels()
  49. data = np.asanyarray(hardi_img.dataobj)
  50. labels = np.asanyarray(labels_img.dataobj)
  51. affine = hardi_img.affine
  52. white_matter = (labels == 1) | (labels == 2)
  53. with warnings.catch_warnings():
  54. warnings.filterwarnings(
  55. "ignore", message=descoteaux07_legacy_msg,
  56. category=PendingDeprecationWarning)
  57. csa_model = CsaOdfModel(gtab, sh_order=6)
  58. csa_peaks = peaks_from_model(csa_model, data, default_sphere,
  59. relative_peak_threshold=.8,
  60. min_separation_angle=45,
  61. mask=white_matter)
  62. classifier = ThresholdStoppingCriterion(csa_peaks.gfa, .25)
  63. seed_mask = labels == 2
  64. seeds = utils.seeds_from_mask(seed_mask, density=[1, 1, 1],
  65. affine=affine)
  66. # Initialization of LocalTracking.
  67. # The computation happens in the next step.
  68. streamlines = LocalTracking(csa_peaks, classifier, seeds, affine,
  69. step_size=2)
  70. # Compute streamlines and store as a list.
  71. streamlines = list(streamlines)
  72. # Prepare the display objects.
  73. streamlines_actor = actor.line(
  74. streamlines, colormap.line_colors(streamlines))
  75. seedroi_actor = actor.contour_from_roi(seed_mask, affine,
  76. [0, 1, 1], 0.5)
  77. # Create the 3d display.
  78. sc = window.Scene()
  79. sc2 = window.Scene()
  80. sc.add(streamlines_actor)
  81. arr3 = window.snapshot(sc, 'test_surface3.png', offscreen=True)
  82. report3 = window.analyze_snapshot(arr3, find_objects=True)
  83. sc2.add(streamlines_actor)
  84. sc2.add(seedroi_actor)
  85. arr4 = window.snapshot(sc2, 'test_surface4.png', offscreen=True)
  86. report4 = window.analyze_snapshot(arr4, find_objects=True)
  87. # assert that the seed ROI rendering is not far
  88. # away from the streamlines (affine error)
  89. npt.assert_equal(report3.objects, report4.objects)
  90. # window.show(sc)
  91. # window.show(sc2)
  92. @pytest.mark.skipif(skip_it or not has_fury,
  93. reason="Needs xvfb")
  94. def test_bundle_maps():
  95. scene = window.Scene()
  96. bundle = fornix_streamlines()
  97. bundle, _ = center_streamlines(bundle)
  98. mat = np.array([[1, 0, 0, 100],
  99. [0, 1, 0, 100],
  100. [0, 0, 1, 100],
  101. [0, 0, 0, 1.]])
  102. bundle = transform_streamlines(bundle, mat)
  103. # metric = np.random.rand(*(200, 200, 200))
  104. metric = 100 * np.ones((200, 200, 200))
  105. # add lower values
  106. metric[100, :, :] = 100 * 0.5
  107. # create a nice orange-red colormap
  108. lut = actor.colormap_lookup_table(scale_range=(0., 100.),
  109. hue_range=(0., 0.1),
  110. saturation_range=(1, 1),
  111. value_range=(1., 1))
  112. line = actor.line(bundle, metric, linewidth=0.1, lookup_colormap=lut)
  113. scene.add(line)
  114. scene.add(actor.scalar_bar(lut, ' '))
  115. report = window.analyze_scene(scene)
  116. npt.assert_almost_equal(report.actors, 1)
  117. # window.show(scene)
  118. scene.clear()
  119. nb_points = np.sum([len(b) for b in bundle])
  120. values = 100 * np.random.rand(nb_points)
  121. # values[:nb_points/2] = 0
  122. line = actor.streamtube(bundle, values, linewidth=0.1, lookup_colormap=lut)
  123. scene.add(line)
  124. # window.show(scene)
  125. report = window.analyze_scene(scene)
  126. npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')
  127. scene.clear()
  128. colors = np.random.rand(nb_points, 3)
  129. # values[:nb_points/2] = 0
  130. line = actor.line(bundle, colors, linewidth=2)
  131. scene.add(line)
  132. # window.show(scene)
  133. report = window.analyze_scene(scene)
  134. npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')
  135. # window.show(scene)
  136. arr = window.snapshot(scene)
  137. report2 = window.analyze_snapshot(arr)
  138. npt.assert_equal(report2.objects, 1)
  139. # try other input options for colors
  140. scene.clear()
  141. actor.line(bundle, (1., 0.5, 0))
  142. actor.line(bundle, np.arange(len(bundle)))
  143. actor.line(bundle)
  144. colors = [np.random.rand(*b.shape) for b in bundle]
  145. actor.line(bundle, colors=colors)
  146. @pytest.mark.skipif(skip_it or not has_fury,
  147. reason="Needs xvfb")
  148. def test_odf_slicer(interactive=False):
  149. # Prepare our data
  150. sphere = get_sphere('repulsion100')
  151. shape = (11, 11, 11, sphere.vertices.shape[0])
  152. odfs = np.ones(shape)
  153. affine = np.array([[2.0, 0.0, 0.0, 3.0],
  154. [0.0, 2.0, 0.0, 3.0],
  155. [0.0, 0.0, 2.0, 1.0],
  156. [0.0, 0.0, 0.0, 1.0]])
  157. mask = np.ones(odfs.shape[:3], bool)
  158. mask[:4, :4, :4] = False
  159. # Test that affine and mask work
  160. odf_actor = actor.odf_slicer(odfs, sphere=sphere, affine=affine, mask=mask,
  161. scale=.25, colormap='blues')
  162. k = 2
  163. I, J, _ = odfs.shape[:3]
  164. odf_actor.display_extent(0, I - 1, 0, J - 1, k, k)
  165. scene = window.Scene()
  166. scene.add(odf_actor)
  167. scene.reset_camera()
  168. scene.reset_clipping_range()
  169. if interactive:
  170. window.show(scene, reset_camera=False)
  171. arr = window.snapshot(scene)
  172. report = window.analyze_snapshot(arr, find_objects=True)
  173. npt.assert_equal(report.objects, 11 * 11 - 16)
  174. # Test that global colormap works
  175. odf_actor = actor.odf_slicer(odfs, sphere=sphere, mask=mask, scale=.25,
  176. colormap='blues', norm=False, global_cm=True)
  177. scene.clear()
  178. scene.add(odf_actor)
  179. scene.reset_camera()
  180. scene.reset_clipping_range()
  181. if interactive:
  182. window.show(scene)
  183. # Test that the most basic odf_slicer instantiation works
  184. odf_actor = actor.odf_slicer(odfs)
  185. scene.clear()
  186. scene.add(odf_actor)
  187. scene.reset_camera()
  188. scene.reset_clipping_range()
  189. if interactive:
  190. window.show(scene)
  191. # Test that odf_slicer.display works properly
  192. scene.clear()
  193. scene.add(odf_actor)
  194. scene.add(actor.axes((11, 11, 11)))
  195. for i in range(11):
  196. odf_actor.display(i, None, None)
  197. if interactive:
  198. window.show(scene)
  199. for j in range(11):
  200. odf_actor.display(None, j, None)
  201. if interactive:
  202. window.show(scene)
  203. # With mask equal to zero everything should be black
  204. mask = np.zeros(odfs.shape[:3])
  205. odf_actor = actor.odf_slicer(odfs, sphere=sphere, mask=mask,
  206. scale=.25, colormap='blues',
  207. norm=False, global_cm=True)
  208. scene.clear()
  209. scene.add(odf_actor)
  210. scene.reset_camera()
  211. scene.reset_clipping_range()
  212. if interactive:
  213. window.show(scene)
  214. # global_cm=True with colormap=None should raise an error
  215. npt.assert_raises(IOError, actor.odf_slicer, odfs, sphere=sphere,
  216. mask=None, scale=.25, colormap=None, norm=False,
  217. global_cm=True)
  218. # Dimension mismatch between sphere vertices and number
  219. # of SF coefficients will raise an error.
  220. npt.assert_raises(ValueError, actor.odf_slicer, odfs, mask=None,
  221. sphere=get_sphere('repulsion200'), scale=.25)
  222. # colormap=None and global_cm=False results in directionally encoded colors
  223. odf_actor = actor.odf_slicer(odfs, sphere=sphere, mask=None,
  224. scale=.25, colormap=None,
  225. norm=False, global_cm=False)
  226. scene.clear()
  227. scene.add(odf_actor)
  228. scene.reset_camera()
  229. scene.reset_clipping_range()
  230. if interactive:
  231. window.show(scene)
  232. # Test that SH coefficients input works
  233. with warnings.catch_warnings():
  234. warnings.filterwarnings(
  235. "ignore", message=descoteaux07_legacy_msg,
  236. category=PendingDeprecationWarning)
  237. B = sh_to_sf_matrix(sphere, sh_order=4, return_inv=False)
  238. odfs = np.zeros((11, 11, 11, B.shape[0]))
  239. odfs[..., 0] = 1.0
  240. odf_actor = actor.odf_slicer(odfs, sphere=sphere, B_matrix=B)
  241. scene.clear()
  242. scene.add(odf_actor)
  243. scene.reset_camera()
  244. scene.reset_clipping_range()
  245. if interactive:
  246. window.show(scene)
  247. # Dimension mismatch between sphere vertices and dimension of
  248. # B matrix will raise an error.
  249. npt.assert_raises(ValueError, actor.odf_slicer, odfs, mask=None,
  250. sphere=get_sphere('repulsion200'))
  251. # Test that constant colormap color works. Also test that sphere
  252. # normals are oriented correctly. Will show purple spheres with
  253. # a white contour.
  254. odf_contour = actor.odf_slicer(odfs, sphere=sphere, B_matrix=B,
  255. colormap=(255, 255, 255))
  256. odf_contour.GetProperty().SetAmbient(1.0)
  257. odf_contour.GetProperty().SetFrontfaceCulling(True)
  258. odf_actor = actor.odf_slicer(odfs, sphere=sphere, B_matrix=B,
  259. colormap=(255, 0, 255), scale=0.4)
  260. scene.clear()
  261. scene.add(odf_contour)
  262. scene.add(odf_actor)
  263. scene.reset_camera()
  264. scene.reset_clipping_range()
  265. if interactive:
  266. window.show(scene)
  267. # Test that we can change the sphere on an active actor
  268. new_sphere = get_sphere('symmetric362')
  269. with warnings.catch_warnings():
  270. warnings.filterwarnings(
  271. "ignore", message=descoteaux07_legacy_msg,
  272. category=PendingDeprecationWarning)
  273. new_B = sh_to_sf_matrix(new_sphere, sh_order=4, return_inv=False)
  274. odf_actor.update_sphere(new_sphere.vertices, new_sphere.faces, new_B)
  275. if interactive:
  276. window.show(scene)
  277. del odf_actor
  278. del odfs
  279. @pytest.mark.skipif(skip_it or not has_fury,
  280. reason="Needs xvfb")
  281. def test_tensor_slicer(interactive=False):
  282. evals = np.array([1.4, .35, .35]) * 10 ** (-3)
  283. evecs = np.eye(3)
  284. mevals = np.zeros((3, 2, 4, 3))
  285. mevecs = np.zeros((3, 2, 4, 3, 3))
  286. mevals[..., :] = evals
  287. mevecs[..., :, :] = evecs
  288. sphere = get_sphere('symmetric724')
  289. affine = np.eye(4)
  290. scene = window.Scene()
  291. tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
  292. sphere=sphere, scale=.3, opacity=0.4)
  293. _, J, K = mevals.shape[:3]
  294. scene.add(tensor_actor)
  295. scene.reset_camera()
  296. scene.reset_clipping_range()
  297. tensor_actor.display_extent(0, 1, 0, J, 0, K)
  298. if interactive:
  299. window.show(scene, reset_camera=False)
  300. tensor_actor.GetProperty().SetOpacity(1.0)
  301. if interactive:
  302. window.show(scene, reset_camera=False)
  303. npt.assert_equal(scene.GetActors().GetNumberOfItems(), 1)
  304. # Test extent
  305. big_extent = scene.GetActors().GetLastActor().GetBounds()
  306. big_extent_x = abs(big_extent[1] - big_extent[0])
  307. tensor_actor.display(x=2)
  308. if interactive:
  309. window.show(scene, reset_camera=False)
  310. small_extent = scene.GetActors().GetLastActor().GetBounds()
  311. small_extent_x = abs(small_extent[1] - small_extent[0])
  312. npt.assert_equal(big_extent_x > small_extent_x, True)
  313. # Test empty mask
  314. empty_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
  315. mask=np.zeros(mevals.shape[:3]),
  316. sphere=sphere, scale=.3)
  317. npt.assert_equal(empty_actor.GetMapper(), None)
  318. # Test mask
  319. mask = np.ones(mevals.shape[:3])
  320. mask[:2, :3, :3] = 0
  321. cfa = color_fa(fractional_anisotropy(mevals), mevecs)
  322. tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
  323. mask=mask, scalar_colors=cfa,
  324. sphere=sphere, scale=.3)
  325. scene.clear()
  326. scene.add(tensor_actor)
  327. scene.reset_camera()
  328. scene.reset_clipping_range()
  329. if interactive:
  330. window.show(scene, reset_camera=False)
  331. mask_extent = scene.GetActors().GetLastActor().GetBounds()
  332. mask_extent_x = abs(mask_extent[1] - mask_extent[0])
  333. npt.assert_equal(big_extent_x > mask_extent_x, True)
  334. # test display
  335. tensor_actor.display()
  336. current_extent = scene.GetActors().GetLastActor().GetBounds()
  337. current_extent_x = abs(current_extent[1] - current_extent[0])
  338. npt.assert_equal(big_extent_x > current_extent_x, True)
  339. if interactive:
  340. window.show(scene, reset_camera=False)
  341. tensor_actor.display(y=1)
  342. current_extent = scene.GetActors().GetLastActor().GetBounds()
  343. current_extent_y = abs(current_extent[3] - current_extent[2])
  344. big_extent_y = abs(big_extent[3] - big_extent[2])
  345. npt.assert_equal(big_extent_y > current_extent_y, True)
  346. if interactive:
  347. window.show(scene, reset_camera=False)
  348. tensor_actor.display(z=1)
  349. current_extent = scene.GetActors().GetLastActor().GetBounds()
  350. current_extent_z = abs(current_extent[5] - current_extent[4])
  351. big_extent_z = abs(big_extent[5] - big_extent[4])
  352. npt.assert_equal(big_extent_z > current_extent_z, True)
  353. if interactive:
  354. window.show(scene, reset_camera=False)
  355. # Test error handling of the method when
  356. # incompatible dimension of mevals and evecs are passed.
  357. mevals = np.zeros((3, 2, 3))
  358. mevecs = np.zeros((3, 2, 4, 3, 3))
  359. with npt.assert_raises(RuntimeError):
  360. tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
  361. mask=mask, scalar_colors=cfa,
  362. sphere=sphere, scale=.3)