/specializers/stencil/tests/stencil_cache_block_test.py

https://github.com/richardxia/asp · Python · 195 lines · 152 code · 26 blank · 17 comment · 12 complexity · b495944513a6d5154b663ae21f170078 MD5 · raw file

  1. import unittest2 as unittest
  2. from asp.codegen.cpp_ast import *
  3. from stencil_convert import *
  4. from stencil_optimize_cpp import *
  5. from stencil_kernel import *
  6. class StencilConvertASTTests(unittest.TestCase):
  7. def setUp(self):
  8. class IdentityKernel(StencilKernel):
  9. def kernel(self, in_grid, out_grid):
  10. for x in out_grid.interior_points():
  11. for y in in_grid.neighbors(x, 1):
  12. out_grid[x] = out_grid[x] + in_grid[y]
  13. self.kernel = IdentityKernel()
  14. self.in_grid = StencilGrid([130,130])
  15. self.in_grids = [self.in_grid]
  16. self.out_grid = StencilGrid([130,130])
  17. self.model = python_func_to_unrolled_model(IdentityKernel.kernel, self.in_grids, self.out_grid)
  18. def test_StencilConvertAST_array_macro_use(self):
  19. import asp.codegen.cpp_ast as cpp_ast
  20. result = StencilConvertAST(self.model, self.in_grids, self.out_grid).gen_array_macro('in_grid',
  21. [cpp_ast.CNumber(3),
  22. cpp_ast.CNumber(4)])
  23. self.assertEqual(str(result), "_in_grid_array_macro(3, 4)")
  24. def test_whole_thing(self):
  25. import numpy
  26. for i in [1,2,3]:
  27. self.in_grid.data = numpy.ones([130,130])
  28. self.out_grid.data = numpy.zeros([130,130])
  29. self.kernel.kernel(self.in_grid, self.out_grid)
  30. # print self.kernel.mod.db.get("kernel")
  31. self.assertEqual(self.out_grid[5,5],4.0)
  32. for x in xrange(1,128):
  33. for y in xrange(1,128):
  34. self.assertAlmostEqual(self.out_grid[x,y], 4.0)
  35. def test_whole_thing_in_3D(self):
  36. import numpy
  37. for i in [1,2,3]:
  38. self.out_grid = StencilGrid([130,130,130])
  39. self.in_grid = StencilGrid([130,130,130])
  40. self.in_grid.data = numpy.ones([130,130,130])
  41. self.out_grid.data = numpy.zeros([130,130,130])
  42. self.kernel.kernel(self.in_grid, self.out_grid)
  43. self.assertEqual(self.out_grid[5,5,5],6.0)
  44. for x in xrange(1,128):
  45. for y in xrange(1,128):
  46. for z in xrange(1,128):
  47. self.assertAlmostEqual(self.out_grid[x,y,z], 6.0)
  48. class StencilConvertASTBlockedTests(unittest.TestCase):
  49. def setUp(self):
  50. class IdentityKernel(StencilKernel):
  51. def kernel(self, in_grid, out_grid):
  52. for x in out_grid.interior_points():
  53. for y in in_grid.neighbors(x, 1):
  54. out_grid[x] = out_grid[x] + in_grid[y]
  55. self.kernel = IdentityKernel()
  56. self.in_grid = StencilGrid([10,10])
  57. self.in_grids = [self.in_grid]
  58. self.out_grid = StencilGrid([10,10])
  59. self.model = python_func_to_unrolled_model(IdentityKernel.kernel, self.in_grids, self.out_grid)
  60. self.base_variant = Converter(model, input_grids, output_grid).run()
  61. # def test_gen_loops(self):
  62. # converter = StencilConvertASTBlocked(self.model, self.in_grids, self.out_grid, block_factor=(2,1))
  63. # result = converter.gen_loops(self.model)
  64. # wanted = """for (int x1x1 = 1; (x1x1 <= 8); x1x1 = (x1x1 + (1 * 2)))
  65. # {
  66. # for (int x1 = x1x1; (x1 <= min((x1x1 + 1),8)); x1 = (x1 + 1))
  67. # {
  68. # #pragma ivdep
  69. # for(intx2=1;(x2<=8);x2=(x2+1))
  70. # {
  71. # }
  72. # }
  73. # }"""
  74. # self.assertEqual(wanted.replace(' ',''), str(result[1]).replace(' ',''))
  75. class CacheBlockerTests(unittest.TestCase):
  76. def test_2d(self):
  77. loop = For("i",
  78. CNumber(0),
  79. CNumber(7),
  80. CNumber(1),
  81. Block(contents=[For("j",
  82. CNumber(0),
  83. CNumber(3),
  84. CNumber(1),
  85. Block(contents=[Assign(CName("v"), CName("i"))]))]))
  86. wanted = """for (int ii = 0; (ii <= 7); ii = (ii + (1 * 2)))
  87. {
  88. for (int jj = 0; (jj <= 3); jj = (jj + (1 * 2)))
  89. {
  90. for (int i = ii; (i <= min((ii + 1),7)); i = (i + 1))
  91. {
  92. for (int j = jj; (j <= min((jj + 1),3)); j = (j + 1))
  93. {
  94. v = i;
  95. }
  96. }
  97. }
  98. }"""
  99. self.assertEqual(str(StencilCacheBlocker().block(loop, (2, 2))).replace(' ',''), wanted.replace(' ',''))
  100. def test_3d(self):
  101. loop = For("i",
  102. CNumber(0),
  103. CNumber(7),
  104. CNumber(1),
  105. Block(contents=[For("j",
  106. CNumber(0),
  107. CNumber(3),
  108. CNumber(1),
  109. Block(contents=[For("k",
  110. CNumber(0),
  111. CNumber(4),
  112. CNumber(1),
  113. Block(contents=[Assign(CName("v"), CName("i"))]))]))]))
  114. #print StencilCacheBlocker().block(loop, (2,2,3))
  115. wanted = """for (int ii = 0; (ii <= 7); ii = (ii + (1 * 2)))
  116. {
  117. for (int jj = 0; (jj <= 3); jj = (jj + (1 * 2)))
  118. {
  119. for (int kk = 0; (kk <= 4); kk = (kk + (1 * 3)))
  120. {
  121. for (int i = ii; (i <= min((ii + 1),7)); i = (i + 1))
  122. {
  123. for (int j = jj; (j <= min((jj + 1),3)); j = (j + 1))
  124. {
  125. for (int k = kk; (k <= min((kk + 2),4)); k = (k + 1))
  126. {
  127. v = i;
  128. }\n}\n}\n}\n}\n}"""
  129. self.assertEqual(str(StencilCacheBlocker().block(loop, (2,2,3))).replace(' ',''),
  130. wanted.replace(' ', ''))
  131. def test_rivera_blocking(self):
  132. loop = For("i",
  133. CNumber(0),
  134. CNumber(7),
  135. CNumber(1),
  136. Block(contents=[For("j",
  137. CNumber(0),
  138. CNumber(3),
  139. CNumber(1),
  140. Block(contents=[For("k",
  141. CNumber(0),
  142. CNumber(4),
  143. CNumber(1),
  144. Block(contents=[Assign(CName("v"), CName("i"))]))]))]))
  145. #print StencilCacheBlocker().block(loop, (2,2,0))
  146. wanted = """for (int ii = 0; (ii <= 7); ii = (ii + (1 * 2)))
  147. {
  148. for (int jj = 0; (jj <= 3); jj = (jj + (1 * 2)))
  149. {
  150. for (int i = ii; (i <= min((ii + 1),7)); i = (i + 1))
  151. {
  152. for (int j = jj; (j <= min((jj + 1),3)); j = (j + 1))
  153. {
  154. for (int k = 0; (k <= 4); k = (k + 1))
  155. {
  156. v = i;
  157. }
  158. }
  159. }
  160. }
  161. }"""
  162. self.assertEqual(str(StencilCacheBlocker().block(loop, (2,2,0))).replace(' ',''),
  163. wanted.replace(' ', ''))
  164. def python_func_to_unrolled_model(func, in_grids, out_grid):
  165. python_ast = ast.parse(inspect.getsource(func).lstrip())
  166. model = StencilPythonFrontEnd().parse(python_ast)
  167. return StencilUnrollNeighborIter(model, in_grids, out_grid).run()
  168. if __name__ == '__main__':
  169. unittest.main()