PageRenderTime 69ms CodeModel.GetById 21ms app.highlight 40ms RepoModel.GetById 1ms app.codeStats 0ms

/pypy/module/pypyjit/test_pypy_c/model.py

https://bitbucket.org/pypy/pypy/
Python | 560 lines | 557 code | 2 blank | 1 comment | 4 complexity | ba13e61f400902783ec14d799fc28ae3 MD5 | raw file
  1import py
  2import sys
  3import re
  4import os.path
  5try:
  6    from _pytest.assertion import newinterpret
  7except ImportError:   # e.g. Python 2.5
  8    newinterpret = None
  9from rpython.tool.jitlogparser.parser import (SimpleParser, Function,
 10                                              TraceForOpcode)
 11from rpython.tool.jitlogparser.storage import LoopStorage
 12
 13
 14def find_ids_range(code):
 15    """
 16    Parse the given function and return a dictionary mapping "ids" to
 17    "line ranges".  Ids are identified by comments with a special syntax::
 18
 19        # "myid" corresponds to the whole line
 20        print 'foo' # ID: myid
 21    """
 22    result = {}
 23    start_lineno = code.co.co_firstlineno
 24    for i, line in enumerate(py.code.Source(code.source)):
 25        m = re.search('# ID: (\w+)', line)
 26        if m:
 27            name = m.group(1)
 28            lineno = start_lineno+i
 29            result[name] = xrange(lineno, lineno+1)
 30    return result
 31
 32def find_ids(code):
 33    """
 34    Parse the given function and return a dictionary mapping "ids" to
 35    "opcodes".
 36    """
 37    ids = {}
 38    ranges = find_ids_range(code)
 39    for name, linerange in ranges.iteritems():
 40        opcodes = [opcode for opcode in code.opcodes
 41                   if opcode.lineno in linerange]
 42        ids[name] = opcodes
 43    return ids
 44
 45
 46class Log(object):
 47    def __init__(self, rawtraces):
 48        storage = LoopStorage()
 49        traces = [SimpleParser.parse_from_input(rawtrace) for rawtrace in rawtraces]
 50        traces = storage.reconnect_loops(traces)
 51        self.loops = [TraceWithIds.from_trace(trace, storage) for trace in traces]
 52
 53    def _filter(self, loop, is_entry_bridge=False):
 54        if is_entry_bridge == '*':
 55            return loop
 56        assert is_entry_bridge in (True, False)
 57        return PartialTraceWithIds(loop, is_entry_bridge)
 58
 59    def loops_by_filename(self, filename, **kwds):
 60        """
 61        Return all loops which start in the file ``filename``
 62        """
 63        return [self._filter(loop, **kwds)  for loop in self.loops
 64                if loop.filename == filename]
 65
 66    def loops_by_id(self, id, **kwds):
 67        """
 68        Return all loops which contain the ID ``id``
 69        """
 70        return [self._filter(loop, **kwds) for loop in self.loops
 71                if loop.has_id(id)]
 72
 73    @classmethod
 74    def opnames(self, oplist):
 75        return [op.name for op in oplist]
 76
 77class TraceWithIds(Function):
 78
 79    def __init__(self, *args, **kwds):
 80        Function.__init__(self, *args, **kwds)
 81        self.ids = {}
 82        self.code = self.chunks[0].getcode()
 83        if not self.code and len(self.chunks)>1 and \
 84               isinstance(self.chunks[1], TraceForOpcode):
 85            # First chunk might be missing the debug_merge_point op
 86            self.code = self.chunks[1].getcode()
 87        if self.code:
 88            self.compute_ids(self.ids)
 89
 90    @classmethod
 91    def from_trace(cls, trace, storage):
 92        res = cls.from_operations(trace.operations, storage)
 93        return res
 94
 95    def flatten_chunks(self):
 96        """
 97        return a flat sequence of TraceForOpcode objects, including the ones
 98        inside inlined functions
 99        """
