pypy /pypy/module/pypyjit/test/test_jit_hook.py

Language Python Lines 259
MD5 Hash db5bfec6adfb36ee076a6024d38820c9 Estimated Cost $5,353 (why?)
Repository https://bitbucket.org/pypy/pypy/ View Raw File View Project SPDX
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import py
from pypy.interpreter.gateway import interp2app
from pypy.interpreter.pycode import PyCode
from rpython.jit.metainterp.history import JitCellToken, ConstInt, ConstPtr,\
     BasicFailDescr
from rpython.jit.metainterp.resoperation import rop
from rpython.jit.metainterp.logger import Logger
from rpython.rtyper.annlowlevel import (cast_instance_to_base_ptr,
                                      cast_base_ptr_to_instance)
from rpython.rtyper.lltypesystem import lltype, llmemory
from rpython.rtyper.rclass import OBJECT
from pypy.module.pypyjit.interp_jit import pypyjitdriver
from pypy.module.pypyjit.hooks import pypy_hooks
from rpython.jit.tool.oparser import parse
from rpython.jit.metainterp.typesystem import llhelper
from rpython.rlib.jit import JitDebugInfo, AsmInfo, Counters


class MockJitDriverSD(object):
    class warmstate(object):
        @staticmethod
        def get_location_str(boxes):
            ll_code = lltype.cast_opaque_ptr(lltype.Ptr(OBJECT),
                                             boxes[2].getref_base())
            pycode = cast_base_ptr_to_instance(PyCode, ll_code)
            return pycode.co_name

    jitdriver = pypyjitdriver


class MockSD(object):
    class cpu(object):
        ts = llhelper

    jitdrivers_sd = [MockJitDriverSD]


