PageRenderTime 56ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/rpython/jit/metainterp/test/test_recursive.py

https://bitbucket.org/kkris/pypy
Python | 1270 lines | 1268 code | 2 blank | 0 comment | 4 complexity | 14540eee973084d11b5515e1de41efff MD5 | raw file
  1. import py
  2. from rpython.rlib.jit import JitDriver, hint, set_param
  3. from rpython.rlib.jit import unroll_safe, dont_look_inside, promote
  4. from rpython.rlib.objectmodel import we_are_translated
  5. from rpython.rlib.debug import fatalerror
  6. from rpython.jit.metainterp.test.support import LLJitMixin, OOJitMixin
  7. from rpython.jit.codewriter.policy import StopAtXPolicy
  8. from rpython.rtyper.annlowlevel import hlstr
  9. from rpython.jit.metainterp.warmspot import get_stats
  10. class RecursiveTests:
  11. def test_simple_recursion(self):
  12. myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  13. def f(n):
  14. m = n - 2
  15. while True:
  16. myjitdriver.jit_merge_point(n=n, m=m)
  17. n -= 1
  18. if m == n:
  19. return main(n) * 2
  20. myjitdriver.can_enter_jit(n=n, m=m)
  21. def main(n):
  22. if n > 0:
  23. return f(n+1)
  24. else:
  25. return 1
  26. res = self.meta_interp(main, [20], enable_opts='')
  27. assert res == main(20)
  28. self.check_history(call=0)
  29. def test_simple_recursion_with_exc(self):
  30. myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  31. class Error(Exception):
  32. pass
  33. def f(n):
  34. m = n - 2
  35. while True:
  36. myjitdriver.jit_merge_point(n=n, m=m)
  37. n -= 1
  38. if n == 10:
  39. raise Error
  40. if m == n:
  41. try:
  42. return main(n) * 2
  43. except Error:
  44. return 2
  45. myjitdriver.can_enter_jit(n=n, m=m)
  46. def main(n):
  47. if n > 0:
  48. return f(n+1)
  49. else:
  50. return 1
  51. res = self.meta_interp(main, [20], enable_opts='')
  52. assert res == main(20)
  53. def test_recursion_three_times(self):
  54. myjitdriver = JitDriver(greens=[], reds=['n', 'm', 'total'])
  55. def f(n):
  56. m = n - 3
  57. total = 0
  58. while True:
  59. myjitdriver.jit_merge_point(n=n, m=m, total=total)
  60. n -= 1
  61. total += main(n)
  62. if m == n:
  63. return total + 5
  64. myjitdriver.can_enter_jit(n=n, m=m, total=total)
  65. def main(n):
  66. if n > 0:
  67. return f(n)
  68. else:
  69. return 1
  70. print
  71. for i in range(1, 11):
  72. print '%3d %9d' % (i, f(i))
  73. res = self.meta_interp(main, [10], enable_opts='')
  74. assert res == main(10)
  75. self.check_enter_count_at_most(11)
  76. def test_bug_1(self):
  77. myjitdriver = JitDriver(greens=[], reds=['n', 'i', 'stack'])
  78. def opaque(n, i):
  79. if n == 1 and i == 19:
  80. for j in range(20):
  81. res = f(0) # recurse repeatedly, 20 times
  82. assert res == 0
  83. def f(n):
  84. stack = [n]
  85. i = 0
  86. while i < 20:
  87. myjitdriver.can_enter_jit(n=n, i=i, stack=stack)
  88. myjitdriver.jit_merge_point(n=n, i=i, stack=stack)
  89. opaque(n, i)
  90. i += 1
  91. return stack.pop()
  92. res = self.meta_interp(f, [1], enable_opts='', repeat=2,
  93. policy=StopAtXPolicy(opaque))
  94. assert res == 1
  95. def get_interpreter(self, codes):
  96. ADD = "0"
  97. JUMP_BACK = "1"
  98. CALL = "2"
  99. EXIT = "3"
  100. def getloc(i, code):
  101. return 'code="%s", i=%d' % (code, i)
  102. jitdriver = JitDriver(greens = ['i', 'code'], reds = ['n'],
  103. get_printable_location = getloc)
  104. def interpret(codenum, n, i):
  105. code = codes[codenum]
  106. while i < len(code):
  107. jitdriver.jit_merge_point(n=n, i=i, code=code)
  108. op = code[i]
  109. if op == ADD:
  110. n += 1
  111. i += 1
  112. elif op == CALL:
  113. n = interpret(1, n, 1)
  114. i += 1
  115. elif op == JUMP_BACK:
  116. if n > 20:
  117. return 42
  118. i -= 2
  119. jitdriver.can_enter_jit(n=n, i=i, code=code)
  120. elif op == EXIT:
  121. return n
  122. else:
  123. raise NotImplementedError
  124. return n
  125. return interpret
  126. def test_inline(self):
  127. code = "021"
  128. subcode = "00"
  129. codes = [code, subcode]
  130. f = self.get_interpreter(codes)
  131. assert self.meta_interp(f, [0, 0, 0], enable_opts='') == 42
  132. self.check_resops(call_may_force=1, int_add=1, call=0)
  133. assert self.meta_interp(f, [0, 0, 0], enable_opts='',
  134. inline=True) == 42
  135. self.check_resops(call=0, int_add=2, call_may_force=0,
  136. guard_no_exception=0)
  137. def test_inline_jitdriver_check(self):
  138. code = "021"
  139. subcode = "100"
  140. codes = [code, subcode]
  141. f = self.get_interpreter(codes)
  142. assert self.meta_interp(f, [0, 0, 0], enable_opts='',
  143. inline=True) == 42
  144. # the call is fully inlined, because we jump to subcode[1], thus
  145. # skipping completely the JUMP_BACK in subcode[0]
  146. self.check_resops(call=0, call_may_force=0, call_assembler=0)
  147. def test_guard_failure_in_inlined_function(self):
  148. def p(pc, code):
  149. code = hlstr(code)
  150. return "%s %d %s" % (code, pc, code[pc])
  151. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  152. get_printable_location=p)
  153. def f(code, n):
  154. pc = 0
  155. while pc < len(code):
  156. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  157. op = code[pc]
  158. if op == "-":
  159. n -= 1
  160. elif op == "c":
  161. n = f("---i---", n)
  162. elif op == "i":
  163. if n % 5 == 1:
  164. return n
  165. elif op == "l":
  166. if n > 0:
  167. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  168. pc = 0
  169. continue
  170. else:
  171. assert 0
  172. pc += 1
  173. return n
  174. def main(n):
  175. return f("c-l", n)
  176. print main(100)
  177. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  178. assert res == 0
  179. def test_guard_failure_and_then_exception_in_inlined_function(self):
  180. def p(pc, code):
  181. code = hlstr(code)
  182. return "%s %d %s" % (code, pc, code[pc])
  183. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n', 'flag'],
  184. get_printable_location=p)
  185. def f(code, n):
  186. pc = 0
  187. flag = False
  188. while pc < len(code):
  189. myjitdriver.jit_merge_point(n=n, code=code, pc=pc, flag=flag)
  190. op = code[pc]
  191. if op == "-":
  192. n -= 1
  193. elif op == "c":
  194. try:
  195. n = f("---ir---", n)
  196. except Exception:
  197. return n
  198. elif op == "i":
  199. if n < 200:
  200. flag = True
  201. elif op == "r":
  202. if flag:
  203. raise Exception
  204. elif op == "l":
  205. if n > 0:
  206. myjitdriver.can_enter_jit(n=n, code=code, pc=0, flag=flag)
  207. pc = 0
  208. continue
  209. else:
  210. assert 0
  211. pc += 1
  212. return n
  213. def main(n):
  214. return f("c-l", n)
  215. print main(1000)
  216. res = self.meta_interp(main, [1000], enable_opts='', inline=True)
  217. assert res == main(1000)
  218. def test_exception_in_inlined_function(self):
  219. def p(pc, code):
  220. code = hlstr(code)
  221. return "%s %d %s" % (code, pc, code[pc])
  222. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  223. get_printable_location=p)
  224. class Exc(Exception):
  225. pass
  226. def f(code, n):
  227. pc = 0
  228. while pc < len(code):
  229. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  230. op = code[pc]
  231. if op == "-":
  232. n -= 1
  233. elif op == "c":
  234. try:
  235. n = f("---i---", n)
  236. except Exc:
  237. pass
  238. elif op == "i":
  239. if n % 5 == 1:
  240. raise Exc
  241. elif op == "l":
  242. if n > 0:
  243. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  244. pc = 0
  245. continue
  246. else:
  247. assert 0
  248. pc += 1
  249. return n
  250. def main(n):
  251. return f("c-l", n)
  252. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  253. assert res == main(100)
  254. def test_recurse_during_blackholing(self):
  255. # this passes, if the blackholing shortcut for calls is turned off
  256. # it fails, it is very delicate in terms of parameters,
  257. # bridge/loop creation order
  258. def p(pc, code):
  259. code = hlstr(code)
  260. return "%s %d %s" % (code, pc, code[pc])
  261. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  262. get_printable_location=p)
  263. def f(code, n):
  264. pc = 0
  265. while pc < len(code):
  266. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  267. op = code[pc]
  268. if op == "-":
  269. n -= 1
  270. elif op == "c":
  271. if n < 70 and n % 3 == 1:
  272. n = f("--", n)
  273. elif op == "l":
  274. if n > 0:
  275. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  276. pc = 0
  277. continue
  278. else:
  279. assert 0
  280. pc += 1
  281. return n
  282. def main(n):
  283. set_param(None, 'threshold', 3)
  284. set_param(None, 'trace_eagerness', 5)
  285. return f("c-l", n)
  286. expected = main(100)
  287. res = self.meta_interp(main, [100], enable_opts='', inline=True)
  288. assert res == expected
  289. def check_max_trace_length(self, length):
  290. for loop in get_stats().loops:
  291. assert len(loop.operations) <= length + 5 # because we only check once per metainterp bytecode
  292. for op in loop.operations:
  293. if op.is_guard() and hasattr(op.getdescr(), '_debug_suboperations'):
  294. assert len(op.getdescr()._debug_suboperations) <= length + 5
  295. def test_inline_trace_limit(self):
  296. myjitdriver = JitDriver(greens=[], reds=['n'])
  297. def recursive(n):
  298. if n > 0:
  299. return recursive(n - 1) + 1
  300. return 0
  301. def loop(n):
  302. set_param(myjitdriver, "threshold", 10)
  303. pc = 0
  304. while n:
  305. myjitdriver.can_enter_jit(n=n)
  306. myjitdriver.jit_merge_point(n=n)
  307. n = recursive(n)
  308. n -= 1
  309. return n
  310. TRACE_LIMIT = 66
  311. res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
  312. assert res == 0
  313. self.check_max_trace_length(TRACE_LIMIT)
  314. self.check_enter_count_at_most(10) # maybe
  315. self.check_aborted_count(7)
  316. def test_trace_limit_bridge(self):
  317. def recursive(n):
  318. if n > 0:
  319. return recursive(n - 1) + 1
  320. return 0
  321. myjitdriver = JitDriver(greens=[], reds=['n'])
  322. def loop(n):
  323. set_param(None, "threshold", 4)
  324. set_param(None, "trace_eagerness", 2)
  325. while n:
  326. myjitdriver.can_enter_jit(n=n)
  327. myjitdriver.jit_merge_point(n=n)
  328. if n % 5 == 0:
  329. n -= 1
  330. if n < 50:
  331. n = recursive(n)
  332. n -= 1
  333. return n
  334. TRACE_LIMIT = 20
  335. res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
  336. self.check_max_trace_length(TRACE_LIMIT)
  337. self.check_aborted_count(8)
  338. self.check_enter_count_at_most(30)
  339. def test_trace_limit_with_exception_bug(self):
  340. myjitdriver = JitDriver(greens=[], reds=['n'])
  341. @unroll_safe
  342. def do_stuff(n):
  343. while n > 0:
  344. n -= 1
  345. raise ValueError
  346. def loop(n):
  347. pc = 0
  348. while n > 80:
  349. myjitdriver.can_enter_jit(n=n)
  350. myjitdriver.jit_merge_point(n=n)
  351. try:
  352. do_stuff(n)
  353. except ValueError:
  354. # the trace limit is checked when we arrive here, and we
  355. # have the exception still in last_exc_value_box at this
  356. # point -- so when we abort because of a trace too long,
  357. # the exception is passed to the blackhole interp and
  358. # incorrectly re-raised from here
  359. pass
  360. n -= 1
  361. return n
  362. TRACE_LIMIT = 66
  363. res = self.meta_interp(loop, [100], trace_limit=TRACE_LIMIT)
  364. assert res == 80
  365. def test_max_failure_args(self):
  366. FAILARGS_LIMIT = 10
  367. jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
  368. class A(object):
  369. def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
  370. self.i0 = i0
  371. self.i1 = i1
  372. self.i2 = i2
  373. self.i3 = i3
  374. self.i4 = i4
  375. self.i5 = i5
  376. self.i6 = i6
  377. self.i7 = i7
  378. self.i8 = i8
  379. self.i9 = i9
  380. def loop(n):
  381. i = 0
  382. o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
  383. while i < n:
  384. jitdriver.can_enter_jit(o=o, i=i, n=n)
  385. jitdriver.jit_merge_point(o=o, i=i, n=n)
  386. o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
  387. i + 6, i + 7, i + 8, i + 9)
  388. i += 1
  389. return o
  390. res = self.meta_interp(loop, [20], failargs_limit=FAILARGS_LIMIT,
  391. listops=True)
  392. self.check_aborted_count(5)
  393. def test_max_failure_args_exc(self):
  394. FAILARGS_LIMIT = 10
  395. jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
  396. class A(object):
  397. def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
  398. self.i0 = i0
  399. self.i1 = i1
  400. self.i2 = i2
  401. self.i3 = i3
  402. self.i4 = i4
  403. self.i5 = i5
  404. self.i6 = i6
  405. self.i7 = i7
  406. self.i8 = i8
  407. self.i9 = i9
  408. def loop(n):
  409. i = 0
  410. o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
  411. while i < n:
  412. jitdriver.can_enter_jit(o=o, i=i, n=n)
  413. jitdriver.jit_merge_point(o=o, i=i, n=n)
  414. o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
  415. i + 6, i + 7, i + 8, i + 9)
  416. i += 1
  417. raise ValueError
  418. def main(n):
  419. try:
  420. loop(n)
  421. return 1
  422. except ValueError:
  423. return 0
  424. res = self.meta_interp(main, [20], failargs_limit=FAILARGS_LIMIT,
  425. listops=True)
  426. assert not res
  427. self.check_aborted_count(5)
  428. def test_set_param_inlining(self):
  429. myjitdriver = JitDriver(greens=[], reds=['n', 'recurse'])
  430. def loop(n, recurse=False):
  431. while n:
  432. myjitdriver.jit_merge_point(n=n, recurse=recurse)
  433. n -= 1
  434. if not recurse:
  435. loop(10, True)
  436. myjitdriver.can_enter_jit(n=n, recurse=recurse)
  437. return n
  438. TRACE_LIMIT = 66
  439. def main(inline):
  440. set_param(None, "threshold", 10)
  441. set_param(None, 'function_threshold', 60)
  442. if inline:
  443. set_param(None, 'inlining', True)
  444. else:
  445. set_param(None, 'inlining', False)
  446. return loop(100)
  447. res = self.meta_interp(main, [0], enable_opts='', trace_limit=TRACE_LIMIT)
  448. self.check_resops(call=0, call_may_force=1)
  449. res = self.meta_interp(main, [1], enable_opts='', trace_limit=TRACE_LIMIT)
  450. self.check_resops(call=0, call_may_force=0)
  451. def test_trace_from_start(self):
  452. def p(pc, code):
  453. code = hlstr(code)
  454. return "'%s' at %d: %s" % (code, pc, code[pc])
  455. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  456. get_printable_location=p)
  457. def f(code, n):
  458. pc = 0
  459. while pc < len(code):
  460. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  461. op = code[pc]
  462. if op == "+":
  463. n += 7
  464. elif op == "-":
  465. n -= 1
  466. elif op == "c":
  467. n = f('---', n)
  468. elif op == "l":
  469. if n > 0:
  470. myjitdriver.can_enter_jit(n=n, code=code, pc=1)
  471. pc = 1
  472. continue
  473. else:
  474. assert 0
  475. pc += 1
  476. return n
  477. def g(m):
  478. if m > 1000000:
  479. f('', 0)
  480. result = 0
  481. for i in range(m):
  482. result += f('+-cl--', i)
  483. res = self.meta_interp(g, [50], backendopt=True)
  484. assert res == g(50)
  485. py.test.skip("tracing from start is by now only longer enabled "
  486. "if a trace gets too big")
  487. self.check_tree_loop_count(3)
  488. self.check_history(int_add=1)
  489. def test_dont_inline_huge_stuff(self):
  490. def p(pc, code):
  491. code = hlstr(code)
  492. return "%s %d %s" % (code, pc, code[pc])
  493. myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
  494. get_printable_location=p)
  495. def f(code, n):
  496. pc = 0
  497. while pc < len(code):
  498. myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
  499. op = code[pc]
  500. if op == "-":
  501. n -= 1
  502. elif op == "c":
  503. f('--------------------', n)
  504. elif op == "l":
  505. if n > 0:
  506. myjitdriver.can_enter_jit(n=n, code=code, pc=0)
  507. pc = 0
  508. continue
  509. else:
  510. assert 0
  511. pc += 1
  512. return n
  513. def g(m):
  514. set_param(None, 'inlining', True)
  515. # carefully chosen threshold to make sure that the inner function
  516. # cannot be inlined, but the inner function on its own is small
  517. # enough
  518. set_param(None, 'trace_limit', 40)
  519. if m > 1000000:
  520. f('', 0)
  521. result = 0
  522. for i in range(m):
  523. result += f('-c-----------l-', i+100)
  524. self.meta_interp(g, [10], backendopt=True)
  525. self.check_aborted_count(1)
  526. self.check_resops(call=0, call_assembler=2)
  527. self.check_jitcell_token_count(2)
  528. def test_directly_call_assembler(self):
  529. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  530. get_printable_location = lambda codeno : str(codeno))
  531. def portal(codeno):
  532. i = 0
  533. while i < 10:
  534. driver.can_enter_jit(codeno = codeno, i = i)
  535. driver.jit_merge_point(codeno = codeno, i = i)
  536. if codeno == 2:
  537. portal(1)
  538. i += 1
  539. self.meta_interp(portal, [2], inline=True)
  540. self.check_history(call_assembler=1)
  541. def test_recursion_cant_call_assembler_directly(self):
  542. driver = JitDriver(greens = ['codeno'], reds = ['i', 'j'],
  543. get_printable_location = lambda codeno : str(codeno))
  544. def portal(codeno, j):
  545. i = 1
  546. while 1:
  547. driver.jit_merge_point(codeno=codeno, i=i, j=j)
  548. if (i >> 1) == 1:
  549. if j == 0:
  550. return
  551. portal(2, j - 1)
  552. elif i == 5:
  553. return
  554. i += 1
  555. driver.can_enter_jit(codeno=codeno, i=i, j=j)
  556. portal(2, 5)
  557. from rpython.jit.metainterp import compile, pyjitpl
  558. pyjitpl._warmrunnerdesc = None
  559. trace = []
  560. def my_ctc(*args):
  561. looptoken = original_ctc(*args)
  562. trace.append(looptoken)
  563. return looptoken
  564. original_ctc = compile.compile_tmp_callback
  565. try:
  566. compile.compile_tmp_callback = my_ctc
  567. self.meta_interp(portal, [2, 5], inline=True)
  568. self.check_resops(call_may_force=0, call_assembler=2)
  569. finally:
  570. compile.compile_tmp_callback = original_ctc
  571. # check that we made a temporary callback
  572. assert len(trace) == 1
  573. # and that we later redirected it to something else
  574. try:
  575. redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
  576. except AttributeError:
  577. pass # not the llgraph backend
  578. else:
  579. print redirected
  580. assert redirected.keys() == trace
  581. def test_recursion_cant_call_assembler_directly_with_virtualizable(self):
  582. # exactly the same logic as the previous test, but with 'frame.j'
  583. # instead of just 'j'
  584. class Frame(object):
  585. _virtualizable2_ = ['j']
  586. def __init__(self, j):
  587. self.j = j
  588. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  589. virtualizables = ['frame'],
  590. get_printable_location = lambda codeno : str(codeno))
  591. def portal(codeno, frame):
  592. i = 1
  593. while 1:
  594. driver.jit_merge_point(codeno=codeno, i=i, frame=frame)
  595. if (i >> 1) == 1:
  596. if frame.j == 0:
  597. return
  598. portal(2, Frame(frame.j - 1))
  599. elif i == 5:
  600. return
  601. i += 1
  602. driver.can_enter_jit(codeno=codeno, i=i, frame=frame)
  603. def main(codeno, j):
  604. portal(codeno, Frame(j))
  605. main(2, 5)
  606. from rpython.jit.metainterp import compile, pyjitpl
  607. pyjitpl._warmrunnerdesc = None
  608. trace = []
  609. def my_ctc(*args):
  610. looptoken = original_ctc(*args)
  611. trace.append(looptoken)
  612. return looptoken
  613. original_ctc = compile.compile_tmp_callback
  614. try:
  615. compile.compile_tmp_callback = my_ctc
  616. self.meta_interp(main, [2, 5], inline=True)
  617. self.check_resops(call_may_force=0, call_assembler=2)
  618. finally:
  619. compile.compile_tmp_callback = original_ctc
  620. # check that we made a temporary callback
  621. assert len(trace) == 1
  622. # and that we later redirected it to something else
  623. try:
  624. redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
  625. except AttributeError:
  626. pass # not the llgraph backend
  627. else:
  628. print redirected
  629. assert redirected.keys() == trace
  630. def test_directly_call_assembler_return(self):
  631. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  632. get_printable_location = lambda codeno : str(codeno))
  633. def portal(codeno):
  634. i = 0
  635. k = codeno
  636. while i < 10:
  637. driver.can_enter_jit(codeno = codeno, i = i, k = k)
  638. driver.jit_merge_point(codeno = codeno, i = i, k = k)
  639. if codeno == 2:
  640. k = portal(1)
  641. i += 1
  642. return k
  643. self.meta_interp(portal, [2], inline=True)
  644. self.check_history(call_assembler=1)
  645. def test_directly_call_assembler_raise(self):
  646. class MyException(Exception):
  647. def __init__(self, x):
  648. self.x = x
  649. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  650. get_printable_location = lambda codeno : str(codeno))
  651. def portal(codeno):
  652. i = 0
  653. while i < 10:
  654. driver.can_enter_jit(codeno = codeno, i = i)
  655. driver.jit_merge_point(codeno = codeno, i = i)
  656. if codeno == 2:
  657. try:
  658. portal(1)
  659. except MyException, me:
  660. i += me.x
  661. i += 1
  662. if codeno == 1:
  663. raise MyException(1)
  664. self.meta_interp(portal, [2], inline=True)
  665. self.check_history(call_assembler=1)
  666. def test_directly_call_assembler_fail_guard(self):
  667. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  668. get_printable_location = lambda codeno : str(codeno))
  669. def portal(codeno, k):
  670. i = 0
  671. while i < 10:
  672. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  673. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  674. if codeno == 2:
  675. k += portal(1, k)
  676. elif k > 40:
  677. if i % 2:
  678. k += 1
  679. else:
  680. k += 2
  681. k += 1
  682. i += 1
  683. return k
  684. res = self.meta_interp(portal, [2, 0], inline=True)
  685. assert res == 13542
  686. def test_directly_call_assembler_virtualizable(self):
  687. class Thing(object):
  688. def __init__(self, val):
  689. self.val = val
  690. class Frame(object):
  691. _virtualizable2_ = ['thing']
  692. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  693. virtualizables = ['frame'],
  694. get_printable_location = lambda codeno : str(codeno))
  695. def main(codeno):
  696. frame = Frame()
  697. frame.thing = Thing(0)
  698. portal(codeno, frame)
  699. return frame.thing.val
  700. def portal(codeno, frame):
  701. i = 0
  702. while i < 10:
  703. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  704. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  705. nextval = frame.thing.val
  706. if codeno == 0:
  707. subframe = Frame()
  708. subframe.thing = Thing(nextval)
  709. nextval = portal(1, subframe)
  710. frame.thing = Thing(nextval + 1)
  711. i += 1
  712. return frame.thing.val
  713. res = self.meta_interp(main, [0], inline=True)
  714. assert res == main(0)
  715. def test_directly_call_assembler_virtualizable_reset_token(self):
  716. from rpython.rtyper.lltypesystem import lltype
  717. from rpython.rlib.debug import llinterpcall
  718. class Thing(object):
  719. def __init__(self, val):
  720. self.val = val
  721. class Frame(object):
  722. _virtualizable2_ = ['thing']
  723. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  724. virtualizables = ['frame'],
  725. get_printable_location = lambda codeno : str(codeno))
  726. @dont_look_inside
  727. def check_frame(subframe):
  728. if we_are_translated():
  729. llinterpcall(lltype.Void, check_ll_frame, subframe)
  730. def check_ll_frame(ll_subframe):
  731. # This is called with the low-level Struct that is the frame.
  732. # Check that the vable_token was correctly reset to zero.
  733. # Note that in order for that test to catch failures, it needs
  734. # three levels of recursion: the vable_token of the subframe
  735. # at the level 2 is set to a non-zero value when doing the
  736. # call to the level 3 only. This used to fail when the test
  737. # is run via rpython.jit.backend.x86.test.test_recursive.
  738. assert ll_subframe.vable_token == 0
  739. def main(codeno):
  740. frame = Frame()
  741. frame.thing = Thing(0)
  742. portal(codeno, frame)
  743. return frame.thing.val
  744. def portal(codeno, frame):
  745. i = 0
  746. while i < 5:
  747. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  748. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  749. nextval = frame.thing.val
  750. if codeno < 2:
  751. subframe = Frame()
  752. subframe.thing = Thing(nextval)
  753. nextval = portal(codeno + 1, subframe)
  754. check_frame(subframe)
  755. frame.thing = Thing(nextval + 1)
  756. i += 1
  757. return frame.thing.val
  758. res = self.meta_interp(main, [0], inline=True)
  759. assert res == main(0)
  760. def test_directly_call_assembler_virtualizable_force1(self):
  761. class Thing(object):
  762. def __init__(self, val):
  763. self.val = val
  764. class Frame(object):
  765. _virtualizable2_ = ['thing']
  766. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  767. virtualizables = ['frame'],
  768. get_printable_location = lambda codeno : str(codeno))
  769. class SomewhereElse(object):
  770. pass
  771. somewhere_else = SomewhereElse()
  772. def change(newthing):
  773. somewhere_else.frame.thing = newthing
  774. def main(codeno):
  775. frame = Frame()
  776. somewhere_else.frame = frame
  777. frame.thing = Thing(0)
  778. portal(codeno, frame)
  779. return frame.thing.val
  780. def portal(codeno, frame):
  781. print 'ENTER:', codeno, frame.thing.val
  782. i = 0
  783. while i < 10:
  784. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  785. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  786. nextval = frame.thing.val
  787. if codeno == 0:
  788. subframe = Frame()
  789. subframe.thing = Thing(nextval)
  790. nextval = portal(1, subframe)
  791. elif codeno == 1:
  792. if frame.thing.val > 40:
  793. change(Thing(13))
  794. nextval = 13
  795. else:
  796. fatalerror("bad codeno = " + str(codeno))
  797. frame.thing = Thing(nextval + 1)
  798. i += 1
  799. print 'LEAVE:', codeno, frame.thing.val
  800. return frame.thing.val
  801. res = self.meta_interp(main, [0], inline=True,
  802. policy=StopAtXPolicy(change))
  803. assert res == main(0)
  804. def test_directly_call_assembler_virtualizable_with_array(self):
  805. myjitdriver = JitDriver(greens = ['codeno'], reds = ['n', 'x', 'frame'],
  806. virtualizables = ['frame'])
  807. class Frame(object):
  808. _virtualizable2_ = ['l[*]', 's']
  809. def __init__(self, l, s):
  810. self = hint(self, access_directly=True,
  811. fresh_virtualizable=True)
  812. self.l = l
  813. self.s = s
  814. def main(codeno, n, a):
  815. frame = Frame([a, a+1, a+2, a+3], 0)
  816. return f(codeno, n, a, frame)
  817. def f(codeno, n, a, frame):
  818. x = 0
  819. while n > 0:
  820. myjitdriver.can_enter_jit(codeno=codeno, frame=frame, n=n, x=x)
  821. myjitdriver.jit_merge_point(codeno=codeno, frame=frame, n=n,
  822. x=x)
  823. frame.s = promote(frame.s)
  824. n -= 1
  825. s = frame.s
  826. assert s >= 0
  827. x += frame.l[s]
  828. frame.s += 1
  829. if codeno == 0:
  830. subframe = Frame([n, n+1, n+2, n+3], 0)
  831. x += f(1, 10, 1, subframe)
  832. s = frame.s
  833. assert s >= 0
  834. x += frame.l[s]
  835. x += len(frame.l)
  836. frame.s -= 1
  837. return x
  838. res = self.meta_interp(main, [0, 10, 1], listops=True, inline=True)
  839. assert res == main(0, 10, 1)
  840. def test_directly_call_assembler_virtualizable_force_blackhole(self):
  841. class Thing(object):
  842. def __init__(self, val):
  843. self.val = val
  844. class Frame(object):
  845. _virtualizable2_ = ['thing']
  846. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  847. virtualizables = ['frame'],
  848. get_printable_location = lambda codeno : str(codeno))
  849. class SomewhereElse(object):
  850. pass
  851. somewhere_else = SomewhereElse()
  852. def change(newthing, arg):
  853. print arg
  854. if arg > 30:
  855. somewhere_else.frame.thing = newthing
  856. arg = 13
  857. return arg
  858. def main(codeno):
  859. frame = Frame()
  860. somewhere_else.frame = frame
  861. frame.thing = Thing(0)
  862. portal(codeno, frame)
  863. return frame.thing.val
  864. def portal(codeno, frame):
  865. i = 0
  866. while i < 10:
  867. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  868. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  869. nextval = frame.thing.val
  870. if codeno == 0:
  871. subframe = Frame()
  872. subframe.thing = Thing(nextval)
  873. nextval = portal(1, subframe)
  874. else:
  875. nextval = change(Thing(13), frame.thing.val)
  876. frame.thing = Thing(nextval + 1)
  877. i += 1
  878. return frame.thing.val
  879. res = self.meta_interp(main, [0], inline=True,
  880. policy=StopAtXPolicy(change))
  881. assert res == main(0)
  882. def test_assembler_call_red_args(self):
  883. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  884. get_printable_location = lambda codeno : str(codeno))
  885. def residual(k):
  886. if k > 150:
  887. return 0
  888. return 1
  889. def portal(codeno, k):
  890. i = 0
  891. while i < 15:
  892. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  893. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  894. if codeno == 2:
  895. k += portal(residual(k), k)
  896. if codeno == 0:
  897. k += 2
  898. elif codeno == 1:
  899. k += 1
  900. i += 1
  901. return k
  902. res = self.meta_interp(portal, [2, 0], inline=True,
  903. policy=StopAtXPolicy(residual))
  904. assert res == portal(2, 0)
  905. self.check_resops(call_assembler=4)
  906. def test_inline_without_hitting_the_loop(self):
  907. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  908. get_printable_location = lambda codeno : str(codeno))
  909. def portal(codeno):
  910. i = 0
  911. while True:
  912. driver.jit_merge_point(codeno=codeno, i=i)
  913. if codeno < 10:
  914. i += portal(20)
  915. codeno += 1
  916. elif codeno == 10:
  917. if i > 63:
  918. return i
  919. codeno = 0
  920. driver.can_enter_jit(codeno=codeno, i=i)
  921. else:
  922. return 1
  923. assert portal(0) == 70
  924. res = self.meta_interp(portal, [0], inline=True)
  925. assert res == 70
  926. self.check_resops(call_assembler=0)
  927. def test_inline_with_hitting_the_loop_sometimes(self):
  928. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  929. get_printable_location = lambda codeno : str(codeno))
  930. def portal(codeno, k):
  931. if k > 2:
  932. return 1
  933. i = 0
  934. while True:
  935. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  936. if codeno < 10:
  937. i += portal(codeno + 5, k+1)
  938. codeno += 1
  939. elif codeno == 10:
  940. if i > [-1, 2000, 63][k]:
  941. return i
  942. codeno = 0
  943. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  944. else:
  945. return 1
  946. assert portal(0, 1) == 2095
  947. res = self.meta_interp(portal, [0, 1], inline=True)
  948. assert res == 2095
  949. self.check_resops(call_assembler=12)
  950. def test_inline_with_hitting_the_loop_sometimes_exc(self):
  951. driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
  952. get_printable_location = lambda codeno : str(codeno))
  953. class GotValue(Exception):
  954. def __init__(self, result):
  955. self.result = result
  956. def portal(codeno, k):
  957. if k > 2:
  958. raise GotValue(1)
  959. i = 0
  960. while True:
  961. driver.jit_merge_point(codeno=codeno, i=i, k=k)
  962. if codeno < 10:
  963. try:
  964. portal(codeno + 5, k+1)
  965. except GotValue, e:
  966. i += e.result
  967. codeno += 1
  968. elif codeno == 10:
  969. if i > [-1, 2000, 63][k]:
  970. raise GotValue(i)
  971. codeno = 0
  972. driver.can_enter_jit(codeno=codeno, i=i, k=k)
  973. else:
  974. raise GotValue(1)
  975. def main(codeno, k):
  976. try:
  977. portal(codeno, k)
  978. except GotValue, e:
  979. return e.result
  980. assert main(0, 1) == 2095
  981. res = self.meta_interp(main, [0, 1], inline=True)
  982. assert res == 2095
  983. self.check_resops(call_assembler=12)
  984. def test_handle_jitexception_in_portal(self):
  985. # a test for _handle_jitexception_in_portal in blackhole.py
  986. driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
  987. get_printable_location = lambda codeno: str(codeno))
  988. def do_can_enter_jit(codeno, i, str):
  989. i = (i+1)-1 # some operations
  990. driver.can_enter_jit(codeno=codeno, i=i, str=str)
  991. def intermediate(codeno, i, str):
  992. if i == 9:
  993. do_can_enter_jit(codeno, i, str)
  994. def portal(codeno, str):
  995. i = value.initial
  996. while i < 10:
  997. intermediate(codeno, i, str)
  998. driver.jit_merge_point(codeno=codeno, i=i, str=str)
  999. i += 1
  1000. if codeno == 64 and i == 10:
  1001. str = portal(96, str)
  1002. str += chr(codeno+i)
  1003. return str
  1004. class Value:
  1005. initial = -1
  1006. value = Value()
  1007. def main():
  1008. value.initial = 0
  1009. return (portal(64, '') +
  1010. portal(64, '') +
  1011. portal(64, '') +
  1012. portal(64, '') +
  1013. portal(64, ''))
  1014. assert main() == 'ABCDEFGHIabcdefghijJ' * 5
  1015. for tlimit in [95, 90, 102]:
  1016. print 'tlimit =', tlimit
  1017. res = self.meta_interp(main, [], inline=True, trace_limit=tlimit)
  1018. assert ''.join(res.chars) == 'ABCDEFGHIabcdefghijJ' * 5
  1019. def test_handle_jitexception_in_portal_returns_void(self):
  1020. # a test for _handle_jitexception_in_portal in blackhole.py
  1021. driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
  1022. get_printable_location = lambda codeno: str(codeno))
  1023. def do_can_enter_jit(codeno, i, str):
  1024. i = (i+1)-1 # some operations
  1025. driver.can_enter_jit(codeno=codeno, i=i, str=str)
  1026. def intermediate(codeno, i, str):
  1027. if i == 9:
  1028. do_can_enter_jit(codeno, i, str)
  1029. def portal(codeno, str):
  1030. i = value.initial
  1031. while i < 10:
  1032. intermediate(codeno, i, str)
  1033. driver.jit_merge_point(codeno=codeno, i=i, str=str)
  1034. i += 1
  1035. if codeno == 64 and i == 10:
  1036. portal(96, str)
  1037. str += chr(codeno+i)
  1038. class Value:
  1039. initial = -1
  1040. value = Value()
  1041. def main():
  1042. value.initial = 0
  1043. portal(64, '')
  1044. portal(64, '')
  1045. portal(64, '')
  1046. portal(64, '')
  1047. portal(64, '')
  1048. main()
  1049. for tlimit in [95, 90, 102]:
  1050. print 'tlimit =', tlimit
  1051. self.meta_interp(main, [], inline=True, trace_limit=tlimit)
  1052. def test_no_duplicates_bug(self):
  1053. driver = JitDriver(greens = ['codeno'], reds = ['i'],
  1054. get_printable_location = lambda codeno: str(codeno))
  1055. def portal(codeno, i):
  1056. while i > 0:
  1057. driver.can_enter_jit(codeno=codeno, i=i)
  1058. driver.jit_merge_point(codeno=codeno, i=i)
  1059. if codeno > 0:
  1060. break
  1061. portal(i, i)
  1062. i -= 1
  1063. self.meta_interp(portal, [0, 10], inline=True)
  1064. def test_trace_from_start_always(self):
  1065. from rpython.rlib.nonconst import NonConstant
  1066. driver = JitDriver(greens = ['c'], reds = ['i', 'v'])
  1067. def portal(c, i, v):
  1068. while i > 0:
  1069. driver.jit_merge_point(c=c, i=i, v=v)
  1070. portal(c, i - 1, v)
  1071. if v:
  1072. driver.can_enter_jit(c=c, i=i, v=v)
  1073. break
  1074. def main(c, i, _set_param, v):
  1075. if _set_param:
  1076. set_param(driver, 'function_threshold', 0)
  1077. portal(c, i, v)
  1078. self.meta_interp(main, [10, 10, False, False], inline=True)
  1079. self.check_jitcell_token_count(1)
  1080. self.check_trace_count(1)
  1081. self.meta_interp(main, [3, 10, True, False], inline=True)
  1082. self.check_jitcell_token_count(0)
  1083. self.check_trace_count(0)
  1084. def test_trace_from_start_does_not_prevent_inlining(self):
  1085. driver = JitDriver(greens = ['c', 'bc'], reds = ['i'])
  1086. def portal(bc, c, i):
  1087. while True:
  1088. driver.jit_merge_point(c=c, bc=bc, i=i)
  1089. if bc == 0:
  1090. portal(1, 8, 0)
  1091. c += 1
  1092. else:
  1093. return
  1094. if c == 10: # bc == 0
  1095. c = 0
  1096. if i >= 100:
  1097. return
  1098. driver.can_enter_jit(c=c, bc=bc, i=i)
  1099. i += 1
  1100. self.meta_interp(portal, [0, 0, 0], inline=True)
  1101. self.check_resops(call_may_force=0, call=0)
  1102. def test_dont_repeatedly_trace_from_the_same_guard(self):
  1103. driver = JitDriver(greens = [], reds = ['level', 'i'])
  1104. def portal(level):
  1105. if level == 0:
  1106. i = -10
  1107. else:
  1108. i = 0
  1109. #
  1110. while True:
  1111. driver.jit_merge_point(level=level, i=i)
  1112. if level == 25:
  1113. return 42
  1114. i += 1
  1115. if i <= 0: # <- guard
  1116. continue # first make a loop
  1117. else:
  1118. # then we fail the guard above, doing a recursive call,
  1119. # which will itself fail the same guard above, and so on
  1120. return portal(level + 1)
  1121. self.meta_interp(portal, [0])
  1122. self.check_trace_count_at_most(2) # and not, e.g., 24
  1123. class TestLLtype(RecursiveTests, LLJitMixin):
  1124. pass
  1125. class TestOOtype(RecursiveTests, OOJitMixin):
  1126. pass