100        for chunk in self.chunks:
101            if isinstance(chunk, TraceForOpcode):
102                yield chunk
103            else:
104                for subchunk in chunk.flatten_chunks():
105                    yield subchunk
106
107    def compute_ids(self, ids):
108        #
109        # 1. compute the ids of self, i.e. the outer function
110        id2opcodes = find_ids(self.code)
111        all_my_opcodes = self.get_set_of_opcodes()
112        for id, opcodes in id2opcodes.iteritems():
113            if not opcodes:
114                continue
115            target_opcodes = set(opcodes)
116            if all_my_opcodes.intersection(target_opcodes):
117                ids[id] = opcodes
118        #
119        # 2. compute the ids of all the inlined functions
120        for chunk in self.chunks:
121            if isinstance(chunk, TraceWithIds) and chunk.code:
122                chunk.compute_ids(ids)
123
124    def get_set_of_opcodes(self):
125        result = set()
126        for chunk in self.chunks:
127            if isinstance(chunk, TraceForOpcode):
128                opcode = chunk.getopcode()
129                result.add(opcode)
130        return result
131
132    def has_id(self, id):
133        return id in self.ids
134
135    def _ops_for_chunk(self, chunk, include_guard_not_invalidated):
136        for op in chunk.operations:
137            if op.name not in ('debug_merge_point', 'enter_portal_frame',
138                               'leave_portal_frame') and \
139                (op.name != 'guard_not_invalidated' or include_guard_not_invalidated):
140                yield op
141
142    def _allops(self, opcode=None, include_guard_not_invalidated=True):
143        opcode_name = opcode
144        for chunk in self.flatten_chunks():
145            opcode = chunk.getopcode()
146            if opcode_name is None or \
147                   (opcode and opcode.__class__.__name__ == opcode_name):
148                for op in self._ops_for_chunk(chunk, include_guard_not_invalidated):
149                    yield op
150            else:
151                for op in chunk.operations:
152                    if op.name == 'label':
153                        yield op
154
155    def allops(self, *args, **kwds):
156        return list(self._allops(*args, **kwds))
157
158    def format_ops(self, id=None, **kwds):
159        if id is None:
160            ops = self.allops(**kwds)
161        else:
162            ops = self.ops_by_id(id, **kwds)
163        return '\n'.join(map(str, ops))
164
165    def print_ops(self, *args, **kwds):
166        print self.format_ops(*args, **kwds)
167
168    def _ops_by_id(self, id, include_guard_not_invalidated=True, opcode=None):
169        opcode_name = opcode
170        target_opcodes = self.ids[id]
171        loop_ops = self.allops(opcode)
172        for chunk in self.flatten_chunks():
173            opcode = chunk.getopcode()
174            if opcode in target_opcodes and (opcode_name is None or
175                                             opcode.__class__.__name__ == opcode_name):
176                for op in self._ops_for_chunk(chunk, include_guard_not_invalidated):
177                    if op in loop_ops:
178                        yield op
179
180    def ops_by_id(self, *args, **kwds):
181        return list(self._ops_by_id(*args, **kwds))
182
183    def match(self, expected_src, **kwds):
184        ops = self.allops()
185        matcher = OpMatcher(ops)
186        return matcher.match(expected_src, **kwds)
187
188    def match_by_id(self, id, expected_src, ignore_ops=[], **kwds):
189        ops = list(self.ops_by_id(id, **kwds))
190        matcher = OpMatcher(ops, id)
191        return matcher.match(expected_src, ignore_ops=ignore_ops)
192
193class PartialTraceWithIds(TraceWithIds):
194    def __init__(self, trace, is_entry_bridge=False):
195        self.trace = trace
196        self.is_entry_bridge = is_entry_bridge
197    
198    def allops(self, *args, **kwds):
199        if self.is_entry_bridge:
200            return self.entry_bridge_ops(*args, **kwds)
201        else:
202            return self.simple_loop_ops(*args, **kwds)
203
204    def simple_loop_ops(self, *args, **kwds):
205        ops = list(self._allops(*args, **kwds))
206        labels = [op for op in ops if op.name == 'label']
207        jumpop = self.chunks[-1].operations[-1]
208        assert jumpop.name == 'jump'
209        assert jumpop.getdescr() == labels[-1].getdescr()
210        i = ops.index(labels[-1])
211        return ops[i+1:]
212
213    def entry_bridge_ops(self, *args, **kwds):
214        ops = list(self._allops(*args, **kwds))
215        labels = [op for op in ops if op.name == 'label']
216        i0 = ops.index(labels[0])
217        i1 = ops.index(labels[1])
218        return ops[i0+1:i1]
219
220    @property
221    def chunks(self):
222        return self.trace.chunks
223
224    @property
225    def ids(self):
226        return self.trace.ids
227
228    @property
229    def filename(self):
230        return self.trace.filename
231    
232    @property
233    def code(self):
234        return self.trace.code
235    
236    
237class InvalidMatch(Exception):
238    opindex = None
239
240    def __init__(self, message, frame):
241        Exception.__init__(self, message)
242        # copied and adapted from pytest's magic AssertionError
243        f = py.code.Frame(frame)
244        try:
245            source = f.code.fullsource
246            if source is not None:
247                try:
248                    source = source.getstatement(f.lineno)
249                except IndexError:
250                    source = None
251                else:
252                    source = str(source.deindent()).strip()
253        except py.error.ENOENT:
254            source = None
255        if source and source.startswith('self._assert(') and newinterpret:
256            # transform self._assert(x, 'foo') into assert x, 'foo'
257            source = source.replace('self._assert(', 'assert ')
258            source = source[:-1] # remove the trailing ')'
259            self.msg = newinterpret.interpret(source, f, should_fail=True)
260        else:
261            self.msg = "<could not determine information>"
262
263
264class OpMatcher(object):
265
266    def __init__(self, ops, id=None):
267        self.ops = ops
268        self.id = id
269        self.src = '\n'.join(map(str, ops))
270        self.alpha_map = {}
271
272    @classmethod
273    def parse_ops(cls, src):
274        ops = [cls.parse_op(line) for line in src.splitlines()]
275        ops.append(('--end--', None, [], '...', True))
276        return [op for op in ops if op is not None]
277
278    @classmethod
279    def parse_op(cls, line):
280        # strip comment after '#', but not if it appears inside parentheses
281        if '#' in line:
282            nested = 0
283            for i, c in enumerate(line):
284                if c == '(':
285                    nested += 1
286                elif c == ')':
287                    assert nested > 0, "more ')' than '(' in %r" % (line,)
288                    nested -= 1
289                elif c == '#' and nested == 0:
290                    line = line[:i]
291                    break
292        #
293        if line.strip() == 'guard_not_invalidated?':
294            return 'guard_not_invalidated', None, [], '...', False
295        # find the resvar, if any
296        if ' = ' in line:
297            resvar, _, line = line.partition(' = ')
298            resvar = resvar.strip()
299        else:
300            resvar = None
301        line = line.strip()
302        if not line:
303            return None
304        if line in ('...', '{{{', '}}}'):
305            return line
306        opname, _, args = line.partition('(')
307        opname = opname.strip()
308        assert args.endswith(')')
309        args = args[:-1]
310        args = args.split(',')
311        args = map(str.strip, args)
312        if args == ['']:
313            args = []
314        if args and args[-1].startswith('descr='):
315            descr = args.pop()
316            descr = descr[len('descr='):]
317        else:
318            descr = None
319        return opname, resvar, args, descr, True
320
321    @classmethod
322    def preprocess_expected_src(cls, src):
323        # all loops decrement the tick-counter at the end. The rpython code is
324        # in jump_absolute() in pypyjit/interp.py. The string --TICK-- is
325        # replaced with the corresponding operations, so that tests don't have
326        # to repeat it every time
327        ticker_check = """
328            guard_not_invalidated?
329            ticker0 = getfield_raw_i(#, descr=<FieldS pypysig_long_struct.c_value .*>)
330            ticker_cond0 = int_lt(ticker0, 0)
331            guard_false(ticker_cond0, descr=...)
332        """
333        src = src.replace('--TICK--', ticker_check)
334        #
335        # this is the ticker check generated if we have threads
336        thread_ticker_check = """
337            guard_not_invalidated?
338            ticker0 = getfield_raw_i(#, descr=<FieldS pypysig_long_struct.c_value .*>)
339            ticker1 = int_sub(ticker0, #)
340            setfield_raw(#, ticker1, descr=<FieldS pypysig_long_struct.c_value .*>)
341            ticker_cond0 = int_lt(ticker1, 0)
342            guard_false(ticker_cond0, descr=...)
343        """
344        src = src.replace('--THREAD-TICK--', thread_ticker_check)
345        #
346        # this is the ticker check generated in PyFrame.handle_operation_error
347        exc_ticker_check = """
348            ticker2 = getfield_raw_i(#, descr=<FieldS pypysig_long_struct.c_value .*>)
349            ticker_cond1 = int_lt(ticker2, 0)
350            guard_false(ticker_cond1, descr=...)
351        """
352        src = src.replace('--EXC-TICK--', exc_ticker_check)
353        #
354        # ISINF is done as a macro; fix it here
355        r = re.compile('(\w+) = --ISINF--[(](\w+)[)]')
356        src = r.sub(r'\2\B999 = float_add(\2, ...)\n\1 = float_eq(\2\B999, \2)',
357                    src)
358        return src
359
360    @classmethod
361    def is_const(cls, v1):
362        return isinstance(v1, str) and v1.startswith('ConstClass(')
363
364    @staticmethod
365    def as_numeric_const(v1):
366        # returns one of:  ('int', value)  ('float', value)  None
367        try:
368            return ('int', int(v1))
369        except ValueError:
370            pass
371        if '.' in v1:
372            try:
373                return ('float', float(v1))
374            except ValueError:
375                pass
376        return None
377
378    def match_var(self, v1, exp_v2):
379        assert v1 != '_'
380        if exp_v2 == '_':           # accept anything
381            return True
382        if exp_v2 is None:
383            return v1 is None
384        assert exp_v2 != '...'      # bogus use of '...' in the expected code
385        n1 = self.as_numeric_const(v1)
386        if exp_v2 == '#':           # accept any (integer or float) number
387            return n1 is not None
388        n2 = self.as_numeric_const(exp_v2)
389        if n1 is not None or n2 is not None:
390            # at least one is a number; check that both are, and are equal
391            return n1 == n2
392        if self.is_const(v1) or self.is_const(exp_v2):
393            return v1[:-1].startswith(exp_v2[:-1])
394        if v1 not in self.alpha_map:
395            self.alpha_map[v1] = exp_v2
396        return self.alpha_map[v1] == exp_v2
397
398    def match_descr(self, descr, exp_descr):
399        if descr == exp_descr or exp_descr == '...':
400            return True
401        self._assert(exp_descr is not None and re.match(exp_descr, descr), "descr mismatch")
402
403    def _assert(self, cond, message):
404        if not cond:
405            raise InvalidMatch(message, frame=sys._getframe(1))
406
407    def match_op(self, op, (exp_opname, exp_res, exp_args, exp_descr, _)):
408        if exp_opname == '--end--':
409            self._assert(op == '--end--', 'got more ops than expected')
410            return
411        self._assert(op != '--end--', 'got less ops than expected')
412        self._assert(op.name == exp_opname, "operation mismatch")
413        self.match_var(op.res, exp_res)
414        if exp_args[-1:] == ['...']:      # exp_args ends with '...'
415            exp_args = exp_args[:-1]
416            self._assert(len(op.args) >= len(exp_args), "not enough arguments")
417        else:
418            self._assert(len(op.args) == len(exp_args), "wrong number of arguments")
419        for arg, exp_arg in zip(op.args, exp_args):
420            self._assert(self.match_var(arg, exp_arg), "variable mismatch: %r instead of %r" % (arg, exp_arg))
421        self.match_descr(op.descr, exp_descr)
422
423
424    def _next_op(self, iter_ops, ignore_ops=set()):
425        try:
426            while True:
427                op = iter_ops.next()
428                if op.name not in ignore_ops:
429                    break
430        except StopIteration:
431            return '--end--'
432        return op
433
434    def try_match(self, op, exp_op):
435        try:
436            # try to match the op, but be sure not to modify the
437            # alpha-renaming map in case the match does not work
438            alpha_map = self.alpha_map.copy()
439            self.match_op(op, exp_op)
440        except InvalidMatch:
441            # it did not match: rollback the alpha_map
442            self.alpha_map = alpha_map
443            return False
444        else:
445            return True
446
447    def match_until(self, until_op, iter_ops):
448        while True:
449            op = self._next_op(iter_ops)
450            if self.try_match(op, until_op):
451                # it matched! The '...' operator ends here
452                return op
453            self._assert(op != '--end--',
454                         'nothing in the end of the loop matches %r' %
455                          (until_op,))
456
457    def match_any_order(self, iter_exp_ops, iter_ops, ignore_ops):
458        exp_ops = []
459        for exp_op in iter_exp_ops:
460            if exp_op == '}}}':
461                break
462            exp_ops.append(exp_op)
463        else:
464            assert 0, "'{{{' not followed by '}}}'"
465        while exp_ops:
466            op = self._next_op(iter_ops, ignore_ops=ignore_ops)
467            # match 'op' against any of the exp_ops; the first successful
468            # match is kept, and the exp_op gets removed from the list
469            for i, exp_op in enumerate(exp_ops):
470                if self.try_match(op, exp_op):
471                    del exp_ops[i]
472                    break
473            else:
474                self._assert(0, \
475                    "operation %r not found within the {{{ }}} block" % (op,))
476
477    def match_loop(self, expected_ops, ignore_ops):
478        """
479        A note about partial matching: the '...' operator is non-greedy,
480        i.e. it matches all the operations until it finds one that matches
481        what is after the '...'.  The '{{{' and '}}}' operators mark a
482        group of lines that can match in any order.
483        """
484        iter_exp_ops = iter(expected_ops)
485        iter_ops = RevertableIterator(self.ops)
486        for exp_op in iter_exp_ops:
487            try:
488                if exp_op == '...':
489                    # loop until we find an operation which matches
490                    try:
491                        exp_op = iter_exp_ops.next()
492                    except StopIteration:
493                        # the ... is the last line in the expected_ops, so we just
494                        # return because it matches everything until the end
495                        return
496                    op = self.match_until(exp_op, iter_ops)
497                elif exp_op == '{{{':
498                    self.match_any_order(iter_exp_ops, iter_ops, ignore_ops)
499                    continue
500                else:
501                    op = self._next_op(iter_ops, ignore_ops=ignore_ops)
502                try:
503                    self.match_op(op, exp_op)
504                except InvalidMatch:
505                    if type(exp_op) is str or exp_op[4] is not False:
506                        raise
507                    #else: optional operation
508                    iter_ops.revert_one()
509                    continue       # try to match with the next exp_op
510            except InvalidMatch as e:
511                e.opindex = iter_ops.index - 1
512                raise
513
514    def match(self, expected_src, ignore_ops=[]):
515        def format(src, opindex=None):
516            if src is None:
517                return ''
518            text = str(py.code.Source(src).deindent().indent())
519            lines = text.splitlines(True)
520            if opindex is not None and 0 <= opindex <= len(lines):
521                lines.insert(opindex, '\n\t===== HERE =====\n')
522            return ''.join(lines)
523        #
524        expected_src = self.preprocess_expected_src(expected_src)
525        expected_ops = self.parse_ops(expected_src)
526        try:
527            self.match_loop(expected_ops, ignore_ops)
528        except InvalidMatch as e:
529            print '@' * 40
530            print "Loops don't match"
531            print "================="
532            print 'loop id = %r' % (self.id,)
533            print e.args
534            print e.msg
535            print
536            print "Ignore ops:", ignore_ops
537            print "Got:"
538            print format(self.src, e.opindex)
539            print
540            print "Expected:"
541            print format(expected_src)
542            raise     # always propagate the exception in case of mismatch
543        else:
544            return True
545
546
547class RevertableIterator(object):
548    def __init__(self, sequence):
549        self.sequence = sequence
550        self.index = 0
551    def __iter__(self):
552        return self
553    def next(self):
554        index = self.index
555        self.index = index + 1
556        if index >= len(self.sequence):
557            raise StopIteration
558        return self.sequence[index]
559    def revert_one(self):
560        self.index -= 1