PageRenderTime 97ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 1ms

/rpython/jit/metainterp/optimizeopt/test/test_guard.py

https://bitbucket.org/pypy/pypy/
Python | 335 lines | 317 code | 14 blank | 4 comment | 13 complexity | e6d4a69bf247b6a086fd95094eab38eb MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. import py
  2. from rpython.jit.metainterp import compile
  3. from rpython.jit.metainterp.history import (TargetToken, JitCellToken,
  4. TreeLoop, Const)
  5. from rpython.jit.metainterp.optimizeopt.util import equaloplists
  6. from rpython.jit.metainterp.optimizeopt.vector import (Pack,
  7. NotAProfitableLoop, VectorizingOptimizer)
  8. from rpython.jit.metainterp.optimizeopt.dependency import (Node,
  9. DependencyGraph, IndexVar)
  10. from rpython.jit.metainterp.optimizeopt.guard import (GuardStrengthenOpt,
  11. Guard)
  12. from rpython.jit.metainterp.optimizeopt.test.test_util import LLtypeMixin
  13. from rpython.jit.metainterp.optimizeopt.test.test_schedule import SchedulerBaseTest
  14. from rpython.jit.metainterp.optimizeopt.test.test_vecopt import (FakeMetaInterpStaticData,
  15. FakeJitDriverStaticData, FakeLoopInfo)
  16. from rpython.jit.metainterp.resoperation import (rop,
  17. ResOperation, InputArgInt)
  18. from rpython.jit.tool.oparser_model import get_model
  19. class FakeMemoryRef(object):
  20. def __init__(self, array, iv):
  21. self.index_var = iv
  22. self.array = array
  23. def is_adjacent_to(self, other):
  24. if self.array is not other.array:
  25. return False
  26. iv = self.index_var
  27. ov = other.index_var
  28. val = (int(str(ov.var)[1:]) - int(str(iv.var)[1:]))
  29. # i0 and i1 are adjacent
  30. # i1 and i0 ...
  31. # but not i0, i2
  32. # ...
  33. return abs(val) == 1
  34. class FakeOp(object):
  35. def __init__(self, cmpop):
  36. self.boolinverse = ResOperation(cmpop, [box(0), box(0)], None).boolinverse
  37. self.cmpop = cmpop
  38. def getopnum(self):
  39. return self.cmpop
  40. def getarg(self, index):
  41. if index == 0:
  42. return 'lhs'
  43. elif index == 1:
  44. return 'rhs'
  45. else:
  46. assert 0
  47. class FakeResOp(object):
  48. def __init__(self, opnum):
  49. self.opnum = opnum
  50. def getopnum(self):
  51. return self.opnum
  52. def box(value):
  53. return InputArgInt(value)
  54. def const(value):
  55. return Const._new(value)
  56. def iv(value, coeff=(1,1,0)):
  57. var = IndexVar(value)
  58. var.coefficient_mul = coeff[0]
  59. var.coefficient_div = coeff[1]
  60. var.constant = coeff[2]
  61. return var
  62. def guard(opnum):
  63. def guard_impl(cmpop, lhs, rhs):
  64. guard = Guard(0, FakeResOp(opnum), FakeOp(cmpop), {'lhs': lhs, 'rhs': rhs})
  65. return guard
  66. return guard_impl
  67. guard_true = guard(rop.GUARD_TRUE)
  68. guard_false = guard(rop.GUARD_FALSE)
  69. del guard
  70. class GuardBaseTest(SchedulerBaseTest):
  71. def optguards(self, loop, user_code=False):
  72. info = FakeLoopInfo(loop)
  73. info.snapshot(loop)
  74. for op in loop.operations:
  75. if op.is_guard():
  76. op.setdescr(compile.CompileLoopVersionDescr())
  77. dep = DependencyGraph(loop)
  78. opt = GuardStrengthenOpt(dep.index_vars)
  79. opt.propagate_all_forward(info, loop, user_code)
  80. return opt
  81. def assert_guard_count(self, loop, count):
  82. guard = 0
  83. for op in loop.operations + loop.prefix:
  84. if op.is_guard():
  85. guard += 1
  86. if guard != count:
  87. self.debug_print_operations(loop)
  88. assert guard == count
  89. def assert_contains_sequence(self, loop, instr):
  90. class Glob(object):
  91. next = None
  92. prev = None
  93. def __repr__(self):
  94. return '*'
  95. from rpython.jit.tool.oparser import OpParser, default_fail_descr
  96. parser = OpParser(instr, self.cpu, self.namespace, None, default_fail_descr, True, None)
  97. parser.vars = { arg.repr_short(arg._repr_memo) : arg for arg in loop.inputargs}
  98. operations = []
  99. last_glob = None
  100. prev_op = None
  101. for line in instr.splitlines():
  102. line = line.strip()
  103. if line.startswith("#") or \
  104. line == "":
  105. continue
  106. if line.startswith("..."):
  107. last_glob = Glob()
  108. last_glob.prev = prev_op
  109. operations.append(last_glob)
  110. continue
  111. op = parser.parse_next_op(line)
  112. if last_glob is not None:
  113. last_glob.next = op
  114. last_glob = None
  115. operations.append(op)
  116. def check(op, candidate, rename):
  117. m = 0
  118. if isinstance(candidate, Glob):
  119. if candidate.next is None:
  120. return 0 # consumes the rest
  121. if op.getopnum() != candidate.next.getopnum():
  122. return 0
  123. m = 1
  124. candidate = candidate.next
  125. if op.getopnum() == candidate.getopnum():
  126. for i,arg in enumerate(op.getarglist()):
  127. oarg = candidate.getarg(i)
  128. if arg in rename:
  129. assert rename[arg].same_box(oarg)
  130. else:
  131. rename[arg] = oarg
  132. if not op.returns_void():
  133. rename[op] = candidate
  134. m += 1
  135. return m
  136. return 0
  137. j = 0
  138. rename = {}
  139. ops = loop.finaloplist()
  140. for i, op in enumerate(ops):
  141. candidate = operations[j]
  142. j += check(op, candidate, rename)
  143. if isinstance(operations[-1], Glob):
  144. assert j == len(operations)-1, self.debug_print_operations(loop)
  145. else:
  146. assert j == len(operations), self.debug_print_operations(loop)
  147. def test_basic(self):
  148. loop1 = self.parse_trace("""
  149. i10 = int_lt(i1, 42)
  150. guard_true(i10) []
  151. i101 = int_add(i1, 1)
  152. i102 = int_lt(i101, 42)
  153. guard_true(i102) []
  154. """)
  155. opt = self.optguards(loop1)
  156. self.assert_guard_count(loop1, 1)
  157. self.assert_contains_sequence(loop1, """
  158. ...
  159. i101 = int_add(i1, 1)
  160. i12 = int_lt(i101, 42)
  161. guard_true(i12) []
  162. ...
  163. """)
  164. def test_basic_sub(self):
  165. loop1 = self.parse_trace("""
  166. i10 = int_gt(i1, 42)
  167. guard_true(i10) []
  168. i101 = int_sub(i1, 1)
  169. i12 = int_gt(i101, 42)
  170. guard_true(i12) []
  171. """)
  172. opt = self.optguards(loop1)
  173. self.assert_guard_count(loop1, 1)
  174. self.assert_contains_sequence(loop1, """
  175. ...
  176. i101 = int_sub(i1, 1)
  177. i12 = int_gt(i101, 42)
  178. guard_true(i12) []
  179. ...
  180. """)
  181. def test_basic_mul(self):
  182. loop1 = self.parse_trace("""
  183. i10 = int_mul(i1, 4)
  184. i20 = int_lt(i10, 42)
  185. guard_true(i20) []
  186. i12 = int_add(i10, 1)
  187. i13 = int_lt(i12, 42)
  188. guard_true(i13) []
  189. """)
  190. opt = self.optguards(loop1)
  191. self.assert_guard_count(loop1, 1)
  192. self.assert_contains_sequence(loop1, """
  193. ...
  194. i101 = int_mul(i1, 4)
  195. i12 = int_add(i101, 1)
  196. i13 = int_lt(i12, 42)
  197. guard_true(i13) []
  198. ...
  199. """)
  200. def test_compare(self):
  201. key = box(1)
  202. incomparable = (False, 0)
  203. # const const
  204. assert iv(const(42)).compare(iv(const(42))) == (True, 0)
  205. assert iv(const(-400)).compare(iv(const(-200))) == (True, -200)
  206. assert iv(const(0)).compare(iv(const(-1))) == (True, 1)
  207. # var const
  208. assert iv(key, coeff=(1,1,0)).compare(iv(const(42))) == incomparable
  209. assert iv(key, coeff=(5,70,500)).compare(iv(const(500))) == incomparable
  210. # var var
  211. assert iv(key, coeff=(1,1,0)).compare(iv(key,coeff=(1,1,0))) == (True, 0)
  212. assert iv(key, coeff=(1,7,0)).compare(iv(key,coeff=(1,7,0))) == (True, 0)
  213. assert iv(key, coeff=(4,7,0)).compare(iv(key,coeff=(3,7,0))) == incomparable
  214. assert iv(key, coeff=(14,7,0)).compare(iv(key,coeff=(2,1,0))) == (True, 0)
  215. assert iv(key, coeff=(14,7,33)).compare(iv(key,coeff=(2,1,0))) == (True, 33)
  216. assert iv(key, coeff=(15,5,33)).compare(iv(key,coeff=(3,1,33))) == (True, 0)
  217. def test_imply_basic(self):
  218. key = box(1)
  219. # if x < 42 <=> x < 42
  220. g1 = guard_true(rop.INT_LT, iv(key, coeff=(1,1,0)), iv(const(42)))
  221. g2 = guard_true(rop.INT_LT, iv(key, coeff=(1,1,0)), iv(const(42)))
  222. assert g1.implies(g2)
  223. assert g2.implies(g1)
  224. # if x+1 < 42 => x < 42
  225. g1 = guard_true(rop.INT_LT, iv(key, coeff=(1,1,1)), iv(const(42)))
  226. g2 = guard_true(rop.INT_LT, iv(key, coeff=(1,1,0)), iv(const(42)))
  227. assert g1.implies(g2)
  228. assert not g2.implies(g1)
  229. # if x+2 < 42 => x < 39
  230. # counter: 39+2 < 42 => 39 < 39
  231. g1 = guard_true(rop.INT_LT, iv(key, coeff=(1,1,2)), iv(const(42)))
  232. g2 = guard_true(rop.INT_LT, iv(key, coeff=(1,1,0)), iv(const(39)))
  233. assert not g1.implies(g2)
  234. assert not g2.implies(g1)
  235. # if x+2 <= 42 => x <= 43
  236. g1 = guard_true(rop.INT_LE, iv(key, coeff=(1,1,2)), iv(const(42)))
  237. g2 = guard_true(rop.INT_LE, iv(key, coeff=(1,1,0)), iv(const(43)))
  238. assert g1.implies(g2)
  239. assert not g2.implies(g1)
  240. # if x*13/3+1 <= 0 => x*13/3 <= -1
  241. # is true, but the implies method is not smart enough
  242. g1 = guard_true(rop.INT_LE, iv(key, coeff=(13,3,1)), iv(const(0)))
  243. g2 = guard_true(rop.INT_LE, iv(key, coeff=(13,3,0)), iv(const(-1)))
  244. assert not g1.implies(g2)
  245. assert not g2.implies(g1)
  246. # > or >=
  247. # if x > -55 => x*2 > -44
  248. # counter: -44 > -55 (True) => -88 > -44 (False)
  249. g1 = guard_true(rop.INT_GT, iv(key, coeff=(1,1,0)), iv(const(-55)))
  250. g2 = guard_true(rop.INT_GT, iv(key, coeff=(2,1,0)), iv(const(-44)))
  251. assert not g1.implies(g2)
  252. assert not g2.implies(g1)
  253. # if x*2/2 > -44 => x*2/2 > -55
  254. g1 = guard_true(rop.INT_GE, iv(key, coeff=(2,2,0)), iv(const(-44)))
  255. g2 = guard_true(rop.INT_GE, iv(key, coeff=(2,2,0)), iv(const(-55)))
  256. assert g1.implies(g2)
  257. assert not g2.implies(g1)
  258. def test_imply_coeff(self):
  259. key = box(1)
  260. key2 = box(2)
  261. # if x > y * 9/3 => x > y
  262. # counter: x = -2, y = -1, -2 > -3 => -2 > -1, True => False
  263. g1 = guard_true(rop.INT_GT, iv(key, coeff=(1,1,0)), iv(box(1),coeff=(9,3,0)))
  264. g2 = guard_true(rop.INT_GT, iv(key, coeff=(1,1,0)), iv(box(1),coeff=(1,1,0)))
  265. assert not g1.implies(g2)
  266. assert not g2.implies(g1)
  267. # if x > y * 15/5 <=> x > y * 3
  268. g1 = guard_true(rop.INT_GT, iv(key, coeff=(1,1,0)), iv(key2,coeff=(15,5,0)))
  269. g2 = guard_true(rop.INT_GT, iv(key, coeff=(1,1,0)), iv(key2,coeff=(3,1,0)))
  270. assert g1.implies(g2)
  271. assert g2.implies(g1)
  272. # x >= y => x*3-5 >= y
  273. # counter: 1 >= 0 => 1*3-5 >= 0 == -2 >= 0, True => False
  274. g1 = guard_true(rop.INT_GE, iv(key, coeff=(1,1,0)), iv(key2))
  275. g2 = guard_true(rop.INT_GE, iv(key, coeff=(3,1,-5)), iv(key2))
  276. assert not g1.implies(g2)
  277. assert not g2.implies(g1)
  278. # guard false inverst >= to <
  279. # x < y => x*3-5 < y
  280. # counter: 3 < 4 => 3*3-5 < 4 == 4 < 4, True => False
  281. g1 = guard_false(rop.INT_GE, iv(key, coeff=(1,1,0)), iv(key2))
  282. g2 = guard_false(rop.INT_GE, iv(key, coeff=(3,1,-5)), iv(key2))
  283. assert not g1.implies(g2)
  284. assert not g2.implies(g1)
  285. # x <= y => x*3-5 > y
  286. # counter: 3 < 4 => 3*3-5 < 4 == 4 < 4, True => False
  287. g1 = guard_false(rop.INT_GT, iv(key, coeff=(1,1,0)), iv(key2))
  288. g2 = guard_true(rop.INT_GT, iv(key, coeff=(3,1,-5)), iv(key2))
  289. assert not g1.implies(g2)
  290. assert not g2.implies(g1)
  291. def test_collapse(self):
  292. loop1 = self.parse_trace("""
  293. i10 = int_gt(i1, 42)
  294. guard_true(i10) []
  295. i11 = int_add(i1, 1)
  296. i12 = int_gt(i11, i2)
  297. guard_true(i12) []
  298. """)
  299. opt = self.optguards(loop1, True)
  300. self.assert_guard_count(loop1, 2)
  301. self.assert_contains_sequence(loop1, """
  302. ...
  303. i100 = int_ge(42, i2)
  304. guard_true(i100) []
  305. ...
  306. i40 = int_gt(i1, 42)
  307. guard_true(i40) []
  308. ...
  309. """)
  310. class Test(GuardBaseTest, LLtypeMixin):
  311. pass