class AppTestJitHook(object):
    spaceconfig = dict(usemodules=('pypyjit',))

    def setup_class(cls):
        if cls.runappdirect:
            py.test.skip("Can't run this test with -A")
        w_f = cls.space.appexec([], """():
        def function():
            pass
        return function
        """)
        cls.w_f = w_f
        ll_code = cast_instance_to_base_ptr(w_f.code)
        code_gcref = lltype.cast_opaque_ptr(llmemory.GCREF, ll_code)
        logger = Logger(MockSD())

        oplist = parse("""
        [i1, i2, p2]
        i3 = int_add(i1, i2)
        debug_merge_point(0, 0, 0, 0, 0, ConstPtr(ptr0))
        guard_nonnull(p2) []
        guard_true(i3) []
        """, namespace={'ptr0': code_gcref}).operations
        greenkey = [ConstInt(0), ConstInt(0), ConstPtr(code_gcref)]
        offset = {}
        for i, op in enumerate(oplist):
            if i != 1:
                offset[op] = i

        class FailDescr(BasicFailDescr):
            def get_jitcounter_hash(self):
                from rpython.rlib.rarithmetic import r_uint
                return r_uint(13)

        oplist[-1].setdescr(FailDescr())
        oplist[-2].setdescr(FailDescr())

        token = JitCellToken()
        token.number = 0
        di_loop = JitDebugInfo(MockJitDriverSD, logger, token, oplist, 'loop',
                   greenkey)
        di_loop_optimize = JitDebugInfo(MockJitDriverSD, logger, JitCellToken(),
                                        oplist, 'loop', greenkey)
        di_loop.asminfo = AsmInfo(offset, 0x42, 12)
        di_bridge = JitDebugInfo(MockJitDriverSD, logger, JitCellToken(),
                                 oplist, 'bridge', fail_descr=FailDescr())
        di_bridge.asminfo = AsmInfo(offset, 0, 0)

        def interp_on_compile():
            di_loop.oplist = cls.oplist
            pypy_hooks.after_compile(di_loop)

        def interp_on_compile_bridge():
            pypy_hooks.after_compile_bridge(di_bridge)

        def interp_on_optimize():
            di_loop_optimize.oplist = cls.oplist
            pypy_hooks.before_compile(di_loop_optimize)

        def interp_on_abort():
            pypy_hooks.on_abort(Counters.ABORT_TOO_LONG, pypyjitdriver,
                                greenkey, 'blah', Logger(MockSD), [])

        space = cls.space
        cls.w_on_compile = space.wrap(interp2app(interp_on_compile))
        cls.w_on_compile_bridge = space.wrap(interp2app(interp_on_compile_bridge))
        cls.w_on_abort = space.wrap(interp2app(interp_on_abort))
        cls.w_int_add_num = space.wrap(rop.INT_ADD)
        cls.w_dmp_num = space.wrap(rop.DEBUG_MERGE_POINT)
        cls.w_on_optimize = space.wrap(interp2app(interp_on_optimize))
        cls.orig_oplist = oplist
        cls.w_sorted_keys = space.wrap(sorted(Counters.counter_names))

    def setup_method(self, meth):
        self.__class__.oplist = self.orig_oplist[:]

    def test_on_compile(self):
        import pypyjit
        all = []

        def hook(info):
            all.append(info)

        self.on_compile()
        pypyjit.set_compile_hook(hook)
        assert not all
        self.on_compile()
        assert len(all) == 1
        info = all[0]
        assert info.jitdriver_name == 'pypyjit'
        assert info.greenkey[0].co_name == 'function'
        assert info.greenkey[1] == 0
        assert info.greenkey[2] == False
        assert info.loop_no == 0
        assert info.type == 'loop'
        assert info.asmaddr == 0x42
        assert info.asmlen == 12
        raises(TypeError, 'info.bridge_no')
        assert len(info.operations) == 4
        int_add = info.operations[0]
        dmp = info.operations[1]
        assert isinstance(dmp, pypyjit.DebugMergePoint)
        assert dmp.pycode is self.f.func_code
        assert dmp.greenkey == (self.f.func_code, 0, False)
        assert dmp.call_depth == 0
        assert dmp.call_id == 0
        assert dmp.offset == -1
        assert int_add.name == 'int_add'
        assert int_add.offset == 0
        self.on_compile_bridge()
        expected = ('<JitLoopInfo pypyjit, 4 operations, starting at '
                    '<(%s, 0, False)>>' % repr(self.f.func_code))
        assert repr(all[0]) == expected
        assert len(all) == 2
        pypyjit.set_compile_hook(None)
        self.on_compile()
        assert len(all) == 2

    def test_on_compile_exception(self):
        import pypyjit, sys, cStringIO

        def hook(*args):
            1/0

        pypyjit.set_compile_hook(hook)
        s = cStringIO.StringIO()
        prev = sys.stderr
        sys.stderr = s
        try:
            self.on_compile()
        finally:
            sys.stderr = prev
        assert 'jit hook' in s.getvalue()
        assert 'ZeroDivisionError' in s.getvalue()

    def test_on_compile_crashes(self):
        import pypyjit
        loops = []
        def hook(loop):
            loops.append(loop)
        pypyjit.set_compile_hook(hook)
        self.on_compile()
        loop = loops[0]
        op = loop.operations[2]
        assert op.name == 'guard_nonnull'

    def test_non_reentrant(self):
        import pypyjit
        l = []

        def hook(*args):
            l.append(None)
            self.on_compile()
            self.on_compile_bridge()

        pypyjit.set_compile_hook(hook)
        self.on_compile()
        assert len(l) == 1 # and did not crash
        self.on_compile_bridge()
        assert len(l) == 2 # and did not crash

    def test_on_compile_types(self):
        import pypyjit
        l = []

        def hook(info):
            l.append(info)

        pypyjit.set_compile_hook(hook)
        self.on_compile()
        op = l[0].operations[1]
        assert isinstance(op, pypyjit.ResOperation)
        assert 'function' in repr(op)

    def test_on_abort(self):
        import pypyjit
        l = []

        def hook(jitdriver_name, greenkey, reason, operations):
            l.append((jitdriver_name, reason, operations))

        pypyjit.set_abort_hook(hook)
        self.on_abort()
        assert l == [('pypyjit', 'ABORT_TOO_LONG', [])]

    def test_creation(self):
        from pypyjit import ResOperation

        op = ResOperation("int_add", -1, "int_add(1, 2)")
        assert op.name == 'int_add'
        assert repr(op) == "int_add(1, 2)"

    def test_creation_dmp(self):
        from pypyjit import DebugMergePoint

        def f():
            pass

        op = DebugMergePoint("debug_merge_point", 'repr', 'pypyjit', 2, 3, (f.func_code, 0, 0))
        assert op.bytecode_no == 0
        assert op.pycode is f.func_code
        assert repr(op) == 'repr'
        assert op.jitdriver_name == 'pypyjit'
        assert op.name == 'debug_merge_point'
        assert op.call_depth == 2
        assert op.call_id == 3
        op = DebugMergePoint('debug_merge_point', 'repr', 'notmain',
                             5, 4, ('str',))
        raises(AttributeError, 'op.pycode')
        assert op.call_depth == 5

    def test_get_stats_snapshot(self):
        skip("a bit no idea how to test it")
        from pypyjit import get_stats_snapshot

        stats = get_stats_snapshot() # we can't do much here, unfortunately
        assert stats.w_loop_run_times == []
        assert isinstance(stats.w_counters, dict)
        assert sorted(stats.w_counters.keys()) == self.sorted_keys
Back to Top