PageRenderTime 54ms CodeModel.GetById 16ms RepoModel.GetById 1ms app.codeStats 0ms

/rpython/jit/metainterp/test/test_recursive.py

https://bitbucket.org/pypy/pypy/
Python | 1337 lines | 1335 code | 2 blank | 0 comment | 4 complexity | fdfc102a999e5ea6cdc589e79e04a623 MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. import py
  2. from rpython.rlib.jit import JitDriver, hint, set_param
  3. from rpython.rlib.jit import unroll_safe, dont_look_inside, promote
  4. from rpython.rlib.objectmodel import we_are_translated
  5. from rpython.rlib.debug import fatalerror
  6. from rpython.jit.metainterp.test.support import LLJitMixin
  7. from rpython.jit.codewriter.policy import StopAtXPolicy
  8. from rpython.rtyper.annlowlevel import hlstr
  9. from rpython.jit.metainterp.warmspot import get_stats
  10. from rpython.jit.backend.llsupport import codemap
  11. class RecursiveTests:
  12. def test_simple_recursion(self):
  13. myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  14. def f(n):
  15. m = n - 2
  16. while True:
  17. myjitdriver.jit_merge_point(n=n, m=m)
  18. n -= 1
  19. if m == n:
  20. return main(n) * 2
  21. myjitdriver.can_enter_jit(n=n, m=m)
  22. def main(n):
  23. if n > 0:
  24. return f(n+1)
  25. else:
  26. return 1
  27. res = self.meta_interp(main, [20], enable_opts='')
  28. assert res == main(20)
  29. self.check_history(call_i=0)
  30. def test_simple_recursion_with_exc(self):
  31. myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  32. class Error(Exception):
  33. pass
  34. def f(n):
  35. m = n - 2
  36. while True:
  37. myjitdriver.jit_merge_point(n=n, m=m)
  38. n -= 1
  39. if n == 10:
  40. raise Error
  41. if m == n:
  42. try:
  43. return main(n) * 2
  44. except Error:
  45. return 2
  46. myjitdriver.can_enter_jit(n=n, m=m)
  47. def main(n):
  48. if n > 0:
  49. return f(n+1)
  50. else:
  51. return 1
  52. res = self.meta_interp(main, [20], enable_opts='')
  53. assert res == main(20)
  54. def test_recursion_three_times(self):
  55. myjitdriver = JitDriver(greens=[], reds=['n', 'm', 'total'])
  56. def f(n):
  57. m = n - 3
  58. total = 0
  59. while True:
  60. myjitdriver.jit_merge_point(n=n, m=m, total=total)
  61. n -= 1
  62. total += main(n)
  63. if m == n:
  64. return total + 5
  65. myjitdriver.can_enter_jit(n=n, m=m, total=total)
  66. def main(n):
  67. if n > 0:
  68. return f(n)
  69. else:
  70. return 1
  71. print
  72. for i in range(1, 11):
  73. print '%3d %9d' % (i, f(i))
  74. res = self.meta_interp(main, [10], enable_opts='')
  75. assert res == main(10)
  76. self.check_enter_count_at_most(11)
  77. def test_bug_1(self):
  78. myjitdriver = JitDriver(greens=[], reds=['n', 'i', 'stack'])
  79. def opaque(n, i):
  80. if n == 1 and i == 19:
  81. for j in range(20):
  82. res = f(0) # recurse repeatedly, 20 times
  83. assert res == 0
  84. def f(n):
  85. stack = [n]
  86. i = 0
  87. while i < 20:
  88. myjitdriver.can_enter_jit(n=n, i=i, stack=stack)
  89. myjitdriver.jit_merge_point(n=n, i=i, stack=stack)
  90. opaque(n, i)
  91. i += 1
  92. return stack.pop()
  93. res = self.meta_interp(f, [1], enable_opts='', repeat=2,
  94. policy=StopAtXPolicy(opaque))
  95. assert res == 1
  96. def get_interpreter(self, codes):
  97. ADD = "0"
  98. JUMP_BACK = "1"
  99. CALL = "2"
  100. EXIT = "3"
  101. def getloc(i, code):
  102. return 'code="%s", i=%d' % (code, i)
  103. jitdriver = JitDriver(greens = ['i', 'code'], reds = ['n'],
  104. get_printable_location = getloc)
  105. def interpret(codenum, n, i):
  106. code = codes[codenum]
  107. while i < len(code):
  108. jitdriver.jit_merge_point(n=n, i=i, code=code)
  109. op = code[i]
  110. if op == ADD:
  111. n += 1
  112. i += 1
  113. elif op == CALL:
  114. n = interpret(1, n, 1)
  115. i += 1
  116. elif op == JUMP_BACK:
  117. if n > 20:
  118. return 42
  119. i -= 2
  120. jitdriver.can_enter_jit(n=n, i=i, code=code)
  121. elif op == EXIT:
  122. return n
  123. else:
  124. raise NotImplementedError
  125. return n
  126. return interpret
  127. def test_inline(self):
  128. code = "021"
  129. subcode = "00"
  130. codes = [code, subcode]
  131. f = self.get_interpreter(codes)
  132. assert self.meta_interp(f, [0, 0, 0], enable_opts='') == 42
  133. self.check_resops(call_may_force_i=1, int_add=1, call=0)
  134. assert self.meta_interp(f, [0, 0, 0], enable_opts='',
  135. inline=True) == 42
  136. self.check_resops(call=0, int_add=2, call_may_force_i=0,
  137. guard_no_exception=0)
  138. def test_inline_jitdriver_check(self):
  139. code = "021"
  140. subcode = "100"
  141. codes = [code, subcode]
  142. f = self.get_interpreter(codes)
  143. assert self.meta_interp(f, [0, 0, 0], enable_opts='',
  144. inline=True) == 42
  145. # the call is fully inlined, because we jump to subcode[1], thus
  146. # skipping completely the JUMP_BACK in subcode[0]
  147. self.check_resops(call=0, call_may_force=0, call_assembler=0)
  148. def test_guard_failure_in_inlined_function(self):
  149. def p(pc, code):
  150. code = hlstr(code)
  151. return "%s %d %s" % (code, pc, code[pc])
  152. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  153. get_printable_location=p)
  154. def f(code, n):
  155. pc = 0
  156. while pc < len(code):
  157. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  158. op = code[pc]
  159. if op == "-":
  160. n -= 1
  161. elif op == "c":
  162. n = f("---i---", n)
  163. elif op == "i":
  164. if n % 5 == 1:
  165. return n
  166. elif op == "l":
  167. if n > 0:
  168. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  169. pc = 0
  170. continue
  171. else:
  172. assert 0
  173. pc += 1
  174. return n
  175. def main(n):
  176. return f("c-l", n)
  177. print main(100)
  178. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  179. assert res == 0
  180. def test_guard_failure_and_then_exception_in_inlined_function(self):
  181. def p(pc, code):
  182. code = hlstr(code)
  183. return "%s %d %s" % (code, pc, code[pc])
  184. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n', 'flag'],
  185. get_printable_location=p)
  186. def f(code, n):
  187. pc = 0
  188. flag = False
  189. while pc < len(code):
  190. myjitdriver.jit_merge_point(n=n, code=code, pc=pc, flag=flag)
  191. op = code[pc]
  192. if op == "-":
  193. n -= 1
  194. elif op == "c":
  195. try:
  196. n = f("---ir---", n)
  197. except Exception:
  198. return n
  199. elif op == "i":
  200. if n < 200:
  201. flag = True
  202. elif op == "r":
  203. if flag:
  204. raise Exception
  205. elif op == "l":
  206. if n > 0:
  207. myjitdriver.can_enter_jit(n=n, code=code, pc=0, flag=flag)
  208. pc = 0
  209. continue
  210. else:
  211. assert 0
  212. pc += 1
  213. return n
  214. def main(n):
  215. return f("c-l", n)
  216. print main(1000)
  217. res = self.meta_interp(main, [1000], enable_opts='', inline=True)
  218. assert res == main(1000)
  219. def test_exception_in_inlined_function(self):
  220. def p(pc, code):
  221. code = hlstr(code)
  222. return "%s %d %s" % (code, pc, code[pc])
  223. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  224. get_printable_location=p)
  225. class Exc(Exception):
  226. pass
  227. def f(code, n):
  228. pc = 0
  229. while pc < len(code):
  230. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  231. op = code[pc]
  232. if op == "-":
  233. n -= 1
  234. elif op == "c":
  235. try:
  236. n = f("---i---", n)
  237. except Exc:
  238. pass
  239. elif op == "i":
  240. if n % 5 == 1:
  241. raise Exc
  242. elif op == "l":
  243. if n > 0:
  244. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  245. pc = 0
  246. continue
  247. else:
  248. assert 0
  249. pc += 1
  250. return n
  251. def main(n):
  252. return f("c-l", n)
  253. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  254. assert res == main(100)
  255. def test_recurse_during_blackholing(self):
  256. # this passes, if the blackholing shortcut for calls is turned off
  257. # it fails, it is very delicate in terms of parameters,
  258. # bridge/loop creation order
  259. def p(pc, code):
  260. code = hlstr(code)
  261. return "%s %d %s" % (code, pc, code[pc])
  262. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  263. get_printable_location=p)
  264. def f(code, n):
  265. pc = 0
  266. while pc < len(code):
  267. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  268. op = code[pc]
  269. if op == "-":
  270. n -= 1
  271. elif op == "c":
  272. if n < 70 and n % 3 == 1:
  273. n = f("--", n)
  274. elif op == "l":
  275. if n > 0:
  276. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  277. pc = 0
  278. continue
  279. else:
  280. assert 0
  281. pc += 1
  282. return n
  283. def main(n):
  284. set_param(None, 'threshold', 3)
  285. set_param(None, 'trace_eagerness', 5)
  286. return f("c-l", n)
  287. expected = main(100)
  288. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  289. assert res == expected
  290. def check_max_trace_length(self, length):
  291. for loop in get_stats().loops:
  292. assert len(loop.operations) <= length + 5 # because we only check once per metainterp bytecode
  293. for op in loop.operations:
  294. if op.is_guard() and hasattr(op.getdescr(), '_debug_suboperations'):
  295. assert len(op.getdescr()._debug_suboperations) <= length + 5
  296. def test_inline_trace_limit(self):
  297. myjitdriver = JitDriver(greens=[], reds=['n'])
  298. def recursive(n):
  299. if n > 0:
  300. return recursive(n - 1) + 1
  301. return 0
  302. def loop(n):
  303. set_param(myjitdriver, "threshold", 10)
  304. pc = 0
  305. while n:
  306. myjitdriver.can_enter_jit(n=n)
  307. myjitdriver.jit_merge_point(n=n)
  308. n = recursive(n)
  309. n -= 1
  310. return n
  311. TRACE_LIMIT = 66
  312. res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
  313. assert res == 0
  314. self.check_max_trace_length(TRACE_LIMIT)
  315. self.check_enter_count_at_most(10) # maybe
  316. self.check_aborted_count(6)
  317. def test_trace_limit_bridge(self):
  318. def recursive(n):
  319. if n > 0:
  320. return recursive(n - 1) + 1
  321. return 0
  322. myjitdriver = JitDriver(greens=[], reds=['n'])
  323. def loop(n):
  324. set_param(None, "threshold", 4)
  325. set_param(None, "trace_eagerness", 2)
  326. while n:
  327. myjitdriver.can_enter_jit(n=n)
  328. myjitdriver.jit_merge_point(n=n)
  329. if n % 5 == 0:
  330. n -= 1
  331. if n < 50:
  332. n = recursive(n)
  333. n -= 1
  334. return n
  335. TRACE_LIMIT = 20
  336. res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
  337. self.check_max_trace_length(TRACE_LIMIT)
  338. self.check_aborted_count(8)
  339. self.check_enter_count_at_most(30)
  340. def test_trace_limit_with_exception_bug(self):
  341. myjitdriver = JitDriver(greens=[], reds=['n'])
  342. @unroll_safe
  343. def do_stuff(n):
  344. while n > 0:
  345. n -= 1
  346. raise ValueError
  347. def loop(n):
  348. pc = 0
  349. while n > 80:
  350. myjitdriver.can_enter_jit(n=n)
  351. myjitdriver.jit_merge_point(n=n)
  352. try:
  353. do_stuff(n)
  354. except ValueError:
  355. # the trace limit is checked when we arrive here, and we
  356. # have the exception still in last_exc_value_box at this
  357. # point -- so when we abort because of a trace too long,
  358. # the exception is passed to the blackhole interp and
  359. # incorrectly re-raised from here
  360. pass
  361. n -= 1
  362. return n
  363. TRACE_LIMIT = 66
  364. res = self.meta_interp(loop, [100], trace_limit=TRACE_LIMIT)
  365. assert res == 80
  366. def test_max_failure_args(self):
  367. FAILARGS_LIMIT = 10
  368. jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
  369. class A(object):
  370. def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
  371. self.i0 = i0
  372. self.i1 = i1
  373. self.i2 = i2
  374. self.i3 = i3
  375. self.i4 = i4
  376. self.i5 = i5
  377. self.i6 = i6
  378. self.i7 = i7
  379. self.i8 = i8
  380. self.i9 = i9
  381. def loop(n):
  382. i = 0
  383. o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
  384. while i < n:
  385. jitdriver.can_enter_jit(o=o, i=i, n=n)
  386. jitdriver.jit_merge_point(o=o, i=i, n=n)
  387. o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
  388. i + 6, i + 7, i + 8, i + 9)
  389. i += 1
  390. return o
  391. res = self.meta_interp(loop, [20], failargs_limit=FAILARGS_LIMIT,
  392. listops=True)
  393. self.check_aborted_count(4)
  394. def test_max_failure_args_exc(self):
  395. FAILARGS_LIMIT = 10
  396. jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
  397. class A(object):
  398. def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
  399. self.i0 = i0
  400. self.i1 = i1
  401. self.i2 = i2
  402. self.i3 = i3
  403. self.i4 = i4
  404. self.i5 = i5
  405. self.i6 = i6
  406. self.i7 = i7
  407. self.i8 = i8
  408. self.i9 = i9
  409. def loop(n):
  410. i = 0
  411. o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
  412. while i < n:
  413. jitdriver.can_enter_jit(o=o, i=i, n=n)
  414. jitdriver.jit_merge_point(o=o, i=i, n=n)
  415. o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
  416. i + 6, i + 7, i + 8, i + 9)
  417. i += 1
  418. raise ValueError
  419. def main(n):
  420. try:
  421. loop(n)
  422. return 1
  423. except ValueError:
  424. return 0
  425. res = self.meta_interp(main, [20], failargs_limit=FAILARGS_LIMIT,
  426. listops=True)
  427. assert not res
  428. self.check_aborted_count(4)
  429. def test_set_param_inlining(self):
  430. myjitdriver = JitDriver(greens=[], reds=['n', 'recurse'])
  431. def loop(n, recurse=False):
  432. while n:
  433. myjitdriver.jit_merge_point(n=n, recurse=recurse)
  434. n -= 1
  435. if not recurse:
  436. loop(10, True)
  437. myjitdriver.can_enter_jit(n=n, recurse=recurse)
  438. return n
  439. TRACE_LIMIT = 66
  440. def main(inline):
  441. set_param(None, "threshold", 10)
  442. set_param(None, 'function_threshold', 60)
  443. if inline:
  444. set_param(None, 'inlining', True)
  445. else:
  446. set_param(None, 'inlining', False)
  447. return loop(100)
  448. res = self.meta_interp(main, [0], enable_opts='', trace_limit=TRACE_LIMIT)
  449. self.check_resops(call=0, call_may_force_i=1)
  450. res = self.meta_interp(main, [1], enable_opts='', trace_limit=TRACE_LIMIT)
  451. self.check_resops(call=0, call_may_force=0)
  452. def test_trace_from_start(self):
  453. def p(pc, code):
  454. code = hlstr(code)
  455. return "'%s' at %d: %s" % (code, pc, code[pc])
  456. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  457. get_printable_location=p)
  458. def f(code, n):
  459. pc = 0
  460. while pc < len(code):
  461. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  462. op = code[pc]
  463. if op == "+":
  464. n += 7
  465. elif op == "-":
  466. n -= 1
  467. elif op == "c":
  468. n = f('---', n)
  469. elif op == "l":
  470. if n > 0:
  471. myjitdriver.can_enter_jit(n=n, code=code, pc=1)
  472. pc = 1
  473. continue
  474. else:
  475. assert 0
  476. pc += 1
  477. return n
  478. def g(m):
  479. if m > 1000000:
  480. f('', 0)
  481. result = 0
  482. for i in range(m):
  483. result += f('+-cl--', i)
  484. res = self.meta_interp(g, [50], backendopt=True)
  485. assert res == g(50)
  486. py.test.skip("tracing from start is by now only longer enabled "
  487. "if a trace gets too big")
  488. self.check_tree_loop_count(3)
  489. self.check_history(int_add=1)
  490. def test_dont_inline_huge_stuff(self):
  491. def p(pc, code):
  492. code = hlstr(code)
  493. return "%s %d %s" % (code, pc, code[pc])
  494. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  495. get_printable_location=p,
  496. is_recursive=True)
  497. def f(code, n):
  498. pc = 0
  499. while pc < len(code):
  500. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  501. op = code[pc]
  502. if op == "-":
  503. n -= 1
  504. elif op == "c":
  505. f('--------------------', n)
  506. elif op == "l":
  507. if n > 0:
  508. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  509. pc = 0
  510. continue
  511. else:
  512. assert 0
  513. pc += 1
  514. return n
  515. def g(m):
  516. set_param(None, 'inlining', True)
  517. # carefully chosen threshold to make sure that the inner function
  518. # cannot be inlined, but the inner function on its own is small
  519. # enough
  520. set_param(None, 'trace_limit', 40)
  521. if m > 1000000:
  522. f('', 0)
  523. result = 0
  524. for i in range(m):
  525. result += f('-c-----------l-', i+100)
  526. self.meta_interp(g, [10], backendopt=True)
  527. self.check_aborted_count(1)
  528. self.check_resops(call=0, call_assembler_i=2)
  529. self.check_jitcell_token_count(2)
  530. def test_directly_call_assembler(self):
  531. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  532. get_printable_location = lambda codeno : str(codeno))
  533. def portal(codeno):
  534. i = 0
  535. while i < 10:
  536. driver.can_enter_jit(codeno = codeno, i = i)
  537. driver.jit_merge_point(codeno = codeno, i = i)
  538. if codeno == 2:
  539. portal(1)
  540. i += 1
  541. self.meta_interp(portal, [2], inline=True)
  542. self.check_history(call_assembler_n=1)
  543. def test_recursion_cant_call_assembler_directly(self):
  544. driver = JitDriver(greens = ['codeno'], reds = ['i', 'j'],
  545. get_printable_location = lambda codeno : str(codeno))
  546. def portal(codeno, j):
  547. i = 1
  548. while 1:
  549. driver.jit_merge_point(codeno=codeno, i=i, j=j)
  550. if (i >> 1) == 1:
  551. if j == 0:
  552. return
  553. portal(2, j - 1)
  554. elif i == 5:
  555. return
  556. i += 1
  557. driver.can_enter_jit(codeno=codeno, i=i, j=j)
  558. portal(2, 5)
  559. from rpython.jit.metainterp import compile, pyjitpl
  560. pyjitpl._warmrunnerdesc = None
  561. trace = []
  562. def my_ctc(*args):
  563. looptoken = original_ctc(*args)
  564. trace.append(looptoken)
  565. return looptoken
  566. original_ctc = compile.compile_tmp_callback
  567. try:
  568. compile.compile_tmp_callback = my_ctc
  569. self.meta_interp(portal, [2, 5], inline=True)
  570. self.check_resops(call_may_force=0, call_assembler_n=2)
  571. finally:
  572. compile.compile_tmp_callback = original_ctc
  573. # check that we made a temporary callback
  574. assert len(trace) == 1
  575. # and that we later redirected it to something else
  576. try:
  577. redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
  578. except AttributeError:
  579. pass # not the llgraph backend
  580. else:
  581. print redirected
  582. assert redirected.keys() == trace
  583. def test_recursion_cant_call_assembler_directly_with_virtualizable(self):
  584. # exactly the same logic as the previous test, but with 'frame.j'
  585. # instead of just 'j'
  586. class Frame(object):
  587. _virtualizable_ = ['j']
  588. def __init__(self, j):
  589. self.j = j
  590. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  591. virtualizables = ['frame'],
  592. get_printable_location = lambda codeno : str(codeno))
  593. def portal(codeno, frame):
  594. i = 1
  595. while 1:
  596. driver.jit_merge_point(codeno=codeno, i=i, frame=frame)
  597. if (i >> 1) == 1:
  598. if frame.j == 0:
  599. return
  600. portal(2, Frame(frame.j - 1))
  601. elif i == 5:
  602. return
  603. i += 1
  604. driver.can_enter_jit(codeno=codeno, i=i, frame=frame)
  605. def main(codeno, j):
  606. portal(codeno, Frame(j))
  607. main(2, 5)
  608. from rpython.jit.metainterp import compile, pyjitpl
  609. pyjitpl._warmrunnerdesc = None
  610. trace = []
  611. def my_ctc(*args):
  612. looptoken = original_ctc(*args)
  613. trace.append(looptoken)
  614. return looptoken
  615. original_ctc = compile.compile_tmp_callback
  616. try:
  617. compile.compile_tmp_callback = my_ctc
  618. self.meta_interp(main, [2, 5], inline=True)
  619. self.check_resops(call_may_force=0, call_assembler_n=2)
  620. finally:
  621. compile.compile_tmp_callback = original_ctc
  622. # check that we made a temporary callback
  623. assert len(trace) == 1
  624. # and that we later redirected it to something else
  625. try:
  626. redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
  627. except AttributeError:
  628. pass # not the llgraph backend
  629. else:
  630. print redirected
  631. assert redirected.keys() == trace
  632. def test_directly_call_assembler_return(self):
  633. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  634. get_printable_location = lambda codeno : str(codeno))
  635. def portal(codeno):
  636. i = 0
  637. k = codeno
  638. while i < 10:
  639. driver.can_enter_jit(codeno = codeno, i = i, k = k)
  640. driver.jit_merge_point(codeno = codeno, i = i, k = k)
  641. if codeno == 2:
  642. k = portal(1)
  643. i += 1
  644. return k
  645. self.meta_interp(portal, [2], inline=True)
  646. self.check_history(call_assembler_i=1)
  647. def test_directly_call_assembler_raise(self):
  648. class MyException(Exception):
  649. def __init__(self, x):
  650. self.x = x
  651. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  652. get_printable_location = lambda codeno : str(codeno))
  653. def portal(codeno):
  654. i = 0
  655. while i < 10:
  656. driver.can_enter_jit(codeno = codeno, i = i)
  657. driver.jit_merge_point(codeno = codeno, i = i)
  658. if codeno == 2:
  659. try:
  660. portal(1)
  661. except MyException as me:
  662. i += me.x
  663. i += 1
  664. if codeno == 1:
  665. raise MyException(1)
  666. self.meta_interp(portal, [2], inline=True)
  667. self.check_history(call_assembler_n=1)
  668. def test_directly_call_assembler_fail_guard(self):
  669. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  670. get_printable_location = lambda codeno : str(codeno))
  671. def portal(codeno, k):
  672. i = 0
  673. while i < 10:
  674. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  675. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  676. if codeno == 2:
  677. k += portal(1, k)
  678. elif k > 40:
  679. if i % 2:
  680. k += 1
  681. else:
  682. k += 2
  683. k += 1
  684. i += 1
  685. return k
  686. res = self.meta_interp(portal, [2, 0], inline=True)
  687. assert res == 13542
  688. def test_directly_call_assembler_virtualizable(self):
  689. class Thing(object):
  690. def __init__(self, val):
  691. self.val = val
  692. class Frame(object):
  693. _virtualizable_ = ['thing']
  694. driver = JitDriver(greens = ['codeno'], reds = ['i', 's', 'frame'],
  695. virtualizables = ['frame'],
  696. get_printable_location = lambda codeno : str(codeno))
  697. def main(codeno):
  698. frame = Frame()
  699. frame.thing = Thing(0)
  700. result = portal(codeno, frame)
  701. return result
  702. def portal(codeno, frame):
  703. i = 0
  704. s = 0
  705. while i < 10:
  706. driver.can_enter_jit(frame=frame, codeno=codeno, i=i, s=s)
  707. driver.jit_merge_point(frame=frame, codeno=codeno, i=i, s=s)
  708. nextval = frame.thing.val
  709. if codeno == 0:
  710. subframe = Frame()
  711. subframe.thing = Thing(nextval)
  712. nextval = portal(1, subframe)
  713. s += subframe.thing.val
  714. frame.thing = Thing(nextval + 1)
  715. i += 1
  716. return frame.thing.val + s
  717. res = self.meta_interp(main, [0], inline=True)
  718. self.check_resops(call=0, cond_call=2)
  719. assert res == main(0)
  720. def test_directly_call_assembler_virtualizable_reset_token(self):
  721. py.test.skip("not applicable any more, I think")
  722. from rpython.rtyper.lltypesystem import lltype
  723. from rpython.rlib.debug import llinterpcall
  724. class Thing(object):
  725. def __init__(self, val):
  726. self.val = val
  727. class Frame(object):
  728. _virtualizable_ = ['thing']
  729. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  730. virtualizables = ['frame'],
  731. get_printable_location = lambda codeno : str(codeno))
  732. @dont_look_inside
  733. def check_frame(subframe):
  734. if we_are_translated():
  735. llinterpcall(lltype.Void, check_ll_frame, subframe)
  736. def check_ll_frame(ll_subframe):
  737. # This is called with the low-level Struct that is the frame.
  738. # Check that the vable_token was correctly reset to zero.
  739. # Note that in order for that test to catch failures, it needs
  740. # three levels of recursion: the vable_token of the subframe
  741. # at the level 2 is set to a non-zero value when doing the
  742. # call to the level 3 only. This used to fail when the test
  743. # is run via rpython.jit.backend.x86.test.test_recursive.
  744. from rpython.jit.metainterp.virtualizable import TOKEN_NONE
  745. assert ll_subframe.vable_token == TOKEN_NONE
  746. def main(codeno):
  747. frame = Frame()
  748. frame.thing = Thing(0)
  749. portal(codeno, frame)
  750. return frame.thing.val
  751. def portal(codeno, frame):
  752. i = 0
  753. while i < 5:
  754. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  755. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  756. nextval = frame.thing.val
  757. if codeno < 2:
  758. subframe = Frame()
  759. subframe.thing = Thing(nextval)
  760. nextval = portal(codeno + 1, subframe)
  761. check_frame(subframe)
  762. frame.thing = Thing(nextval + 1)
  763. i += 1
  764. return frame.thing.val
  765. res = self.meta_interp(main, [0], inline=True)
  766. assert res == main(0)
  767. def test_directly_call_assembler_virtualizable_force1(self):
  768. class Thing(object):
  769. def __init__(self, val):
  770. self.val = val
  771. class Frame(object):
  772. _virtualizable_ = ['thing']
  773. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  774. virtualizables = ['frame'],
  775. get_printable_location = lambda codeno : str(codeno))
  776. class SomewhereElse(object):
  777. pass
  778. somewhere_else = SomewhereElse()
  779. def change(newthing):
  780. somewhere_else.frame.thing = newthing
  781. def main(codeno):
  782. frame = Frame()
  783. somewhere_else.frame = frame
  784. frame.thing = Thing(0)
  785. portal(codeno, frame)
  786. return frame.thing.val
  787. def portal(codeno, frame):
  788. print 'ENTER:', codeno, frame.thing.val
  789. i = 0
  790. while i < 10:
  791. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  792. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  793. nextval = frame.thing.val
  794. if codeno == 0:
  795. subframe = Frame()
  796. subframe.thing = Thing(nextval)
  797. nextval = portal(1, subframe)
  798. elif codeno == 1:
  799. if frame.thing.val > 40:
  800. change(Thing(13))
  801. nextval = 13
  802. else:
  803. fatalerror("bad codeno = " + str(codeno))
  804. frame.thing = Thing(nextval + 1)
  805. i += 1
  806. print 'LEAVE:', codeno, frame.thing.val
  807. return frame.thing.val
  808. res = self.meta_interp(main, [0], inline=True,
  809. policy=StopAtXPolicy(change))
  810. assert res == main(0)
  811. def test_directly_call_assembler_virtualizable_with_array(self):
  812. myjitdriver = JitDriver(greens = ['codeno'], reds = ['n', 'x', 'frame'],
  813. virtualizables = ['frame'])
  814. class Frame(object):
  815. _virtualizable_ = ['l[*]', 's']
  816. def __init__(self, l, s):
  817. self = hint(self, access_directly=True,
  818. fresh_virtualizable=True)
  819. self.l = l
  820. self.s = s
  821. def main(codeno, n, a):
  822. frame = Frame([a, a+1, a+2, a+3], 0)
  823. return f(codeno, n, a, frame)
  824. def f(codeno, n, a, frame):
  825. x = 0
  826. while n > 0:
  827. myjitdriver.can_enter_jit(codeno=codeno, frame=frame, n=n, x=x)
  828. myjitdriver.jit_merge_point(codeno=codeno, frame=frame, n=n,
  829. x=x)
  830. frame.s = promote(frame.s)
  831. n -= 1
  832. s = frame.s
  833. assert s >= 0
  834. x += frame.l[s]
  835. frame.s += 1
  836. if codeno == 0:
  837. subframe = Frame([n, n+1, n+2, n+3], 0)
  838. x += f(1, 10, 1, subframe)
  839. s = frame.s
  840. assert s >= 0
  841. x += frame.l[s]
  842. x += len(frame.l)
  843. frame.s -= 1
  844. return x
  845. res = self.meta_interp(main, [0, 10, 1], listops=True, inline=True)
  846. assert res == main(0, 10, 1)
  847. def test_directly_call_assembler_virtualizable_force_blackhole(self):
  848. class Thing(object):
  849. def __init__(self, val):
  850. self.val = val
  851. class Frame(object):
  852. _virtualizable_ = ['thing']
  853. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  854. virtualizables = ['frame'],
  855. get_printable_location = lambda codeno : str(codeno))
  856. class SomewhereElse(object):
  857. pass
  858. somewhere_else = SomewhereElse()
  859. def change(newthing, arg):
  860. print arg
  861. if arg > 30:
  862. somewhere_else.frame.thing = newthing
  863. arg = 13
  864. return arg
  865. def main(codeno):
  866. frame = Frame()
  867. somewhere_else.frame = frame
  868. frame.thing = Thing(0)
  869. portal(codeno, frame)
  870. return frame.thing.val
  871. def portal(codeno, frame):
  872. i = 0
  873. while i < 10:
  874. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  875. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  876. nextval = frame.thing.val
  877. if codeno == 0:
  878. subframe = Frame()
  879. subframe.thing = Thing(nextval)
  880. nextval = portal(1, subframe)
  881. else:
  882. nextval = change(Thing(13), frame.thing.val)
  883. frame.thing = Thing(nextval + 1)
  884. i += 1
  885. return frame.thing.val
  886. res = self.meta_interp(main, [0], inline=True,
  887. policy=StopAtXPolicy(change))
  888. assert res == main(0)
  889. def test_assembler_call_red_args(self):
  890. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  891. get_printable_location = lambda codeno : str(codeno))
  892. def residual(k):
  893. if k > 150:
  894. return 0
  895. return 1
  896. def portal(codeno, k):
  897. i = 0
  898. while i < 15:
  899. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  900. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  901. if codeno == 2:
  902. k += portal(residual(k), k)
  903. if codeno == 0:
  904. k += 2
  905. elif codeno == 1:
  906. k += 1
  907. i += 1
  908. return k
  909. res = self.meta_interp(portal, [2, 0], inline=True,
  910. policy=StopAtXPolicy(residual))
  911. assert res == portal(2, 0)
  912. self.check_resops(call_assembler_i=4)
  913. def test_inline_without_hitting_the_loop(self):
  914. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  915. get_printable_location = lambda codeno : str(codeno))
  916. def portal(codeno):
  917. i = 0
  918. while True:
  919. driver.jit_merge_point(codeno=codeno, i=i)
  920. if codeno < 10:
  921. i += portal(20)
  922. codeno += 1
  923. elif codeno == 10:
  924. if i > 63:
  925. return i
  926. codeno = 0
  927. driver.can_enter_jit(codeno=codeno, i=i)
  928. else:
  929. return 1
  930. assert portal(0) == 70
  931. res = self.meta_interp(portal, [0], inline=True)
  932. assert res == 70
  933. self.check_resops(call_assembler=0)
  934. def test_inline_with_hitting_the_loop_sometimes(self):
  935. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  936. get_printable_location = lambda codeno : str(codeno))
  937. def portal(codeno, k):
  938. if k > 2:
  939. return 1
  940. i = 0
  941. while True:
  942. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  943. if codeno < 10:
  944. i += portal(codeno + 5, k+1)
  945. codeno += 1
  946. elif codeno == 10:
  947. if i > [-1, 2000, 63][k]:
  948. return i
  949. codeno = 0
  950. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  951. else:
  952. return 1
  953. assert portal(0, 1) == 2095
  954. res = self.meta_interp(portal, [0, 1], inline=True)
  955. assert res == 2095
  956. self.check_resops(call_assembler_i=12)
  957. def test_inline_with_hitting_the_loop_sometimes_exc(self):
  958. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  959. get_printable_location = lambda codeno : str(codeno))
  960. class GotValue(Exception):
  961. def __init__(self, result):
  962. self.result = result
  963. def portal(codeno, k):
  964. if k > 2:
  965. raise GotValue(1)
  966. i = 0
  967. while True:
  968. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  969. if codeno < 10:
  970. try:
  971. portal(codeno + 5, k+1)
  972. except GotValue as e:
  973. i += e.result
  974. codeno += 1
  975. elif codeno == 10:
  976. if i > [-1, 2000, 63][k]:
  977. raise GotValue(i)
  978. codeno = 0
  979. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  980. else:
  981. raise GotValue(1)
  982. def main(codeno, k):
  983. try:
  984. portal(codeno, k)
  985. except GotValue as e:
  986. return e.result
  987. assert main(0, 1) == 2095
  988. res = self.meta_interp(main, [0, 1], inline=True)
  989. assert res == 2095
  990. self.check_resops(call_assembler_n=12)
  991. def test_inline_recursion_limit(self):
  992. driver = JitDriver(greens = ["threshold", "loop"], reds=["i"])
  993. @dont_look_inside
  994. def f():
  995. set_param(driver, "max_unroll_recursion", 10)
  996. def portal(threshold, loop, i):
  997. f()
  998. if i > threshold:
  999. return i
  1000. while True:
  1001. driver.jit_merge_point(threshold=threshold, loop=loop, i=i)
  1002. if loop:
  1003. portal(threshold, False, 0)
  1004. else:
  1005. portal(threshold, False, i + 1)
  1006. return i
  1007. if i > 10:
  1008. return 1
  1009. i += 1
  1010. driver.can_enter_jit(threshold=threshold, loop=loop, i=i)
  1011. res1 = portal(10, True, 0)
  1012. res2 = self.meta_interp(portal, [10, True, 0], inline=True)
  1013. assert res1 == res2
  1014. self.check_resops(call_assembler_i=2)
  1015. res1 = portal(9, True, 0)
  1016. res2 = self.meta_interp(portal, [9, True, 0], inline=True)
  1017. assert res1 == res2
  1018. self.check_resops(call_assembler=0)
  1019. def test_handle_jitexception_in_portal(self):
  1020. # a test for _handle_jitexception_in_portal in blackhole.py
  1021. driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
  1022. get_printable_location = lambda codeno: str(codeno))
  1023. def do_can_enter_jit(codeno, i, str):
  1024. i = (i+1)-1 # some operations
  1025. driver.can_enter_jit(codeno=codeno, i=i, str=str)
  1026. def intermediate(codeno, i, str):
  1027. if i == 9:
  1028. do_can_enter_jit(codeno, i, str)
  1029. def portal(codeno, str):
  1030. i = value.initial
  1031. while i < 10:
  1032. intermediate(codeno, i, str)
  1033. driver.jit_merge_point(codeno=codeno, i=i, str=str)
  1034. i += 1
  1035. if codeno == 64 and i == 10:
  1036. str = portal(96, str)
  1037. str += chr(codeno+i)
  1038. return str
  1039. class Value:
  1040. initial = -1
  1041. value = Value()
  1042. def main():
  1043. value.initial = 0
  1044. return (portal(64, '') +
  1045. portal(64, '') +
  1046. portal(64, '') +
  1047. portal(64, '') +
  1048. portal(64, ''))
  1049. assert main() == 'ABCDEFGHIabcdefghijJ' * 5
  1050. for tlimit in [95, 90, 102]:
  1051. print 'tlimit =', tlimit
  1052. res = self.meta_interp(main, [], inline=True, trace_limit=tlimit)
  1053. assert ''.join(res.chars) == 'ABCDEFGHIabcdefghijJ' * 5
  1054. def test_handle_jitexception_in_portal_returns_void(self):
  1055. # a test for _handle_jitexception_in_portal in blackhole.py
  1056. driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
  1057. get_printable_location = lambda codeno: str(codeno))
  1058. def do_can_enter_jit(codeno, i, str):
  1059. i = (i+1)-1 # some operations
  1060. driver.can_enter_jit(codeno=codeno, i=i, str=str)
  1061. def intermediate(codeno, i, str):
  1062. if i == 9:
  1063. do_can_enter_jit(codeno, i, str)
  1064. def portal(codeno, str):
  1065. i = value.initial
  1066. while i < 10:
  1067. intermediate(codeno, i, str)
  1068. driver.jit_merge_point(codeno=codeno, i=i, str=str)
  1069. i += 1
  1070. if codeno == 64 and i == 10:
  1071. portal(96, str)
  1072. str += chr(codeno+i)
  1073. class Value:
  1074. initial = -1
  1075. value = Value()
  1076. def main():
  1077. value.initial = 0
  1078. portal(64, '')
  1079. portal(64, '')
  1080. portal(64, '')
  1081. portal(64, '')
  1082. portal(64, '')
  1083. main()
  1084. for tlimit in [95, 90, 102]:
  1085. print 'tlimit =', tlimit
  1086. self.meta_interp(main, [], inline=True, trace_limit=tlimit)
  1087. def test_no_duplicates_bug(self):
  1088. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  1089. get_printable_location = lambda codeno: str(codeno))
  1090. def portal(codeno, i):
  1091. while i > 0:
  1092. driver.can_enter_jit(codeno=codeno, i=i)
  1093. driver.jit_merge_point(codeno=codeno, i=i)
  1094. if codeno > 0:
  1095. break
  1096. portal(i, i)
  1097. i -= 1
  1098. self.meta_interp(portal, [0, 10], inline=True)
  1099. def test_trace_from_start_always(self):
  1100. from rpython.rlib.nonconst import NonConstant
  1101. driver = JitDriver(greens = ['c'], reds = ['i', 'v'])
  1102. def portal(c, i, v):
  1103. while i > 0:
  1104. driver.jit_merge_point(c=c, i=i, v=v)
  1105. portal(c, i - 1, v)
  1106. if v:
  1107. driver.can_enter_jit(c=c, i=i, v=v)
  1108. break
  1109. def main(c, i, _set_param, v):
  1110. if _set_param:
  1111. set_param(driver, 'function_threshold', 0)
  1112. portal(c, i, v)
  1113. self.meta_interp(main, [10, 10, False, False], inline=True)
  1114. self.check_jitcell_token_count(1)
  1115. self.check_trace_count(1)
  1116. self.meta_interp(main, [3, 10, True, False], inline=True)
  1117. self.check_jitcell_token_count(0)
  1118. self.check_trace_count(0)
  1119. def test_trace_from_start_does_not_prevent_inlining(self):
  1120. driver = JitDriver(greens = ['c', 'bc'], reds = ['i'])
  1121. def portal(bc, c, i):
  1122. while True:
  1123. driver.jit_merge_point(c=c, bc=bc, i=i)
  1124. if bc == 0:
  1125. portal(1, 8, 0)
  1126. c += 1
  1127. else:
  1128. return
  1129. if c == 10: # bc == 0
  1130. c = 0
  1131. if i >= 100:
  1132. return
  1133. driver.can_enter_jit(c=c, bc=bc, i=i)
  1134. i += 1
  1135. self.meta_interp(portal, [0, 0, 0], inline=True)
  1136. self.check_resops(call_may_force=0, call=0)
  1137. def test_dont_repeatedly_trace_from_the_same_guard(self):
  1138. driver = JitDriver(greens = [], reds = ['level', 'i'])
  1139. def portal(level):
  1140. if level == 0:
  1141. i = -10
  1142. else:
  1143. i = 0
  1144. #
  1145. while True:
  1146. driver.jit_merge_point(level=level, i=i)
  1147. if level == 25:
  1148. return 42
  1149. i += 1
  1150. if i <= 0: # <- guard
  1151. continue # first make a loop
  1152. else:
  1153. # then we fail the guard above, doing a recursive call,
  1154. # which will itself fail the same guard above, and so on
  1155. return portal(level + 1)
  1156. self.meta_interp(portal, [0])
  1157. self.check_trace_count_at_most(2) # and not, e.g., 24
  1158. def test_get_unique_id(self):
  1159. lst = []
  1160. def reg_codemap(self, (start, size, l)):
  1161. lst.append((start, size))
  1162. old_reg_codemap(self, (start, size, l))
  1163. old_reg_codemap = codemap.CodemapStorage.register_codemap
  1164. try:
  1165. codemap.CodemapStorage.register_codemap = reg_codemap
  1166. def get_unique_id(pc, code):
  1167. return (code + 1) * 2
  1168. driver = JitDriver(greens=["pc", "code"], reds='auto',
  1169. get_unique_id=get_unique_id, is_recursive=True)
  1170. def f(pc, code):
  1171. i = 0
  1172. while i < 10:
  1173. driver.jit_merge_point(pc=pc, code=code)
  1174. pc += 1
  1175. if pc == 3:
  1176. if code == 1:
  1177. f(0, 0)
  1178. pc = 0
  1179. i += 1
  1180. self.meta_interp(f, [0, 1], inline=True)
  1181. self.check_get_unique_id(lst) # overloaded on assembler backends
  1182. finally:
  1183. codemap.CodemapStorage.register_codemap = old_reg_codemap
  1184. def check_get_unique_id(self, lst):
  1185. pass
  1186. class TestLLtype(RecursiveTests, LLJitMixin):
  1187. pass