PageRenderTime 398ms CodeModel.GetById 40ms app.highlight 286ms RepoModel.GetById 56ms app.codeStats 1ms

/rpython/jit/metainterp/test/test_recursive.py

https://bitbucket.org/kcr/pypy
Python | 1271 lines | 1269 code | 2 blank | 0 comment | 3 complexity | ff12caff748119a841a4cfd4e16d535f MD5 | raw file
   1import py
   2from rpython.rlib.jit import JitDriver, hint, set_param
   3from rpython.rlib.jit import unroll_safe, dont_look_inside, promote
   4from rpython.rlib.objectmodel import we_are_translated
   5from rpython.rlib.debug import fatalerror
   6from rpython.jit.metainterp.test.support import LLJitMixin, OOJitMixin
   7from rpython.jit.codewriter.policy import StopAtXPolicy
   8from rpython.rtyper.annlowlevel import hlstr
   9from rpython.jit.metainterp.warmspot import get_stats
  10
  11class RecursiveTests:
  12
  13    def test_simple_recursion(self):
  14        myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  15        def f(n):
  16            m = n - 2
  17            while True:
  18                myjitdriver.jit_merge_point(n=n, m=m)
  19                n -= 1
  20                if m == n:
  21                    return main(n) * 2
  22                myjitdriver.can_enter_jit(n=n, m=m)
  23        def main(n):
  24            if n > 0:
  25                return f(n+1)
  26            else:
  27                return 1
  28        res = self.meta_interp(main, [20], enable_opts='')
  29        assert res == main(20)
  30        self.check_history(call=0)
  31
  32    def test_simple_recursion_with_exc(self):
  33        myjitdriver = JitDriver(greens=[], reds=['n', 'm'])
  34        class Error(Exception):
  35            pass
  36        
  37        def f(n):
  38            m = n - 2
  39            while True:
  40                myjitdriver.jit_merge_point(n=n, m=m)
  41                n -= 1
  42                if n == 10:
  43                    raise Error
  44                if m == n:
  45                    try:
  46                        return main(n) * 2
  47                    except Error:
  48                        return 2
  49                myjitdriver.can_enter_jit(n=n, m=m)
  50        def main(n):
  51            if n > 0:
  52                return f(n+1)
  53            else:
  54                return 1
  55        res = self.meta_interp(main, [20], enable_opts='')
  56        assert res == main(20)
  57
  58    def test_recursion_three_times(self):
  59        myjitdriver = JitDriver(greens=[], reds=['n', 'm', 'total'])
  60        def f(n):
  61            m = n - 3
  62            total = 0
  63            while True:
  64                myjitdriver.jit_merge_point(n=n, m=m, total=total)
  65                n -= 1
  66                total += main(n)
  67                if m == n:
  68                    return total + 5
  69                myjitdriver.can_enter_jit(n=n, m=m, total=total)
  70        def main(n):
  71            if n > 0:
  72                return f(n)
  73            else:
  74                return 1
  75        print
  76        for i in range(1, 11):
  77            print '%3d %9d' % (i, f(i))
  78        res = self.meta_interp(main, [10], enable_opts='')
  79        assert res == main(10)
  80        self.check_enter_count_at_most(11)
  81
  82    def test_bug_1(self):
  83        myjitdriver = JitDriver(greens=[], reds=['n', 'i', 'stack'])
  84        def opaque(n, i):
  85            if n == 1 and i == 19:
  86                for j in range(20):
  87                    res = f(0)      # recurse repeatedly, 20 times
  88                    assert res == 0
  89        def f(n):
  90            stack = [n]
  91            i = 0
  92            while i < 20:
  93                myjitdriver.can_enter_jit(n=n, i=i, stack=stack)
  94                myjitdriver.jit_merge_point(n=n, i=i, stack=stack)
  95                opaque(n, i)
  96                i += 1
  97            return stack.pop()
  98        res = self.meta_interp(f, [1], enable_opts='', repeat=2,
  99                               policy=StopAtXPolicy(opaque))
 100        assert res == 1
 101
 102    def get_interpreter(self, codes):
 103        ADD = "0"
 104        JUMP_BACK = "1"
 105        CALL = "2"
 106        EXIT = "3"
 107
 108        def getloc(i, code):
 109            return 'code="%s", i=%d' % (code, i)
 110
 111        jitdriver = JitDriver(greens = ['i', 'code'], reds = ['n'],
 112                              get_printable_location = getloc)
 113
 114        def interpret(codenum, n, i):
 115            code = codes[codenum]
 116            while i < len(code):
 117                jitdriver.jit_merge_point(n=n, i=i, code=code)
 118                op = code[i]
 119                if op == ADD:
 120                    n += 1
 121                    i += 1
 122                elif op == CALL:
 123                    n = interpret(1, n, 1)
 124                    i += 1
 125                elif op == JUMP_BACK:
 126                    if n > 20:
 127                        return 42
 128                    i -= 2
 129                    jitdriver.can_enter_jit(n=n, i=i, code=code)
 130                elif op == EXIT:
 131                    return n
 132                else:
 133                    raise NotImplementedError
 134            return n
 135
 136        return interpret
 137
 138    def test_inline(self):
 139        code = "021"
 140        subcode = "00"
 141
 142        codes = [code, subcode]
 143        f = self.get_interpreter(codes)
 144
 145        assert self.meta_interp(f, [0, 0, 0], enable_opts='') == 42
 146        self.check_resops(call_may_force=1, int_add=1, call=0)
 147        assert self.meta_interp(f, [0, 0, 0], enable_opts='',
 148                                inline=True) == 42
 149        self.check_resops(call=0, int_add=2, call_may_force=0,
 150                          guard_no_exception=0)
 151
 152    def test_inline_jitdriver_check(self):
 153        code = "021"
 154        subcode = "100"
 155        codes = [code, subcode]
 156
 157        f = self.get_interpreter(codes)
 158
 159        assert self.meta_interp(f, [0, 0, 0], enable_opts='',
 160                                inline=True) == 42
 161        # the call is fully inlined, because we jump to subcode[1], thus
 162        # skipping completely the JUMP_BACK in subcode[0]
 163        self.check_resops(call=0, call_may_force=0, call_assembler=0)
 164
 165    def test_guard_failure_in_inlined_function(self):
 166        def p(pc, code):
 167            code = hlstr(code)
 168            return "%s %d %s" % (code, pc, code[pc])
 169        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
 170                                get_printable_location=p)
 171        def f(code, n):
 172            pc = 0
 173            while pc < len(code):
 174
 175                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
 176                op = code[pc]
 177                if op == "-":
 178                    n -= 1
 179                elif op == "c":
 180                    n = f("---i---", n)
 181                elif op == "i":
 182                    if n % 5 == 1:
 183                        return n
 184                elif op == "l":
 185                    if n > 0:
 186                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
 187                        pc = 0
 188                        continue
 189                else:
 190                    assert 0
 191                pc += 1
 192            return n
 193        def main(n):
 194            return f("c-l", n)
 195        print main(100)
 196        res = self.meta_interp(main, [100], enable_opts='', inline=True)
 197        assert res == 0
 198
 199    def test_guard_failure_and_then_exception_in_inlined_function(self):
 200        def p(pc, code):
 201            code = hlstr(code)
 202            return "%s %d %s" % (code, pc, code[pc])
 203        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n', 'flag'],
 204                                get_printable_location=p)
 205        def f(code, n):
 206            pc = 0
 207            flag = False
 208            while pc < len(code):
 209
 210                myjitdriver.jit_merge_point(n=n, code=code, pc=pc, flag=flag)
 211                op = code[pc]
 212                if op == "-":
 213                    n -= 1
 214                elif op == "c":
 215                    try:
 216                        n = f("---ir---", n)
 217                    except Exception:
 218                        return n
 219                elif op == "i":
 220                    if n < 200:
 221                        flag = True
 222                elif op == "r":
 223                    if flag:
 224                        raise Exception
 225                elif op == "l":
 226                    if n > 0:
 227                        myjitdriver.can_enter_jit(n=n, code=code, pc=0, flag=flag)
 228                        pc = 0
 229                        continue
 230                else:
 231                    assert 0
 232                pc += 1
 233            return n
 234        def main(n):
 235            return f("c-l", n)
 236        print main(1000)
 237        res = self.meta_interp(main, [1000], enable_opts='', inline=True)
 238        assert res == main(1000)
 239
 240    def test_exception_in_inlined_function(self):
 241        def p(pc, code):
 242            code = hlstr(code)
 243            return "%s %d %s" % (code, pc, code[pc])
 244        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
 245                                get_printable_location=p)
 246
 247        class Exc(Exception):
 248            pass
 249        
 250        def f(code, n):
 251            pc = 0
 252            while pc < len(code):
 253
 254                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
 255                op = code[pc]
 256                if op == "-":
 257                    n -= 1
 258                elif op == "c":
 259                    try:
 260                        n = f("---i---", n)
 261                    except Exc:
 262                        pass
 263                elif op == "i":
 264                    if n % 5 == 1:
 265                        raise Exc
 266                elif op == "l":
 267                    if n > 0:
 268                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
 269                        pc = 0
 270                        continue
 271                else:
 272                    assert 0
 273                pc += 1
 274            return n
 275        def main(n):
 276            return f("c-l", n)
 277        res = self.meta_interp(main, [100], enable_opts='', inline=True)
 278        assert res == main(100)
 279
 280    def test_recurse_during_blackholing(self):
 281        # this passes, if the blackholing shortcut for calls is turned off
 282        # it fails, it is very delicate in terms of parameters,
 283        # bridge/loop creation order
 284        def p(pc, code):
 285            code = hlstr(code)
 286            return "%s %d %s" % (code, pc, code[pc])
 287        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
 288                                get_printable_location=p)
 289        
 290        def f(code, n):
 291            pc = 0
 292            while pc < len(code):
 293
 294                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
 295                op = code[pc]
 296                if op == "-":
 297                    n -= 1
 298                elif op == "c":
 299                    if n < 70 and n % 3 == 1:
 300                        n = f("--", n)
 301                elif op == "l":
 302                    if n > 0:
 303                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
 304                        pc = 0
 305                        continue
 306                else:
 307                    assert 0
 308                pc += 1
 309            return n
 310        def main(n):
 311            set_param(None, 'threshold', 3)
 312            set_param(None, 'trace_eagerness', 5)            
 313            return f("c-l", n)
 314        expected = main(100)
 315        res = self.meta_interp(main, [100], enable_opts='', inline=True)
 316        assert res == expected
 317
 318    def check_max_trace_length(self, length):
 319        for loop in get_stats().loops:
 320            assert len(loop.operations) <= length + 5 # because we only check once per metainterp bytecode
 321            for op in loop.operations:
 322                if op.is_guard() and hasattr(op.getdescr(), '_debug_suboperations'):
 323                    assert len(op.getdescr()._debug_suboperations) <= length + 5
 324
 325    def test_inline_trace_limit(self):
 326        myjitdriver = JitDriver(greens=[], reds=['n'])
 327        def recursive(n):
 328            if n > 0:
 329                return recursive(n - 1) + 1
 330            return 0
 331        def loop(n):            
 332            set_param(myjitdriver, "threshold", 10)
 333            pc = 0
 334            while n:
 335                myjitdriver.can_enter_jit(n=n)
 336                myjitdriver.jit_merge_point(n=n)
 337                n = recursive(n)
 338                n -= 1
 339            return n
 340        TRACE_LIMIT = 66
 341        res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
 342        assert res == 0
 343        self.check_max_trace_length(TRACE_LIMIT)
 344        self.check_enter_count_at_most(10) # maybe
 345        self.check_aborted_count(7)
 346
 347    def test_trace_limit_bridge(self):
 348        def recursive(n):
 349            if n > 0:
 350                return recursive(n - 1) + 1
 351            return 0
 352        myjitdriver = JitDriver(greens=[], reds=['n'])
 353        def loop(n):
 354            set_param(None, "threshold", 4)
 355            set_param(None, "trace_eagerness", 2)
 356            while n:
 357                myjitdriver.can_enter_jit(n=n)
 358                myjitdriver.jit_merge_point(n=n)
 359                if n % 5 == 0:
 360                    n -= 1
 361                if n < 50:
 362                    n = recursive(n)
 363                n -= 1
 364            return n
 365        TRACE_LIMIT = 20
 366        res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT)
 367        self.check_max_trace_length(TRACE_LIMIT)
 368        self.check_aborted_count(8)
 369        self.check_enter_count_at_most(30)
 370
 371    def test_trace_limit_with_exception_bug(self):
 372        myjitdriver = JitDriver(greens=[], reds=['n'])
 373        @unroll_safe
 374        def do_stuff(n):
 375            while n > 0:
 376                n -= 1
 377            raise ValueError
 378        def loop(n):
 379            pc = 0
 380            while n > 80:
 381                myjitdriver.can_enter_jit(n=n)
 382                myjitdriver.jit_merge_point(n=n)
 383                try:
 384                    do_stuff(n)
 385                except ValueError:
 386                    # the trace limit is checked when we arrive here, and we
 387                    # have the exception still in last_exc_value_box at this
 388                    # point -- so when we abort because of a trace too long,
 389                    # the exception is passed to the blackhole interp and
 390                    # incorrectly re-raised from here
 391                    pass
 392                n -= 1
 393            return n
 394        TRACE_LIMIT = 66
 395        res = self.meta_interp(loop, [100], trace_limit=TRACE_LIMIT)
 396        assert res == 80
 397
 398    def test_max_failure_args(self):
 399        FAILARGS_LIMIT = 10
 400        jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
 401
 402        class A(object):
 403            def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
 404                self.i0 = i0
 405                self.i1 = i1
 406                self.i2 = i2
 407                self.i3 = i3
 408                self.i4 = i4
 409                self.i5 = i5
 410                self.i6 = i6
 411                self.i7 = i7
 412                self.i8 = i8
 413                self.i9 = i9
 414                
 415        
 416        def loop(n):
 417            i = 0
 418            o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
 419            while i < n:
 420                jitdriver.can_enter_jit(o=o, i=i, n=n)
 421                jitdriver.jit_merge_point(o=o, i=i, n=n)
 422                o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
 423                      i + 6, i + 7, i + 8, i + 9)
 424                i += 1
 425            return o
 426
 427        res = self.meta_interp(loop, [20], failargs_limit=FAILARGS_LIMIT,
 428                               listops=True)
 429        self.check_aborted_count(5)
 430
 431    def test_max_failure_args_exc(self):
 432        FAILARGS_LIMIT = 10
 433        jitdriver = JitDriver(greens = [], reds = ['i', 'n', 'o'])
 434
 435        class A(object):
 436            def __init__(self, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9):
 437                self.i0 = i0
 438                self.i1 = i1
 439                self.i2 = i2
 440                self.i3 = i3
 441                self.i4 = i4
 442                self.i5 = i5
 443                self.i6 = i6
 444                self.i7 = i7
 445                self.i8 = i8
 446                self.i9 = i9
 447                
 448        
 449        def loop(n):
 450            i = 0
 451            o = A(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
 452            while i < n:
 453                jitdriver.can_enter_jit(o=o, i=i, n=n)
 454                jitdriver.jit_merge_point(o=o, i=i, n=n)
 455                o = A(i, i + 1, i + 2, i + 3, i + 4, i + 5,
 456                      i + 6, i + 7, i + 8, i + 9)
 457                i += 1
 458            raise ValueError
 459
 460        def main(n):
 461            try:
 462                loop(n)
 463                return 1
 464            except ValueError:
 465                return 0
 466
 467        res = self.meta_interp(main, [20], failargs_limit=FAILARGS_LIMIT,
 468                               listops=True)
 469        assert not res
 470        self.check_aborted_count(5)        
 471
 472    def test_set_param_inlining(self):
 473        myjitdriver = JitDriver(greens=[], reds=['n', 'recurse'])
 474        def loop(n, recurse=False):
 475            while n:
 476                myjitdriver.jit_merge_point(n=n, recurse=recurse)
 477                n -= 1
 478                if not recurse:
 479                    loop(10, True)
 480                    myjitdriver.can_enter_jit(n=n, recurse=recurse)
 481            return n
 482        TRACE_LIMIT = 66
 483 
 484        def main(inline):
 485            set_param(None, "threshold", 10)
 486            set_param(None, 'function_threshold', 60)
 487            if inline:
 488                set_param(None, 'inlining', True)
 489            else:
 490                set_param(None, 'inlining', False)
 491            return loop(100)
 492
 493        res = self.meta_interp(main, [0], enable_opts='', trace_limit=TRACE_LIMIT)
 494        self.check_resops(call=0, call_may_force=1)
 495
 496        res = self.meta_interp(main, [1], enable_opts='', trace_limit=TRACE_LIMIT)
 497        self.check_resops(call=0, call_may_force=0)
 498
 499    def test_trace_from_start(self):
 500        def p(pc, code):
 501            code = hlstr(code)
 502            return "'%s' at %d: %s" % (code, pc, code[pc])
 503        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
 504                                get_printable_location=p)
 505        
 506        def f(code, n):
 507            pc = 0
 508            while pc < len(code):
 509
 510                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
 511                op = code[pc]
 512                if op == "+":
 513                    n += 7
 514                elif op == "-":
 515                    n -= 1
 516                elif op == "c":
 517                    n = f('---', n)
 518                elif op == "l":
 519                    if n > 0:
 520                        myjitdriver.can_enter_jit(n=n, code=code, pc=1)
 521                        pc = 1
 522                        continue
 523                else:
 524                    assert 0
 525                pc += 1
 526            return n
 527        def g(m):
 528            if m > 1000000:
 529                f('', 0)
 530            result = 0
 531            for i in range(m):
 532                result += f('+-cl--', i)
 533        res = self.meta_interp(g, [50], backendopt=True)
 534        assert res == g(50)
 535        py.test.skip("tracing from start is by now only longer enabled "
 536                     "if a trace gets too big")
 537        self.check_tree_loop_count(3)
 538        self.check_history(int_add=1)
 539
 540    def test_dont_inline_huge_stuff(self):
 541        def p(pc, code):
 542            code = hlstr(code)
 543            return "%s %d %s" % (code, pc, code[pc])
 544        myjitdriver = JitDriver(greens=['pc', 'code'], reds=['n'],
 545                                get_printable_location=p)
 546        
 547        def f(code, n):
 548            pc = 0
 549            while pc < len(code):
 550
 551                myjitdriver.jit_merge_point(n=n, code=code, pc=pc)
 552                op = code[pc]
 553                if op == "-":
 554                    n -= 1
 555                elif op == "c":
 556                    f('--------------------', n)
 557                elif op == "l":
 558                    if n > 0:
 559                        myjitdriver.can_enter_jit(n=n, code=code, pc=0)
 560                        pc = 0
 561                        continue
 562                else:
 563                    assert 0
 564                pc += 1
 565            return n
 566        def g(m):
 567            set_param(None, 'inlining', True)
 568            # carefully chosen threshold to make sure that the inner function
 569            # cannot be inlined, but the inner function on its own is small
 570            # enough
 571            set_param(None, 'trace_limit', 40)
 572            if m > 1000000:
 573                f('', 0)
 574            result = 0
 575            for i in range(m):
 576                result += f('-c-----------l-', i+100)
 577        self.meta_interp(g, [10], backendopt=True)
 578        self.check_aborted_count(1)
 579        self.check_resops(call=0, call_assembler=2)        
 580        self.check_jitcell_token_count(2)
 581
 582    def test_directly_call_assembler(self):
 583        driver = JitDriver(greens = ['codeno'], reds = ['i'],
 584                           get_printable_location = lambda codeno : str(codeno))
 585
 586        def portal(codeno):
 587            i = 0
 588            while i < 10:
 589                driver.can_enter_jit(codeno = codeno, i = i)
 590                driver.jit_merge_point(codeno = codeno, i = i)
 591                if codeno == 2:
 592                    portal(1)
 593                i += 1
 594
 595        self.meta_interp(portal, [2], inline=True)
 596        self.check_history(call_assembler=1)
 597
 598    def test_recursion_cant_call_assembler_directly(self):
 599        driver = JitDriver(greens = ['codeno'], reds = ['i', 'j'],
 600                           get_printable_location = lambda codeno : str(codeno))
 601
 602        def portal(codeno, j):
 603            i = 1
 604            while 1:
 605                driver.jit_merge_point(codeno=codeno, i=i, j=j)
 606                if (i >> 1) == 1:
 607                    if j == 0:
 608                        return
 609                    portal(2, j - 1)
 610                elif i == 5:
 611                    return
 612                i += 1
 613                driver.can_enter_jit(codeno=codeno, i=i, j=j)
 614
 615        portal(2, 5)
 616
 617        from rpython.jit.metainterp import compile, pyjitpl
 618        pyjitpl._warmrunnerdesc = None
 619        trace = []
 620        def my_ctc(*args):
 621            looptoken = original_ctc(*args)
 622            trace.append(looptoken)
 623            return looptoken
 624        original_ctc = compile.compile_tmp_callback
 625        try:
 626            compile.compile_tmp_callback = my_ctc
 627            self.meta_interp(portal, [2, 5], inline=True)
 628            self.check_resops(call_may_force=0, call_assembler=2)
 629        finally:
 630            compile.compile_tmp_callback = original_ctc
 631        # check that we made a temporary callback
 632        assert len(trace) == 1
 633        # and that we later redirected it to something else
 634        try:
 635            redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
 636        except AttributeError:
 637            pass    # not the llgraph backend
 638        else:
 639            print redirected
 640            assert redirected.keys() == trace
 641
 642    def test_recursion_cant_call_assembler_directly_with_virtualizable(self):
 643        # exactly the same logic as the previous test, but with 'frame.j'
 644        # instead of just 'j'
 645        class Frame(object):
 646            _virtualizable2_ = ['j']
 647            def __init__(self, j):
 648                self.j = j
 649
 650        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
 651                           virtualizables = ['frame'],
 652                           get_printable_location = lambda codeno : str(codeno))
 653
 654        def portal(codeno, frame):
 655            i = 1
 656            while 1:
 657                driver.jit_merge_point(codeno=codeno, i=i, frame=frame)
 658                if (i >> 1) == 1:
 659                    if frame.j == 0:
 660                        return
 661                    portal(2, Frame(frame.j - 1))
 662                elif i == 5:
 663                    return
 664                i += 1
 665                driver.can_enter_jit(codeno=codeno, i=i, frame=frame)
 666
 667        def main(codeno, j):
 668            portal(codeno, Frame(j))
 669
 670        main(2, 5)
 671
 672        from rpython.jit.metainterp import compile, pyjitpl
 673        pyjitpl._warmrunnerdesc = None
 674        trace = []
 675        def my_ctc(*args):
 676            looptoken = original_ctc(*args)
 677            trace.append(looptoken)
 678            return looptoken
 679        original_ctc = compile.compile_tmp_callback
 680        try:
 681            compile.compile_tmp_callback = my_ctc
 682            self.meta_interp(main, [2, 5], inline=True)
 683            self.check_resops(call_may_force=0, call_assembler=2)
 684        finally:
 685            compile.compile_tmp_callback = original_ctc
 686        # check that we made a temporary callback
 687        assert len(trace) == 1
 688        # and that we later redirected it to something else
 689        try:
 690            redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler
 691        except AttributeError:
 692            pass    # not the llgraph backend
 693        else:
 694            print redirected
 695            assert redirected.keys() == trace
 696
 697    def test_directly_call_assembler_return(self):
 698        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
 699                           get_printable_location = lambda codeno : str(codeno))
 700
 701        def portal(codeno):
 702            i = 0
 703            k = codeno
 704            while i < 10:
 705                driver.can_enter_jit(codeno = codeno, i = i, k = k)
 706                driver.jit_merge_point(codeno = codeno, i = i, k = k)
 707                if codeno == 2:
 708                    k = portal(1)
 709                i += 1
 710            return k
 711
 712        self.meta_interp(portal, [2], inline=True)
 713        self.check_history(call_assembler=1)
 714
 715    def test_directly_call_assembler_raise(self):
 716
 717        class MyException(Exception):
 718            def __init__(self, x):
 719                self.x = x
 720        
 721        driver = JitDriver(greens = ['codeno'], reds = ['i'],
 722                           get_printable_location = lambda codeno : str(codeno))
 723
 724        def portal(codeno):
 725            i = 0
 726            while i < 10:
 727                driver.can_enter_jit(codeno = codeno, i = i)
 728                driver.jit_merge_point(codeno = codeno, i = i)
 729                if codeno == 2:
 730                    try:
 731                        portal(1)
 732                    except MyException, me:
 733                        i += me.x
 734                i += 1
 735            if codeno == 1:
 736                raise MyException(1)
 737
 738        self.meta_interp(portal, [2], inline=True)
 739        self.check_history(call_assembler=1)        
 740
 741    def test_directly_call_assembler_fail_guard(self):
 742        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
 743                           get_printable_location = lambda codeno : str(codeno))
 744
 745        def portal(codeno, k):
 746            i = 0
 747            while i < 10:
 748                driver.can_enter_jit(codeno=codeno, i=i, k=k)
 749                driver.jit_merge_point(codeno=codeno, i=i, k=k)
 750                if codeno == 2:
 751                    k += portal(1, k)
 752                elif k > 40:
 753                    if i % 2:
 754                        k += 1
 755                    else:
 756                        k += 2
 757                k += 1
 758                i += 1
 759            return k
 760
 761        res = self.meta_interp(portal, [2, 0], inline=True)
 762        assert res == 13542
 763
 764    def test_directly_call_assembler_virtualizable(self):
 765        class Thing(object):
 766            def __init__(self, val):
 767                self.val = val
 768        
 769        class Frame(object):
 770            _virtualizable2_ = ['thing']
 771        
 772        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
 773                           virtualizables = ['frame'],
 774                           get_printable_location = lambda codeno : str(codeno))
 775
 776        def main(codeno):
 777            frame = Frame()
 778            frame.thing = Thing(0)
 779            portal(codeno, frame)
 780            return frame.thing.val
 781
 782        def portal(codeno, frame):
 783            i = 0
 784            while i < 10:
 785                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
 786                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
 787                nextval = frame.thing.val
 788                if codeno == 0:
 789                    subframe = Frame()
 790                    subframe.thing = Thing(nextval)
 791                    nextval = portal(1, subframe)
 792                frame.thing = Thing(nextval + 1)
 793                i += 1
 794            return frame.thing.val
 795
 796        res = self.meta_interp(main, [0], inline=True)
 797        assert res == main(0)
 798
 799    def test_directly_call_assembler_virtualizable_reset_token(self):
 800        from rpython.rtyper.lltypesystem import lltype
 801        from rpython.rlib.debug import llinterpcall
 802
 803        class Thing(object):
 804            def __init__(self, val):
 805                self.val = val
 806        
 807        class Frame(object):
 808            _virtualizable2_ = ['thing']
 809        
 810        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
 811                           virtualizables = ['frame'],
 812                           get_printable_location = lambda codeno : str(codeno))
 813
 814        @dont_look_inside
 815        def check_frame(subframe):
 816            if we_are_translated():
 817                llinterpcall(lltype.Void, check_ll_frame, subframe)
 818        def check_ll_frame(ll_subframe):
 819            # This is called with the low-level Struct that is the frame.
 820            # Check that the vable_token was correctly reset to zero.
 821            # Note that in order for that test to catch failures, it needs
 822            # three levels of recursion: the vable_token of the subframe
 823            # at the level 2 is set to a non-zero value when doing the
 824            # call to the level 3 only.  This used to fail when the test
 825            # is run via rpython.jit.backend.x86.test.test_recursive.
 826            from rpython.jit.metainterp.virtualizable import TOKEN_NONE
 827            assert ll_subframe.vable_token == TOKEN_NONE
 828
 829        def main(codeno):
 830            frame = Frame()
 831            frame.thing = Thing(0)
 832            portal(codeno, frame)
 833            return frame.thing.val
 834
 835        def portal(codeno, frame):
 836            i = 0
 837            while i < 5:
 838                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
 839                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
 840                nextval = frame.thing.val
 841                if codeno < 2:
 842                    subframe = Frame()
 843                    subframe.thing = Thing(nextval)
 844                    nextval = portal(codeno + 1, subframe)
 845                    check_frame(subframe)
 846                frame.thing = Thing(nextval + 1)
 847                i += 1
 848            return frame.thing.val
 849
 850        res = self.meta_interp(main, [0], inline=True)
 851        assert res == main(0)
 852
 853    def test_directly_call_assembler_virtualizable_force1(self):
 854        class Thing(object):
 855            def __init__(self, val):
 856                self.val = val
 857        
 858        class Frame(object):
 859            _virtualizable2_ = ['thing']
 860        
 861        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
 862                           virtualizables = ['frame'],
 863                           get_printable_location = lambda codeno : str(codeno))
 864        class SomewhereElse(object):
 865            pass
 866
 867        somewhere_else = SomewhereElse()
 868
 869        def change(newthing):
 870            somewhere_else.frame.thing = newthing
 871
 872        def main(codeno):
 873            frame = Frame()
 874            somewhere_else.frame = frame
 875            frame.thing = Thing(0)
 876            portal(codeno, frame)
 877            return frame.thing.val
 878
 879        def portal(codeno, frame):
 880            print 'ENTER:', codeno, frame.thing.val
 881            i = 0
 882            while i < 10:
 883                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
 884                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
 885                nextval = frame.thing.val
 886                if codeno == 0:
 887                    subframe = Frame()
 888                    subframe.thing = Thing(nextval)
 889                    nextval = portal(1, subframe)
 890                elif codeno == 1:
 891                    if frame.thing.val > 40:
 892                        change(Thing(13))
 893                        nextval = 13
 894                else:
 895                    fatalerror("bad codeno = " + str(codeno))
 896                frame.thing = Thing(nextval + 1)
 897                i += 1
 898            print 'LEAVE:', codeno, frame.thing.val
 899            return frame.thing.val
 900
 901        res = self.meta_interp(main, [0], inline=True,
 902                               policy=StopAtXPolicy(change))
 903        assert res == main(0)
 904
 905    def test_directly_call_assembler_virtualizable_with_array(self):
 906        myjitdriver = JitDriver(greens = ['codeno'], reds = ['n', 'x', 'frame'],
 907                                virtualizables = ['frame'])
 908
 909        class Frame(object):
 910            _virtualizable2_ = ['l[*]', 's']
 911
 912            def __init__(self, l, s):
 913                self = hint(self, access_directly=True,
 914                            fresh_virtualizable=True)
 915                self.l = l
 916                self.s = s
 917
 918        def main(codeno, n, a):
 919            frame = Frame([a, a+1, a+2, a+3], 0)
 920            return f(codeno, n, a, frame)
 921        
 922        def f(codeno, n, a, frame):
 923            x = 0
 924            while n > 0:
 925                myjitdriver.can_enter_jit(codeno=codeno, frame=frame, n=n, x=x)
 926                myjitdriver.jit_merge_point(codeno=codeno, frame=frame, n=n,
 927                                            x=x)
 928                frame.s = promote(frame.s)
 929                n -= 1
 930                s = frame.s
 931                assert s >= 0
 932                x += frame.l[s]
 933                frame.s += 1
 934                if codeno == 0:
 935                    subframe = Frame([n, n+1, n+2, n+3], 0)
 936                    x += f(1, 10, 1, subframe)
 937                s = frame.s
 938                assert s >= 0
 939                x += frame.l[s]
 940                x += len(frame.l)
 941                frame.s -= 1
 942            return x
 943
 944        res = self.meta_interp(main, [0, 10, 1], listops=True, inline=True)
 945        assert res == main(0, 10, 1)
 946
 947    def test_directly_call_assembler_virtualizable_force_blackhole(self):
 948        class Thing(object):
 949            def __init__(self, val):
 950                self.val = val
 951        
 952        class Frame(object):
 953            _virtualizable2_ = ['thing']
 954        
 955        driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
 956                           virtualizables = ['frame'],
 957                           get_printable_location = lambda codeno : str(codeno))
 958        class SomewhereElse(object):
 959            pass
 960
 961        somewhere_else = SomewhereElse()
 962
 963        def change(newthing, arg):
 964            print arg
 965            if arg > 30:
 966                somewhere_else.frame.thing = newthing
 967                arg = 13
 968            return arg
 969
 970        def main(codeno):
 971            frame = Frame()
 972            somewhere_else.frame = frame
 973            frame.thing = Thing(0)
 974            portal(codeno, frame)
 975            return frame.thing.val
 976
 977        def portal(codeno, frame):
 978            i = 0
 979            while i < 10:
 980                driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
 981                driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
 982                nextval = frame.thing.val
 983                if codeno == 0:
 984                    subframe = Frame()
 985                    subframe.thing = Thing(nextval)
 986                    nextval = portal(1, subframe)
 987                else:
 988                    nextval = change(Thing(13), frame.thing.val)
 989                frame.thing = Thing(nextval + 1)
 990                i += 1
 991            return frame.thing.val
 992
 993        res = self.meta_interp(main, [0], inline=True,
 994                               policy=StopAtXPolicy(change))
 995        assert res == main(0)
 996
 997    def test_assembler_call_red_args(self):
 998        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
 999                           get_printable_location = lambda codeno : str(codeno))
