/dipy/denoise/tests/test_gibbs.py

https://github.com/arokem/dipy
Python | 265 lines | 175 code | 55 blank | 35 comment | 0 complexity | aa8f05a82824925e57dfb2b4e5c971f3 MD5 | raw file
  1. import numpy as np
  2. from dipy.denoise.gibbs import (_gibbs_removal_1d, _gibbs_removal_2d,
  3. gibbs_removal, _image_tv)
  4. from numpy.testing import (assert_, assert_array_almost_equal, assert_raises)
  5. def setup_module():
  6. """Module-level setup"""
  7. global image_gibbs, image_gt, image_cor, Nre
  8. # Produce a 2D image
  9. Nori = 32
  10. image = np.zeros((6 * Nori, 6 * Nori))
  11. image[Nori: 2 * Nori, Nori: 2 * Nori] = 1
  12. image[Nori: 2 * Nori, 4 * Nori: 5 * Nori] = 1
  13. image[2 * Nori: 3 * Nori, Nori: 3 * Nori] = 1
  14. image[3 * Nori: 4 * Nori, 2 * Nori: 3 * Nori] = 2
  15. image[3 * Nori: 4 * Nori, 4 * Nori: 5 * Nori] = 1
  16. image[4 * Nori: 5 * Nori, 3 * Nori: 5 * Nori] = 3
  17. # Corrupt image with gibbs ringing
  18. c = np.fft.fft2(image)
  19. c = np.fft.fftshift(c)
  20. c_crop = c[48:144, 48:144]
  21. image_gibbs = abs(np.fft.ifft2(c_crop)/4)
  22. # Produce ground truth
  23. Nre = 16
  24. image_gt = np.zeros((6 * Nre, 6 * Nre))
  25. image_gt[Nre: 2 * Nre, Nre: 2 * Nre] = 1
  26. image_gt[Nre: 2 * Nre, 4 * Nre: 5 * Nre] = 1
  27. image_gt[2 * Nre: 3 * Nre, Nre: 3 * Nre] = 1
  28. image_gt[3 * Nre: 4 * Nre, 2 * Nre: 3 * Nre] = 2
  29. image_gt[3 * Nre: 4 * Nre, 4 * Nre: 5 * Nre] = 1
  30. image_gt[4 * Nre: 5 * Nre, 3 * Nre: 5 * Nre] = 3
  31. # Suppressing gibbs artefacts
  32. image_cor = _gibbs_removal_2d(image_gibbs)
  33. def test_parallel():
  34. # Only relevant for 3d or 4d inputs
  35. # Make input data
  36. input_2d = image_gibbs.copy()
  37. input_3d = np.stack([input_2d, input_2d], axis=2)
  38. input_4d = np.stack([input_3d, input_3d], axis=3)
  39. # Test 3d case
  40. output_3d_parallel = gibbs_removal(input_3d, inplace=False,
  41. num_processes=2)
  42. output_3d_no_parallel = gibbs_removal(
  43. input_3d, inplace=False, num_processes=1
  44. )
  45. assert_array_almost_equal(output_3d_parallel, output_3d_no_parallel)
  46. # Test 4d case
  47. output_4d_parallel = gibbs_removal(input_4d, inplace=False,
  48. num_processes=2)
  49. output_4d_no_parallel = gibbs_removal(
  50. input_4d, inplace=False, num_processes=1
  51. )
  52. assert_array_almost_equal(output_4d_parallel, output_4d_no_parallel)
  53. # Test num_processes=None case
  54. output_4d_all_cpu = gibbs_removal(
  55. input_4d, inplace=False, num_processes=None
  56. )
  57. assert_array_almost_equal(output_4d_all_cpu, output_4d_no_parallel)
  58. def test_inplace():
  59. # Make input data
  60. input_2d = image_gibbs.copy()
  61. input_3d = np.stack([input_2d, input_2d], axis=2)
  62. input_4d = np.stack([input_3d, input_3d], axis=3)
  63. # Test 2d cases
  64. output_2d = gibbs_removal(input_2d, inplace=False)
  65. assert_raises(
  66. AssertionError, assert_array_almost_equal, input_2d, output_2d
  67. )
  68. output_2d = gibbs_removal(input_2d, inplace=True)
  69. assert_array_almost_equal(input_2d, output_2d)
  70. # Test 3d case
  71. output_3d = gibbs_removal(input_3d, inplace=False)
  72. assert_raises(
  73. AssertionError, assert_array_almost_equal, input_3d, output_3d
  74. )
  75. output_3d = gibbs_removal(input_3d, inplace=True)
  76. assert_array_almost_equal(input_3d, output_3d)
  77. # Test 4d case
  78. output_4d = gibbs_removal(input_4d, inplace=False)
  79. assert_raises(
  80. AssertionError, assert_array_almost_equal, input_4d, output_4d
  81. )
  82. output_4d = gibbs_removal(input_4d, inplace=True)
  83. assert_array_almost_equal(input_4d, output_4d)
  84. def test_gibbs_2d():
  85. # Correction of gibbs ringing have to be closer to gt than denoised image
  86. diff_raw = np.mean(abs(image_gibbs - image_gt))
  87. diff_cor = np.mean(abs(image_cor - image_gt))
  88. assert_(diff_raw > diff_cor)
  89. # Test if gibbs_removal works for 2D data
  90. image_cor2 = gibbs_removal(image_gibbs, inplace=False)
  91. assert_array_almost_equal(image_cor2, image_cor)
  92. def test_gibbs_3d():
  93. image3d = np.zeros((6 * Nre, 6 * Nre, 2))
  94. image3d[:, :, 0] = image_gibbs
  95. image3d[:, :, 1] = image_gibbs
  96. image3d_cor = gibbs_removal(image3d, 2)
  97. assert_array_almost_equal(image3d_cor[:, :, 0], image_cor)
  98. assert_array_almost_equal(image3d_cor[:, :, 1], image_cor)
  99. def test_gibbs_4d():
  100. image4d = np.zeros((6 * Nre, 6 * Nre, 2, 2))
  101. image4d[:, :, 0, 0] = image_gibbs
  102. image4d[:, :, 1, 0] = image_gibbs
  103. image4d[:, :, 0, 1] = image_gibbs
  104. image4d[:, :, 1, 1] = image_gibbs
  105. image4d_cor = gibbs_removal(image4d)
  106. assert_array_almost_equal(image4d_cor[:, :, 0, 0], image_cor)
  107. assert_array_almost_equal(image4d_cor[:, :, 1, 0], image_cor)
  108. assert_array_almost_equal(image4d_cor[:, :, 0, 1], image_cor)
  109. assert_array_almost_equal(image4d_cor[:, :, 1, 1], image_cor)
  110. def test_swapped_gibbs_2d():
  111. # 2D case: In this case slice_axis is a dummy variable. Since data is
  112. # already a single 2D image, to axis swapping is required
  113. image_cor0 = gibbs_removal(image_gibbs, slice_axis=0, inplace=False)
  114. assert_array_almost_equal(image_cor0, image_cor)
  115. image_cor1 = gibbs_removal(image_gibbs, slice_axis=1, inplace=False)
  116. assert_array_almost_equal(image_cor1, image_cor)
  117. image_cor2 = gibbs_removal(image_gibbs, slice_axis=2, inplace=False)
  118. assert_array_almost_equal(image_cor2, image_cor)
  119. def test_swapped_gibbs_3d():
  120. image3d = np.zeros((6 * Nre, 2, 6 * Nre))
  121. image3d[:, 0, :] = image_gibbs
  122. image3d[:, 1, :] = image_gibbs
  123. image3d_cor = gibbs_removal(image3d, slice_axis=1)
  124. assert_array_almost_equal(image3d_cor[:, 0, :], image_cor)
  125. assert_array_almost_equal(image3d_cor[:, 1, :], image_cor)
  126. image3d = np.zeros((2, 6 * Nre, 6 * Nre))
  127. image3d[0, :, :] = image_gibbs
  128. image3d[1, :, :] = image_gibbs
  129. image3d_cor = gibbs_removal(image3d, slice_axis=0)
  130. assert_array_almost_equal(image3d_cor[0, :, :], image_cor)
  131. assert_array_almost_equal(image3d_cor[1, :, :], image_cor)
  132. def test_swapped_gibbs_4d():
  133. image4d = np.zeros((2, 6 * Nre, 6 * Nre, 2))
  134. image4d[0, :, :, 0] = image_gibbs
  135. image4d[1, :, :, 0] = image_gibbs
  136. image4d[0, :, :, 1] = image_gibbs
  137. image4d[1, :, :, 1] = image_gibbs
  138. image4d_cor = gibbs_removal(image4d, slice_axis=0)
  139. assert_array_almost_equal(image4d_cor[0, :, :, 0], image_cor)
  140. assert_array_almost_equal(image4d_cor[1, :, :, 0], image_cor)
  141. assert_array_almost_equal(image4d_cor[0, :, :, 1], image_cor)
  142. assert_array_almost_equal(image4d_cor[1, :, :, 1], image_cor)
  143. def test_gibbs_errors():
  144. assert_raises(ValueError, gibbs_removal, np.ones((2, 2, 2, 2, 2)))
  145. assert_raises(ValueError, gibbs_removal, np.ones(2))
  146. assert_raises(ValueError, gibbs_removal, np.ones((2, 2, 2)), 3)
  147. assert_raises(TypeError, gibbs_removal, image_gibbs.copy(), inplace="True")
  148. # Test for valid num_processes
  149. assert_raises(
  150. TypeError, gibbs_removal, image_gibbs.copy(), num_processes="1"
  151. )
  152. assert_raises(
  153. ValueError, gibbs_removal, image_gibbs.copy(), num_processes=0
  154. )
  155. # Test for valid input dimensionality
  156. assert_raises(ValueError, gibbs_removal, np.ones(2)) # 1D
  157. assert_raises(ValueError, gibbs_removal, np.ones((2, 2, 2, 2, 2))) # 5D
  158. def test_gibbs_subfunction():
  159. # This complementary test is to make sure that Gibbs suppression
  160. # sub-functions are properly implemented
  161. # Testing correction along axis 0
  162. image_a0 = _gibbs_removal_1d(image_gibbs, axis=0)
  163. # After this step tv along axis 0 should provide lower values than along
  164. # axis 1
  165. tv0_a0_r, tv0_a0_l = _image_tv(image_a0, axis=0)
  166. tv0_a0 = np.minimum(tv0_a0_r, tv0_a0_l)
  167. tv1_a0_r, tv1_a0_l = _image_tv(image_a0, axis=1)
  168. tv1_a0 = np.minimum(tv1_a0_r, tv1_a0_l)
  169. # Let's check that
  170. mean_tv0 = np.mean(abs(tv0_a0))
  171. mean_tv1 = np.mean(abs(tv1_a0))
  172. assert_(mean_tv0 < mean_tv1)
  173. # Testing correction along axis 1
  174. image_a1 = _gibbs_removal_1d(image_gibbs, axis=1)
  175. # After this step tv along axis 1 should provide higher values than along
  176. # axis 0
  177. tv0_a1_r, tv0_a1_l = _image_tv(image_a1, axis=0)
  178. tv0_a1 = np.minimum(tv0_a1_r, tv0_a1_l)
  179. tv1_a1_r, tv1_a1_l = _image_tv(image_a1, axis=1)
  180. tv1_a1 = np.minimum(tv1_a1_r, tv1_a1_l)
  181. # Let's check that
  182. mean_tv0 = np.mean(abs(tv0_a1))
  183. mean_tv1 = np.mean(abs(tv1_a1))
  184. assert_(mean_tv0 > mean_tv1)
  185. def test_non_square_image():
  186. # Produce non-square 2D image
  187. Nori = 32
  188. img = np.zeros((6 * Nori, 6 * Nori))
  189. img[Nori: 2 * Nori, Nori: 2 * Nori] = 1
  190. img[2 * Nori: 3 * Nori, Nori: 3 * Nori] = 1
  191. img[3 * Nori: 4 * Nori, 2 * Nori: 3 * Nori] = 2
  192. img[4 * Nori: 5 * Nori, 3 * Nori: 5 * Nori] = 3
  193. # Corrupt image with gibbs ringing
  194. c = np.fft.fft2(img)
  195. c = np.fft.fftshift(c)
  196. c_crop = c[48:144, :]
  197. img_gibbs = abs(np.fft.ifft2(c_crop)/2)
  198. # Produce ground truth
  199. Nre = 16
  200. img_gt = np.zeros((6 * Nre, 6 * Nori))
  201. img_gt[Nre: 2 * Nre, Nori: 2 * Nori] = 1
  202. img_gt[2 * Nre: 3 * Nre, Nori: 3 * Nori] = 1
  203. img_gt[3 * Nre: 4 * Nre, 2 * Nori: 3 * Nori] = 2
  204. img_gt[4 * Nre: 5 * Nre, 3 * Nori: 5 * Nori] = 3
  205. # Suppressing gibbs artefacts
  206. img_cor = gibbs_removal(img_gibbs, inplace=False)
  207. # Correction of gibbs ringing have to be closer to gt than denoised image
  208. diff_raw = np.mean(abs(img_gibbs - img_gt))
  209. diff_cor = np.mean(abs(img_cor - img_gt))
  210. assert_(diff_raw > diff_cor)