PageRenderTime 60ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 0ms

/rpython/jit/metainterp/test/test_recursive.py

https://bitbucket.org/kcr/pypy
Python | 1271 lines | 1269 code | 2 blank | 0 comment | 4 complexity | ff12caff748119a841a4cfd4e16d535f MD5 | raw file
Possible License(s): Apache-2.0
  1. import py
  2. from rpython.rlib.jit import JitDriver, hint, set_param
  3. from rpython.rlib.jit import unroll_safe, dont_look_inside, promote
  4. from rpython.rlib.objectmodel import we_are_translated
  5. from rpython.rlib.debug import fatalerror
  6. from rpython.jit.metainterp.test.support import LLJitMixin, OOJitMixin
  7. from rpython.jit.codewriter.policy import StopAtXPolicy
  8. from rpython.rtyper.annlowlevel import hlstr
  9. from rpython.jit.metainterp.warmspot import get_stats
  10. class RecursiveTests:
  11. def test_simple_recursion(self):
  12. myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  13. def f(n):
  14. m = n - 2
  15. while True:
  16. myjitdriver.jit_merge_point(n=n, m=m)
  17. n -= 1
  18. if m == n:
  19. return main(n) * 2
  20. myjitdriver.can_enter_jit(n=n, m=m)
  21. def main(n):
  22. if n > 0:
  23. return f(n+1)
  24. else:
  25. return 1
  26. res = self.meta_interp(main, [20], enable_opts='')
  27. assert res == main(20)
  28. self.check_history(call=0)
  29. def test_simple_recursion_with_exc(self):
  30. myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  31. class Error(Exception):
  32. pass
  33. def f(n):
  34. m = n - 2
  35. while True:
  36. myjitdriver.jit_merge_point(n=n, m=m)
  37. n -= 1
  38. if n == 10:
  39. raise Error
  40. if m == n:
  41. try:
  42. return main(n) * 2
  43. except Error:
  44. return 2
  45. myjitdriver.can_enter_jit(n=n, m=m)
  46. def main(n):
  47. if n > 0:
  48. return f(n+1)
  49. else:
  50. return 1
  51. res = self.meta_interp(main, [20], enable_opts='')
  52. assert res == main(20)
  53. def test_recursion_three_times(self):
  54. myjitdriver = JitDriver(greens=[], reds=['n', 'm', 'total'])
  55. def f(n):
  56. m = n - 3
  57. total = 0
  58. while True:
  59. myjitdriver.jit_merge_point(n=n, m=m, total=total)
  60. n -= 1
  61. total += main(n)
  62. if m == n:
  63. return total + 5
  64. myjitdriver.can_enter_jit(n=n, m=m, total=total)
  65. def main(n):
  66. if n > 0:
  67. return f(n)
  68. else:
  69. return 1
  70. print
  71. for i in range(1, 11):
  72. print '%3d %9d' % (i, f(i))
  73. res = self.meta_interp(main, [10], enable_opts='')
  74. assert res == main(10)
  75. self.check_enter_count_at_most(11)
  76. def test_bug_1(self):
  77. myjitdriver = JitDriver(greens=[], reds=['n', 'i', 'stack'])
  78. def opaque(n, i):
  79. if n == 1 and i == 19:
  80. for j in range(20):
  81. res = f(0) # recurse repeatedly, 20 times
  82. assert res == 0
  83. def f(n):
  84. stack = [n]
  85. i = 0
  86. while i < 20:
  87. myjitdriver.can_enter_jit(n=n, i=i, stack=stack)
  88. myjitdriver.jit_merge_point(n=n, i=i, stack=stack)
  89. opaque(n, i)
  90. i += 1
  91. return stack.pop()
  92. res = self.meta_interp(f, [1], enable_opts='', repeat=2,
  93. policy=StopAtXPolicy(opaque))
  94. assert res == 1
  95. def get_interpreter(self, codes):
  96. ADD = "0"
  97. JUMP_BACK = "1"
  98. CALL = "2"
  99. EXIT = "3"
  100. def getloc(i, code):
  101. return 'code="%s", i=%d' % (code, i)
  102. jitdriver = JitDriver(greens = ['i', 'code'], reds = ['n'],
  103. get_printable_location = getloc)
  104. def interpret(codenum, n, i):
  105. code = codes[codenum]
  106. while i < len(code):
  107. jitdriver.jit_merge_point(n=n, i=i, code=code)
  108. op = code[i]
  109. if op == ADD:
  110. n += 1
  111. i += 1
  112. elif op == CALL:
  113. n = interpret(1, n, 1)
  114. i += 1
  115. elif op == JUMP_BACK:
  116. if n > 20:
  117. return 42
  118. i -= 2
  119. jitdriver.can_enter_jit(n=n, i=i, code=code)
  120. elif op == EXIT:
  121. return n
  122. else:
  123. raise NotImplementedError
  124. return n
  125. return interpret
  126. def test_inline(self):
  127. code = "021"
  128. subcode = "00"
  129. codes = [code, subcode]
  130. f = self.get_interpreter(codes)
  131. assert self.meta_interp(f, [0, 0, 0], enable_opts='') == 42
  132. self.check_resops(call_may_force=1, int_add=1, call=0)
  133. assert self.meta_interp(f, [0, 0, 0], enable_opts='',
  134. inline=True) == 42
  135. self.check_resops(call=0, int_add=2, call_may_force=0,
  136. guard_no_exception=0)
  137. def test_inline_jitdriver_check(self):
  138. code = "021"
  139. subcode = "100"
  140. codes = [code, subcode]
  141. f = self.get_interpreter(codes)
  142. assert self.meta_interp(f, [0, 0, 0], enable_opts='',
  143. inline=True) == 42
  144. # the call is fully inlined, because we jump to subcode[1], thus
  145. # skipping completely the JUMP_BACK in subcode[0]
  146. self.check_resops(call=0, call_may_force=0, call_assembler=0)
  147. def test_guard_failure_in_inlined_function(self):
  148. def p(pc, code):
  149. code = hlstr(code)
  150. return "%s %d %s" % (code, pc, code[pc])
  151. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  152. get_printable_location=p)
  153. def f(code, n):
  154. pc = 0
  155. while pc < len(code):
  156. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  157. op = code[pc]
  158. if op == "-":
  159. n -= 1
  160. elif op == "c":
  161. n = f("---i---", n)
  162. elif op == "i":
  163. if n % 5 == 1:
  164. return n
  165. elif op == "l":
  166. if n > 0:
  167. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  168. pc = 0
  169. continue
  170. else:
  171. assert 0
  172. pc += 1
  173. return n
  174. def main(n):
  175. return f("c-l", n)
  176. print main(100)
  177. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  178. assert res == 0
  179. def test_guard_failure_and_then_exception_in_inlined_function(self):
  180. def p(pc, code):
  181. code = hlstr(code)
  182. return "%s %d %s" % (code, pc, code[pc])
  183. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n', 'flag'],
  184. get_printable_location=p)
  185. def f(code, n):
  186. pc = 0
  187. flag = False
  188. while pc < len(code):
  189. myjitdriver.jit_merge_point(n=n, code=code, pc=pc, flag=flag)
  190. op = code[pc]
  191. if op == "-":
  192. n -= 1
  193. elif op == "c":
  194. try:
  195. n = f("---ir---", n)
  196. except Exception:
  197. return n
  198. elif op == "i":
  199. if n < 200:
  200. flag = True
  201. elif op == "r":
  202. if flag:
  203. raise Exception
  204. elif op == "l":
  205. if n > 0:
  206. myjitdriver.can_enter_jit(n=n, code=code, pc=0, flag=flag)
  207. pc = 0
  208. continue
  209. else:
  210. assert 0
  211. pc += 1
  212. return n
  213. def main(n):
  214. return f("c-l", n)
  215. print main(1000)
  216. res = self.meta_interp(main, [1000], enable_opts='', inline=True)
  217. assert res == main(1000)
  218. def test_exception_in_inlined_function(self):
  219. def p(pc, code):
  220. code = hlstr(code)
  221. return "%s %d %s" % (code, pc, code[pc])
  222. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  223. get_printable_location=p)
  224. class Exc(Exception):
  225. pass
  226. def f(code, n):
  227. pc = 0
  228. while pc < len(code):
  229. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  230. op = code[pc]
  231. if op == "-":
  232. n -= 1
  233. elif op == "c":
  234. try:
  235. n = f("---i---", n)
  236. except Exc:
  237. pass
  238. elif op == "i":
  239. if n % 5 == 1:
  240. raise Exc
  241. elif op == "l":
  242. if n > 0:
  243. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  244. pc = 0
  245. continue
  246. else:
  247. assert 0
  248. pc += 1
  249. return n
  250. def main(n):
  251. return f("c-l", n)
  252. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  253. assert res == main(100)
  254. def test_recurse_during_blackholing(self):
  255. # this passes, if the blackholing shortcut for calls is turned off
  256. # it fails, it is very delicate in terms of parameters,
  257. # bridge/loop creation order
  258. def p(pc, code):
  259. code = hlstr(code)
  260. return "%s %d %s" % (code, pc, code[pc])
  261. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  262. get_printable_location=p)
  263. def f(code, n):
  264. pc = 0
  265. while pc < len(code):
  266. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  267. op = code[pc]
  268. if op == "-":
  269. n -= 1
  270. elif op == "c":
  271. if n < 70 and n % 3 == 1:
  272. n = f("--", n)
  273. elif op == "l":
  274. if n > 0:
  275. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  276. pc = 0
  277. continue
  278. else:
  279. assert 0
  280. pc += 1
  281. return n
  282. def main(n):
  283. set_param(None, 'threshold', 3)
  284. set_param(None, 'trace_eagerness', 5)
  285. return f("c-l", n)
  286. expected = main(100)
  287. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  288. assert res == expected
  289. def check_max_trace_length(self, length):
  290. for loop in get_stats().loops:
  291. assert len(loop.operations) <= length + 5 # because we only check once per metainterp bytecode
  292. for op in loop.operations:
  293. if op.is_guard() and hasattr(op.getdescr(), '_debug_suboperations'):
  294. assert len(op.getdescr()._debug_suboperations) <= length + 5
  295. def test_inline_trace_limit(self):
  296. myjitdriver = JitDriver(greens=[], reds=['n'])
  297. def recursive(n):
  298. if n > 0:
  299. return recursive(n - 1) + 1
  300. return 0
  301. def loop(n):
  302. set_param(myjitdriver, "threshold", 10)
  303. pc = 0
  304. while n:
  305. myjitdriver.can_enter_jit(n=n)
  306. myjitdriver.jit_merge_point(n=n)
  307. n = recursive(n)
  308. n -= 1
  309. return n
  310. TRACE_LIMIT = 66
  311. res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
  312. assert res == 0
  313. self.check_max_trace_length(TRACE_LIMIT)
  314. self.check_enter_count_at_most(10) # maybe
  315. self.check_aborted_count(7)
  316. def test_trace_limit_bridge(self):
  317. def recursive(n):
  318. if n > 0:
  319. return recursive(n - 1) + 1
  320. return 0
  321. myjitdriver = JitDriver(greens=[], reds=['n'])
  322. def loop(n):
  323. set_param(None, "threshold", 4)
  324. set_param(None, "trace_eagerness", 2)
  325. while n:
  326. myjitdriver.can_enter_jit(n=n)
  327. myjitdriver.jit_merge_point(n=n)
  328. if n % 5 == 0:
  329. n -= 1
  330. if n < 50:
  331. n = recursive(n)
  332. n -= 1
  333. return n
  334. TRACE_LIMIT = 20
  335. res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
  336. self.check_max_trace_length(TRACE_LIMIT)
  337. self.check_aborted_count(8)
  338. self.check_enter_count_at_most(30)
  339. def test_trace_limit_with_exception_bug(self):
  340. myjitdriver = JitDriver(greens=[], reds=['n'])
  341. @unroll_safe
  342. def do_stuff(n):
  343. while n > 0:
  344. n -= 1
  345. raise ValueError
  346. def loop(n):
  347. pc = 0
  348. while n > 80:
  349. myjitdriver.can_enter_jit(n=n)
  350. myjitdriver.jit_merge_point(n=n)
  351. try:
  352. do_stuff(n)
  353. except ValueError:
  354. # the trace limit is checked when we arrive here, and we
  355. # have the exception still in last_exc_value_box at this
  356. # point -- so when we abort because of a trace too long,
  357. # the exception is passed to the blackhole interp and
  358. # incorrectly re-raised from here
  359. pass
  360. n -= 1
  361. return n
  362. TRACE_LIMIT = 66
  363. res = self.meta_interp(loop, [100], trace_limit=TRACE_LIMIT)
  364. assert res == 80
  365. def test_max_failure_args(self):
  366. FAILARGS_LIMIT = 10
  367. jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
  368. class A(object):
  369. def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
  370. self.i0 = i0
  371. self.i1 = i1
  372. self.i2 = i2
  373. self.i3 = i3
  374. self.i4 = i4
  375. self.i5 = i5
  376. self.i6 = i6
  377. self.i7 = i7
  378. self.i8 = i8
  379. self.i9 = i9
  380. def loop(n):
  381. i = 0
  382. o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
  383. while i < n:
  384. jitdriver.can_enter_jit(o=o, i=i, n=n)
  385. jitdriver.jit_merge_point(o=o, i=i, n=n)
  386. o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
  387. i + 6, i + 7, i + 8, i + 9)
  388. i += 1
  389. return o
  390. res = self.meta_interp(loop, [20], failargs_limit=FAILARGS_LIMIT,
  391. listops=True)
  392. self.check_aborted_count(5)
  393. def test_max_failure_args_exc(self):
  394. FAILARGS_LIMIT = 10
  395. jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
  396. class A(object):
  397. def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
  398. self.i0 = i0
  399. self.i1 = i1
  400. self.i2 = i2
  401. self.i3 = i3
  402. self.i4 = i4
  403. self.i5 = i5
  404. self.i6 = i6
  405. self.i7 = i7
  406. self.i8 = i8
  407. self.i9 = i9
  408. def loop(n):
  409. i = 0
  410. o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
  411. while i < n:
  412. jitdriver.can_enter_jit(o=o, i=i, n=n)
  413. jitdriver.jit_merge_point(o=o, i=i, n=n)
  414. o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
  415. i + 6, i + 7, i + 8, i + 9)
  416. i += 1
  417. raise ValueError
  418. def main(n):
  419. try:
  420. loop(n)
  421. return 1
  422. except ValueError:
  423. return 0
  424. res = self.meta_interp(main, [20], failargs_limit=FAILARGS_LIMIT,
  425. listops=True)
  426. assert not res
  427. self.check_aborted_count(5)
  428. def test_set_param_inlining(self):
  429. myjitdriver = JitDriver(greens=[], reds=['n', 'recurse'])
  430. def loop(n, recurse=False):
  431. while n:
  432. myjitdriver.jit_merge_point(n=n, recurse=recurse)
  433. n -= 1
  434. if not recurse:
  435. loop(10, True)
  436. myjitdriver.can_enter_jit(n=n, recurse=recurse)
  437. return n
  438. TRACE_LIMIT = 66
  439. def main(inline):
  440. set_param(None, "threshold", 10)
  441. set_param(None, 'function_threshold', 60)
  442. if inline:
  443. set_param(None, 'inlining', True)
  444. else:
  445. set_param(None, 'inlining', False)
  446. return loop(100)
  447. res = self.meta_interp(main, [0], enable_opts='', trace_limit=TRACE_LIMIT)
  448. self.check_resops(call=0, call_may_force=1)
  449. res = self.meta_interp(main, [1], enable_opts='', trace_limit=TRACE_LIMIT)
  450. self.check_resops(call=0, call_may_force=0)
  451. def test_trace_from_start(self):
  452. def p(pc, code):
  453. code = hlstr(code)
  454. return "'%s' at %d: %s" % (code, pc, code[pc])
  455. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  456. get_printable_location=p)
  457. def f(code, n):
  458. pc = 0
  459. while pc < len(code):
  460. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  461. op = code[pc]
  462. if op == "+":
  463. n += 7
  464. elif op == "-":
  465. n -= 1
  466. elif op == "c":
  467. n = f('---', n)
  468. elif op == "l":
  469. if n > 0:
  470. myjitdriver.can_enter_jit(n=n, code=code, pc=1)
  471. pc = 1
  472. continue
  473. else:
  474. assert 0
  475. pc += 1
  476. return n
  477. def g(m):
  478. if m > 1000000:
  479. f('', 0)
  480. result = 0
  481. for i in range(m):
  482. result += f('+-cl--', i)
  483. res = self.meta_interp(g, [50], backendopt=True)
  484. assert res == g(50)
  485. py.test.skip("tracing from start is by now only longer enabled "
  486. "if a trace gets too big")
  487. self.check_tree_loop_count(3)
  488. self.check_history(int_add=1)
  489. def test_dont_inline_huge_stuff(self):
  490. def p(pc, code):
  491. code = hlstr(code)
  492. return "%s %d %s" % (code, pc, code[pc])
  493. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  494. get_printable_location=p)
  495. def f(code, n):
  496. pc = 0
  497. while pc < len(code):
  498. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  499. op = code[pc]
  500. if op == "-":
  501. n -= 1
  502. elif op == "c":
  503. f('--------------------', n)
  504. elif op == "l":
  505. if n > 0:
  506. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  507. pc = 0
  508. continue
  509. else:
  510. assert 0
  511. pc += 1
  512. return n
  513. def g(m):
  514. set_param(None, 'inlining', True)
  515. # carefully chosen threshold to make sure that the inner function
  516. # cannot be inlined, but the inner function on its own is small
  517. # enough
  518. set_param(None, 'trace_limit', 40)
  519. if m > 1000000:
  520. f('', 0)
  521. result = 0
  522. for i in range(m):
  523. result += f('-c-----------l-', i+100)
  524. self.meta_interp(g, [10], backendopt=True)
  525. self.check_aborted_count(1)
  526. self.check_resops(call=0, call_assembler=2)
  527. self.check_jitcell_token_count(2)
  528. def test_directly_call_assembler(self):
  529. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  530. get_printable_location = lambda codeno : str(codeno))
  531. def portal(codeno):
  532. i = 0
  533. while i < 10:
  534. driver.can_enter_jit(codeno = codeno, i = i)
  535. driver.jit_merge_point(codeno = codeno, i = i)
  536. if codeno == 2:
  537. portal(1)
  538. i += 1
  539. self.meta_interp(portal, [2], inline=True)
  540. self.check_history(call_assembler=1)
  541. def test_recursion_cant_call_assembler_directly(self):
  542. driver = JitDriver(greens = ['codeno'], reds = ['i', 'j'],
  543. get_printable_location = lambda codeno : str(codeno))
  544. def portal(codeno, j):
  545. i = 1
  546. while 1:
  547. driver.jit_merge_point(codeno=codeno, i=i, j=j)
  548. if (i >> 1) == 1:
  549. if j == 0:
  550. return
  551. portal(2, j - 1)
  552. elif i == 5:
  553. return
  554. i += 1
  555. driver.can_enter_jit(codeno=codeno, i=i, j=j)
  556. portal(2, 5)
  557. from rpython.jit.metainterp import compile, pyjitpl
  558. pyjitpl._warmrunnerdesc = None
  559. trace = []
  560. def my_ctc(*args):
  561. looptoken = original_ctc(*args)
  562. trace.append(looptoken)
  563. return looptoken
  564. original_ctc = compile.compile_tmp_callback
  565. try:
  566. compile.compile_tmp_callback = my_ctc
  567. self.meta_interp(portal, [2, 5], inline=True)
  568. self.check_resops(call_may_force=0, call_assembler=2)
  569. finally:
  570. compile.compile_tmp_callback = original_ctc
  571. # check that we made a temporary callback
  572. assert len(trace) == 1
  573. # and that we later redirected it to something else
  574. try:
  575. redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
  576. except AttributeError:
  577. pass # not the llgraph backend
  578. else:
  579. print redirected
  580. assert redirected.keys() == trace
  581. def test_recursion_cant_call_assembler_directly_with_virtualizable(self):
  582. # exactly the same logic as the previous test, but with 'frame.j'
  583. # instead of just 'j'
  584. class Frame(object):
  585. _virtualizable2_ = ['j']
  586. def __init__(self, j):
  587. self.j = j
  588. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  589. virtualizables = ['frame'],
  590. get_printable_location = lambda codeno : str(codeno))
  591. def portal(codeno, frame):
  592. i = 1
  593. while 1:
  594. driver.jit_merge_point(codeno=codeno, i=i, frame=frame)
  595. if (i >> 1) == 1:
  596. if frame.j == 0:
  597. return
  598. portal(2, Frame(frame.j - 1))
  599. elif i == 5:
  600. return
  601. i += 1
  602. driver.can_enter_jit(codeno=codeno, i=i, frame=frame)
  603. def main(codeno, j):
  604. portal(codeno, Frame(j))
  605. main(2, 5)
  606. from rpython.jit.metainterp import compile, pyjitpl
  607. pyjitpl._warmrunnerdesc = None
  608. trace = []
  609. def my_ctc(*args):
  610. looptoken = original_ctc(*args)
  611. trace.append(looptoken)
  612. return looptoken
  613. original_ctc = compile.compile_tmp_callback
  614. try:
  615. compile.compile_tmp_callback = my_ctc
  616. self.meta_interp(main, [2, 5], inline=True)
  617. self.check_resops(call_may_force=0, call_assembler=2)
  618. finally:
  619. compile.compile_tmp_callback = original_ctc
  620. # check that we made a temporary callback
  621. assert len(trace) == 1
  622. # and that we later redirected it to something else
  623. try:
  624. redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
  625. except AttributeError:
  626. pass # not the llgraph backend
  627. else:
  628. print redirected
  629. assert redirected.keys() == trace
  630. def test_directly_call_assembler_return(self):
  631. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  632. get_printable_location = lambda codeno : str(codeno))
  633. def portal(codeno):
  634. i = 0
  635. k = codeno
  636. while i < 10:
  637. driver.can_enter_jit(codeno = codeno, i = i, k = k)
  638. driver.jit_merge_point(codeno = codeno, i = i, k = k)
  639. if codeno == 2:
  640. k = portal(1)
  641. i += 1
  642. return k
  643. self.meta_interp(portal, [2], inline=True)
  644. self.check_history(call_assembler=1)
  645. def test_directly_call_assembler_raise(self):
  646. class MyException(Exception):
  647. def __init__(self, x):
  648. self.x = x
  649. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  650. get_printable_location = lambda codeno : str(codeno))
  651. def portal(codeno):
  652. i = 0
  653. while i < 10:
  654. driver.can_enter_jit(codeno = codeno, i = i)
  655. driver.jit_merge_point(codeno = codeno, i = i)
  656. if codeno == 2:
  657. try:
  658. portal(1)
  659. except MyException, me:
  660. i += me.x
  661. i += 1
  662. if codeno == 1:
  663. raise MyException(1)
  664. self.meta_interp(portal, [2], inline=True)
  665. self.check_history(call_assembler=1)
  666. def test_directly_call_assembler_fail_guard(self):
  667. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  668. get_printable_location = lambda codeno : str(codeno))
  669. def portal(codeno, k):
  670. i = 0
  671. while i < 10:
  672. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  673. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  674. if codeno == 2:
  675. k += portal(1, k)
  676. elif k > 40:
  677. if i % 2:
  678. k += 1
  679. else:
  680. k += 2
  681. k += 1
  682. i += 1
  683. return k
  684. res = self.meta_interp(portal, [2, 0], inline=True)
  685. assert res == 13542
  686. def test_directly_call_assembler_virtualizable(self):
  687. class Thing(object):
  688. def __init__(self, val):
  689. self.val = val
  690. class Frame(object):
  691. _virtualizable2_ = ['thing']
  692. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  693. virtualizables = ['frame'],
  694. get_printable_location = lambda codeno : str(codeno))
  695. def main(codeno):
  696. frame = Frame()
  697. frame.thing = Thing(0)
  698. portal(codeno, frame)
  699. return frame.thing.val
  700. def portal(codeno, frame):
  701. i = 0
  702. while i < 10:
  703. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  704. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  705. nextval = frame.thing.val
  706. if codeno == 0:
  707. subframe = Frame()
  708. subframe.thing = Thing(nextval)
  709. nextval = portal(1, subframe)
  710. frame.thing = Thing(nextval + 1)
  711. i += 1
  712. return frame.thing.val
  713. res = self.meta_interp(main, [0], inline=True)
  714. assert res == main(0)
  715. def test_directly_call_assembler_virtualizable_reset_token(self):
  716. from rpython.rtyper.lltypesystem import lltype
  717. from rpython.rlib.debug import llinterpcall
  718. class Thing(object):
  719. def __init__(self, val):
  720. self.val = val
  721. class Frame(object):
  722. _virtualizable2_ = ['thing']
  723. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  724. virtualizables = ['frame'],
  725. get_printable_location = lambda codeno : str(codeno))
  726. @dont_look_inside
  727. def check_frame(subframe):
  728. if we_are_translated():
  729. llinterpcall(lltype.Void, check_ll_frame, subframe)
  730. def check_ll_frame(ll_subframe):
  731. # This is called with the low-level Struct that is the frame.
  732. # Check that the vable_token was correctly reset to zero.
  733. # Note that in order for that test to catch failures, it needs
  734. # three levels of recursion: the vable_token of the subframe
  735. # at the level 2 is set to a non-zero value when doing the
  736. # call to the level 3 only. This used to fail when the test
  737. # is run via rpython.jit.backend.x86.test.test_recursive.
  738. from rpython.jit.metainterp.virtualizable import TOKEN_NONE
  739. assert ll_subframe.vable_token == TOKEN_NONE
  740. def main(codeno):
  741. frame = Frame()
  742. frame.thing = Thing(0)
  743. portal(codeno, frame)
  744. return frame.thing.val
  745. def portal(codeno, frame):
  746. i = 0
  747. while i < 5:
  748. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  749. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  750. nextval = frame.thing.val
  751. if codeno < 2:
  752. subframe = Frame()
  753. subframe.thing = Thing(nextval)
  754. nextval = portal(codeno + 1, subframe)
  755. check_frame(subframe)
  756. frame.thing = Thing(nextval + 1)
  757. i += 1
  758. return frame.thing.val
  759. res = self.meta_interp(main, [0], inline=True)
  760. assert res == main(0)
  761. def test_directly_call_assembler_virtualizable_force1(self):
  762. class Thing(object):
  763. def __init__(self, val):
  764. self.val = val
  765. class Frame(object):
  766. _virtualizable2_ = ['thing']
  767. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  768. virtualizables = ['frame'],
  769. get_printable_location = lambda codeno : str(codeno))
  770. class SomewhereElse(object):
  771. pass
  772. somewhere_else = SomewhereElse()
  773. def change(newthing):
  774. somewhere_else.frame.thing = newthing
  775. def main(codeno):
  776. frame = Frame()
  777. somewhere_else.frame = frame
  778. frame.thing = Thing(0)
  779. portal(codeno, frame)
  780. return frame.thing.val
  781. def portal(codeno, frame):
  782. print 'ENTER:', codeno, frame.thing.val
  783. i = 0
  784. while i < 10:
  785. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  786. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  787. nextval = frame.thing.val
  788. if codeno == 0:
  789. subframe = Frame()
  790. subframe.thing = Thing(nextval)
  791. nextval = portal(1, subframe)
  792. elif codeno == 1:
  793. if frame.thing.val > 40:
  794. change(Thing(13))
  795. nextval = 13
  796. else:
  797. fatalerror("bad codeno = " + str(codeno))
  798. frame.thing = Thing(nextval + 1)
  799. i += 1
  800. print 'LEAVE:', codeno, frame.thing.val
  801. return frame.thing.val
  802. res = self.meta_interp(main, [0], inline=True,
  803. policy=StopAtXPolicy(change))
  804. assert res == main(0)
  805. def test_directly_call_assembler_virtualizable_with_array(self):
  806. myjitdriver = JitDriver(greens = ['codeno'], reds = ['n', 'x', 'frame'],
  807. virtualizables = ['frame'])
  808. class Frame(object):
  809. _virtualizable2_ = ['l[*]', 's']
  810. def __init__(self, l, s):
  811. self = hint(self, access_directly=True,
  812. fresh_virtualizable=True)
  813. self.l = l
  814. self.s = s
  815. def main(codeno, n, a):
  816. frame = Frame([a, a+1, a+2, a+3], 0)
  817. return f(codeno, n, a, frame)
  818. def f(codeno, n, a, frame):
  819. x = 0
  820. while n > 0:
  821. myjitdriver.can_enter_jit(codeno=codeno, frame=frame, n=n, x=x)
  822. myjitdriver.jit_merge_point(codeno=codeno, frame=frame, n=n,
  823. x=x)
  824. frame.s = promote(frame.s)
  825. n -= 1
  826. s = frame.s
  827. assert s >= 0
  828. x += frame.l[s]
  829. frame.s += 1
  830. if codeno == 0:
  831. subframe = Frame([n, n+1, n+2, n+3], 0)
  832. x += f(1, 10, 1, subframe)
  833. s = frame.s
  834. assert s >= 0
  835. x += frame.l[s]
  836. x += len(frame.l)
  837. frame.s -= 1
  838. return x
  839. res = self.meta_interp(main, [0, 10, 1], listops=True, inline=True)
  840. assert res == main(0, 10, 1)
  841. def test_directly_call_assembler_virtualizable_force_blackhole(self):
  842. class Thing(object):
  843. def __init__(self, val):
  844. self.val = val
  845. class Frame(object):
  846. _virtualizable2_ = ['thing']
  847. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  848. virtualizables = ['frame'],
  849. get_printable_location = lambda codeno : str(codeno))
  850. class SomewhereElse(object):
  851. pass
  852. somewhere_else = SomewhereElse()
  853. def change(newthing, arg):
  854. print arg
  855. if arg > 30:
  856. somewhere_else.frame.thing = newthing
  857. arg = 13
  858. return arg
  859. def main(codeno):
  860. frame = Frame()
  861. somewhere_else.frame = frame
  862. frame.thing = Thing(0)
  863. portal(codeno, frame)
  864. return frame.thing.val
  865. def portal(codeno, frame):
  866. i = 0
  867. while i < 10:
  868. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  869. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  870. nextval = frame.thing.val
  871. if codeno == 0:
  872. subframe = Frame()
  873. subframe.thing = Thing(nextval)
  874. nextval = portal(1, subframe)
  875. else:
  876. nextval = change(Thing(13), frame.thing.val)
  877. frame.thing = Thing(nextval + 1)
  878. i += 1
  879. return frame.thing.val
  880. res = self.meta_interp(main, [0], inline=True,
  881. policy=StopAtXPolicy(change))
  882. assert res == main(0)
  883. def test_assembler_call_red_args(self):
  884. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  885. get_printable_location = lambda codeno : str(codeno))
  886. def residual(k):
  887. if k > 150:
  888. return 0
  889. return 1
  890. def portal(codeno, k):
  891. i = 0
  892. while i < 15:
  893. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  894. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  895. if codeno == 2:
  896. k += portal(residual(k), k)
  897. if codeno == 0:
  898. k += 2
  899. elif codeno == 1:
  900. k += 1
  901. i += 1
  902. return k
  903. res = self.meta_interp(portal, [2, 0], inline=True,
  904. policy=StopAtXPolicy(residual))
  905. assert res == portal(2, 0)
  906. self.check_resops(call_assembler=4)
  907. def test_inline_without_hitting_the_loop(self):
  908. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  909. get_printable_location = lambda codeno : str(codeno))
  910. def portal(codeno):
  911. i = 0
  912. while True:
  913. driver.jit_merge_point(codeno=codeno, i=i)
  914. if codeno < 10:
  915. i += portal(20)
  916. codeno += 1
  917. elif codeno == 10:
  918. if i > 63:
  919. return i
  920. codeno = 0
  921. driver.can_enter_jit(codeno=codeno, i=i)
  922. else:
  923. return 1
  924. assert portal(0) == 70
  925. res = self.meta_interp(portal, [0], inline=True)
  926. assert res == 70
  927. self.check_resops(call_assembler=0)
  928. def test_inline_with_hitting_the_loop_sometimes(self):
  929. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  930. get_printable_location = lambda codeno : str(codeno))
  931. def portal(codeno, k):
  932. if k > 2:
  933. return 1
  934. i = 0
  935. while True:
  936. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  937. if codeno < 10:
  938. i += portal(codeno + 5, k+1)
  939. codeno += 1
  940. elif codeno == 10:
  941. if i > [-1, 2000, 63][k]:
  942. return i
  943. codeno = 0
  944. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  945. else:
  946. return 1
  947. assert portal(0, 1) == 2095
  948. res = self.meta_interp(portal, [0, 1], inline=True)
  949. assert res == 2095
  950. self.check_resops(call_assembler=12)
  951. def test_inline_with_hitting_the_loop_sometimes_exc(self):
  952. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  953. get_printable_location = lambda codeno : str(codeno))
  954. class GotValue(Exception):
  955. def __init__(self, result):
  956. self.result = result
  957. def portal(codeno, k):
  958. if k > 2:
  959. raise GotValue(1)
  960. i = 0
  961. while True:
  962. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  963. if codeno < 10:
  964. try:
  965. portal(codeno + 5, k+1)
  966. except GotValue, e:
  967. i += e.result
  968. codeno += 1
  969. elif codeno == 10:
  970. if i > [-1, 2000, 63][k]:
  971. raise GotValue(i)
  972. codeno = 0
  973. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  974. else:
  975. raise GotValue(1)
  976. def main(codeno, k):
  977. try:
  978. portal(codeno, k)
  979. except GotValue, e:
  980. return e.result
  981. assert main(0, 1) == 2095
  982. res = self.meta_interp(main, [0, 1], inline=True)
  983. assert res == 2095
  984. self.check_resops(call_assembler=12)
  985. def test_handle_jitexception_in_portal(self):
  986. # a test for _handle_jitexception_in_portal in blackhole.py
  987. driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
  988. get_printable_location = lambda codeno: str(codeno))
  989. def do_can_enter_jit(codeno, i, str):
  990. i = (i+1)-1 # some operations
  991. driver.can_enter_jit(codeno=codeno, i=i, str=str)
  992. def intermediate(codeno, i, str):
  993. if i == 9:
  994. do_can_enter_jit(codeno, i, str)
  995. def portal(codeno, str):
  996. i = value.initial
  997. while i < 10:
  998. intermediate(codeno, i, str)
  999. driver.jit_merge_point(codeno=codeno, i=i, str=str)
  1000. i += 1
  1001. if codeno == 64 and i == 10:
  1002. str = portal(96, str)
  1003. str += chr(codeno+i)
  1004. return str
  1005. class Value:
  1006. initial = -1
  1007. value = Value()
  1008. def main():
  1009. value.initial = 0
  1010. return (portal(64, '') +
  1011. portal(64, '') +
  1012. portal(64, '') +
  1013. portal(64, '') +
  1014. portal(64, ''))
  1015. assert main() == 'ABCDEFGHIabcdefghijJ' * 5
  1016. for tlimit in [95, 90, 102]:
  1017. print 'tlimit =', tlimit
  1018. res = self.meta_interp(main, [], inline=True, trace_limit=tlimit)
  1019. assert ''.join(res.chars) == 'ABCDEFGHIabcdefghijJ' * 5
  1020. def test_handle_jitexception_in_portal_returns_void(self):
  1021. # a test for _handle_jitexception_in_portal in blackhole.py
  1022. driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
  1023. get_printable_location = lambda codeno: str(codeno))
  1024. def do_can_enter_jit(codeno, i, str):
  1025. i = (i+1)-1 # some operations
  1026. driver.can_enter_jit(codeno=codeno, i=i, str=str)
  1027. def intermediate(codeno, i, str):
  1028. if i == 9:
  1029. do_can_enter_jit(codeno, i, str)
  1030. def portal(codeno, str):
  1031. i = value.initial
  1032. while i < 10:
  1033. intermediate(codeno, i, str)
  1034. driver.jit_merge_point(codeno=codeno, i=i, str=str)
  1035. i += 1
  1036. if codeno == 64 and i == 10:
  1037. portal(96, str)
  1038. str += chr(codeno+i)
  1039. class Value:
  1040. initial = -1
  1041. value = Value()
  1042. def main():
  1043. value.initial = 0
  1044. portal(64, '')
  1045. portal(64, '')
  1046. portal(64, '')
  1047. portal(64, '')
  1048. portal(64, '')
  1049. main()
  1050. for tlimit in [95, 90, 102]:
  1051. print 'tlimit =', tlimit
  1052. self.meta_interp(main, [], inline=True, trace_limit=tlimit)
  1053. def test_no_duplicates_bug(self):
  1054. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  1055. get_printable_location = lambda codeno: str(codeno))
  1056. def portal(codeno, i):
  1057. while i > 0:
  1058. driver.can_enter_jit(codeno=codeno, i=i)
  1059. driver.jit_merge_point(codeno=codeno, i=i)
  1060. if codeno > 0:
  1061. break
  1062. portal(i, i)
  1063. i -= 1
  1064. self.meta_interp(portal, [0, 10], inline=True)
  1065. def test_trace_from_start_always(self):
  1066. from rpython.rlib.nonconst import NonConstant
  1067. driver = JitDriver(greens = ['c'], reds = ['i', 'v'])
  1068. def portal(c, i, v):
  1069. while i > 0:
  1070. driver.jit_merge_point(c=c, i=i, v=v)
  1071. portal(c, i - 1, v)
  1072. if v:
  1073. driver.can_enter_jit(c=c, i=i, v=v)
  1074. break
  1075. def main(c, i, _set_param, v):
  1076. if _set_param:
  1077. set_param(driver, 'function_threshold', 0)
  1078. portal(c, i, v)
  1079. self.meta_interp(main, [10, 10, False, False], inline=True)
  1080. self.check_jitcell_token_count(1)
  1081. self.check_trace_count(1)
  1082. self.meta_interp(main, [3, 10, True, False], inline=True)
  1083. self.check_jitcell_token_count(0)
  1084. self.check_trace_count(0)
  1085. def test_trace_from_start_does_not_prevent_inlining(self):
  1086. driver = JitDriver(greens = ['c', 'bc'], reds = ['i'])
  1087. def portal(bc, c, i):
  1088. while True:
  1089. driver.jit_merge_point(c=c, bc=bc, i=i)
  1090. if bc == 0:
  1091. portal(1, 8, 0)
  1092. c += 1
  1093. else:
  1094. return
  1095. if c == 10: # bc == 0
  1096. c = 0
  1097. if i >= 100:
  1098. return
  1099. driver.can_enter_jit(c=c, bc=bc, i=i)
  1100. i += 1
  1101. self.meta_interp(portal, [0, 0, 0], inline=True)
  1102. self.check_resops(call_may_force=0, call=0)
  1103. def test_dont_repeatedly_trace_from_the_same_guard(self):
  1104. driver = JitDriver(greens = [], reds = ['level', 'i'])
  1105. def portal(level):
  1106. if level == 0:
  1107. i = -10
  1108. else:
  1109. i = 0
  1110. #
  1111. while True:
  1112. driver.jit_merge_point(level=level, i=i)
  1113. if level == 25:
  1114. return 42
  1115. i += 1
  1116. if i <= 0: # <- guard
  1117. continue # first make a loop
  1118. else:
  1119. # then we fail the guard above, doing a recursive call,
  1120. # which will itself fail the same guard above, and so on
  1121. return portal(level + 1)
  1122. self.meta_interp(portal, [0])
  1123. self.check_trace_count_at_most(2) # and not, e.g., 24
  1124. class TestLLtype(RecursiveTests, LLJitMixin):
  1125. pass
  1126. class TestOOtype(RecursiveTests, OOJitMixin):
  1127. pass