1000
1001        def residual(k):
1002            if k > 150:
1003                return 0
1004            return 1
1005
1006        def portal(codeno, k):
1007            i = 0
1008            while i < 15:
1009                driver.can_enter_jit(codeno=codeno, i=i, k=k)
1010                driver.jit_merge_point(codeno=codeno, i=i, k=k)
1011                if codeno == 2:
1012                    k += portal(residual(k), k)
1013                if codeno == 0:
1014                    k += 2
1015                elif codeno == 1:
1016                    k += 1
1017                i += 1
1018            return k
1019
1020        res = self.meta_interp(portal, [2, 0], inline=True,
1021                               policy=StopAtXPolicy(residual))
1022        assert res == portal(2, 0)
1023        self.check_resops(call_assembler=4)
1024
1025    def test_inline_without_hitting_the_loop(self):
1026        driver = JitDriver(greens = ['codeno'], reds = ['i'],
1027                           get_printable_location = lambda codeno : str(codeno))
1028
1029        def portal(codeno):
1030            i = 0
1031            while True:
1032                driver.jit_merge_point(codeno=codeno, i=i)
1033                if codeno < 10:
1034                    i += portal(20)
1035                    codeno += 1
1036                elif codeno == 10:
1037                    if i > 63:
1038                        return i
1039                    codeno = 0
1040                    driver.can_enter_jit(codeno=codeno, i=i)
1041                else:
1042                    return 1
1043
1044        assert portal(0) == 70
1045        res = self.meta_interp(portal, [0], inline=True)
1046        assert res == 70
1047        self.check_resops(call_assembler=0)
1048
1049    def test_inline_with_hitting_the_loop_sometimes(self):
1050        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
1051                           get_printable_location = lambda codeno : str(codeno))
1052
1053        def portal(codeno, k):
1054            if k > 2:
1055                return 1
1056            i = 0
1057            while True:
1058                driver.jit_merge_point(codeno=codeno, i=i, k=k)
1059                if codeno < 10:
1060                    i += portal(codeno + 5, k+1)
1061                    codeno += 1
1062                elif codeno == 10:
1063                    if i > [-1, 2000, 63][k]:
1064                        return i
1065                    codeno = 0
1066                    driver.can_enter_jit(codeno=codeno, i=i, k=k)
1067                else:
1068                    return 1
1069
1070        assert portal(0, 1) == 2095
1071        res = self.meta_interp(portal, [0, 1], inline=True)
1072        assert res == 2095
1073        self.check_resops(call_assembler=12)
1074
1075    def test_inline_with_hitting_the_loop_sometimes_exc(self):
1076        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
1077                           get_printable_location = lambda codeno : str(codeno))
1078        class GotValue(Exception):
1079            def __init__(self, result):
1080                self.result = result
1081
1082        def portal(codeno, k):
1083            if k > 2:
1084                raise GotValue(1)
1085            i = 0
1086            while True:
1087                driver.jit_merge_point(codeno=codeno, i=i, k=k)
1088                if codeno < 10:
1089                    try:
1090                        portal(codeno + 5, k+1)
1091                    except GotValue, e:
1092                        i += e.result
1093                    codeno += 1
1094                elif codeno == 10:
1095                    if i > [-1, 2000, 63][k]:
1096                        raise GotValue(i)
1097                    codeno = 0
1098                    driver.can_enter_jit(codeno=codeno, i=i, k=k)
1099                else:
1100                    raise GotValue(1)
1101
1102        def main(codeno, k):
1103            try:
1104                portal(codeno, k)
1105            except GotValue, e:
1106                return e.result
1107
1108        assert main(0, 1) == 2095
1109        res = self.meta_interp(main, [0, 1], inline=True)
1110        assert res == 2095
1111        self.check_resops(call_assembler=12)
1112
1113    def test_handle_jitexception_in_portal(self):
1114        # a test for _handle_jitexception_in_portal in blackhole.py
1115        driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
1116                           get_printable_location = lambda codeno: str(codeno))
1117        def do_can_enter_jit(codeno, i, str):
1118            i = (i+1)-1    # some operations
1119            driver.can_enter_jit(codeno=codeno, i=i, str=str)
1120        def intermediate(codeno, i, str):
1121            if i == 9:
1122                do_can_enter_jit(codeno, i, str)
1123        def portal(codeno, str):
1124            i = value.initial
1125            while i < 10:
1126                intermediate(codeno, i, str)
1127                driver.jit_merge_point(codeno=codeno, i=i, str=str)
1128                i += 1
1129                if codeno == 64 and i == 10:
1130                    str = portal(96, str)
1131                str += chr(codeno+i)
1132            return str
1133        class Value:
1134            initial = -1
1135        value = Value()
1136        def main():
1137            value.initial = 0
1138            return (portal(64, '') +
1139                    portal(64, '') +
1140                    portal(64, '') +
1141                    portal(64, '') +
1142                    portal(64, ''))
1143        assert main() == 'ABCDEFGHIabcdefghijJ' * 5
1144        for tlimit in [95, 90, 102]:
1145            print 'tlimit =', tlimit
1146            res = self.meta_interp(main, [], inline=True, trace_limit=tlimit)
1147            assert ''.join(res.chars) == 'ABCDEFGHIabcdefghijJ' * 5
1148
1149    def test_handle_jitexception_in_portal_returns_void(self):
1150        # a test for _handle_jitexception_in_portal in blackhole.py
1151        driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
1152                           get_printable_location = lambda codeno: str(codeno))
1153        def do_can_enter_jit(codeno, i, str):
1154            i = (i+1)-1    # some operations
1155            driver.can_enter_jit(codeno=codeno, i=i, str=str)
1156        def intermediate(codeno, i, str):
1157            if i == 9:
1158                do_can_enter_jit(codeno, i, str)
1159        def portal(codeno, str):
1160            i = value.initial
1161            while i < 10:
1162                intermediate(codeno, i, str)
1163                driver.jit_merge_point(codeno=codeno, i=i, str=str)
1164                i += 1
1165                if codeno == 64 and i == 10:
1166                    portal(96, str)
1167                str += chr(codeno+i)
1168        class Value:
1169            initial = -1
1170        value = Value()
1171        def main():
1172            value.initial = 0
1173            portal(64, '')
1174            portal(64, '')
1175            portal(64, '')
1176            portal(64, '')
1177            portal(64, '')
1178        main()
1179        for tlimit in [95, 90, 102]:
1180            print 'tlimit =', tlimit
1181            self.meta_interp(main, [], inline=True, trace_limit=tlimit)
1182
1183    def test_no_duplicates_bug(self):
1184        driver = JitDriver(greens = ['codeno'], reds = ['i'],
1185                           get_printable_location = lambda codeno: str(codeno))
1186        def portal(codeno, i):
1187            while i > 0:
1188                driver.can_enter_jit(codeno=codeno, i=i)
1189                driver.jit_merge_point(codeno=codeno, i=i)
1190                if codeno > 0:
1191                    break
1192                portal(i, i)
1193                i -= 1
1194        self.meta_interp(portal, [0, 10], inline=True)
1195
1196    def test_trace_from_start_always(self):
1197        from rpython.rlib.nonconst import NonConstant
1198        
1199        driver = JitDriver(greens = ['c'], reds = ['i', 'v'])
1200
1201        def portal(c, i, v):
1202            while i > 0:
1203                driver.jit_merge_point(c=c, i=i, v=v)
1204                portal(c, i - 1, v)
1205                if v:
1206                    driver.can_enter_jit(c=c, i=i, v=v)
1207                break
1208
1209        def main(c, i, _set_param, v):
1210            if _set_param:
1211                set_param(driver, 'function_threshold', 0)
1212            portal(c, i, v)
1213
1214        self.meta_interp(main, [10, 10, False, False], inline=True)
1215        self.check_jitcell_token_count(1)
1216        self.check_trace_count(1)
1217        self.meta_interp(main, [3, 10, True, False], inline=True)
1218        self.check_jitcell_token_count(0)
1219        self.check_trace_count(0)
1220
1221    def test_trace_from_start_does_not_prevent_inlining(self):
1222        driver = JitDriver(greens = ['c', 'bc'], reds = ['i'])
1223        
1224        def portal(bc, c, i):
1225            while True:
1226                driver.jit_merge_point(c=c, bc=bc, i=i)
1227                if bc == 0:
1228                    portal(1, 8, 0)
1229                    c += 1
1230                else:
1231                    return
1232                if c == 10: # bc == 0                    
1233                    c = 0
1234                    if i >= 100:
1235                        return
1236                    driver.can_enter_jit(c=c, bc=bc, i=i)
1237                i += 1
1238
1239        self.meta_interp(portal, [0, 0, 0], inline=True)
1240        self.check_resops(call_may_force=0, call=0)
1241
1242    def test_dont_repeatedly_trace_from_the_same_guard(self):
1243        driver = JitDriver(greens = [], reds = ['level', 'i'])
1244
1245        def portal(level):
1246            if level == 0:
1247                i = -10
1248            else:
1249                i = 0
1250            #
1251            while True:
1252                driver.jit_merge_point(level=level, i=i)
1253                if level == 25:
1254                    return 42
1255                i += 1
1256                if i <= 0:      # <- guard
1257                    continue    # first make a loop
1258                else:
1259                    # then we fail the guard above, doing a recursive call,
1260                    # which will itself fail the same guard above, and so on
1261                    return portal(level + 1)
1262
1263        self.meta_interp(portal, [0])
1264        self.check_trace_count_at_most(2)   # and not, e.g., 24
1265
1266
1267class TestLLtype(RecursiveTests, LLJitMixin):
1268    pass
1269
1270class TestOOtype(RecursiveTests, OOJitMixin):
1271    pass