PageRenderTime 134ms CodeModel.GetById 59ms app.highlight 65ms RepoModel.GetById 1ms app.codeStats 1ms

/Lib/compiler/pyassem.py

http://unladen-swallow.googlecode.com/
Python | 747 lines | 683 code | 25 blank | 39 comment | 39 complexity | 5caa399f90a36685b776fc9e06e9e732 MD5 | raw file
  1"""A flow graph representation for Python bytecode"""
  2
  3import dis
  4import types
  5import sys
  6
  7from compiler import misc
  8from compiler.consts \
  9     import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS
 10
 11class FlowGraph:
 12    def __init__(self):
 13        self.current = self.entry = Block()
 14        self.exit = Block("exit")
 15        self.blocks = misc.Set()
 16        self.blocks.add(self.entry)
 17        self.blocks.add(self.exit)
 18
 19    def startBlock(self, block):
 20        if self._debug:
 21            if self.current:
 22                print "end", repr(self.current)
 23                print "    next", self.current.next
 24                print "    prev", self.current.prev
 25                print "   ", self.current.get_children()
 26            print repr(block)
 27        self.current = block
 28
 29    def nextBlock(self, block=None):
 30        # XXX think we need to specify when there is implicit transfer
 31        # from one block to the next.  might be better to represent this
 32        # with explicit JUMP_ABSOLUTE instructions that are optimized
 33        # out when they are unnecessary.
 34        #
 35        # I think this strategy works: each block has a child
 36        # designated as "next" which is returned as the last of the
 37        # children.  because the nodes in a graph are emitted in
 38        # reverse post order, the "next" block will always be emitted
 39        # immediately after its parent.
 40        # Worry: maintaining this invariant could be tricky
 41        if block is None:
 42            block = self.newBlock()
 43
 44        # Note: If the current block ends with an unconditional control
 45        # transfer, then it is techically incorrect to add an implicit
 46        # transfer to the block graph. Doing so results in code generation
 47        # for unreachable blocks.  That doesn't appear to be very common
 48        # with Python code and since the built-in compiler doesn't optimize
 49        # it out we don't either.
 50        self.current.addNext(block)
 51        self.startBlock(block)
 52
 53    def newBlock(self):
 54        b = Block()
 55        self.blocks.add(b)
 56        return b
 57
 58    def startExitBlock(self):
 59        self.startBlock(self.exit)
 60
 61    _debug = 0
 62
 63    def _enable_debug(self):
 64        self._debug = 1
 65
 66    def _disable_debug(self):
 67        self._debug = 0
 68
 69    def emit(self, *inst):
 70        if self._debug:
 71            print "\t", inst
 72        if len(inst) == 2 and isinstance(inst[1], Block):
 73            self.current.addOutEdge(inst[1])
 74        self.current.emit(inst)
 75
 76    def getBlocksInOrder(self):
 77        """Return the blocks in reverse postorder
 78
 79        i.e. each node appears before all of its successors
 80        """
 81        order = order_blocks(self.entry, self.exit)
 82        return order
 83
 84    def getBlocks(self):
 85        return self.blocks.elements()
 86
 87    def getRoot(self):
 88        """Return nodes appropriate for use with dominator"""
 89        return self.entry
 90
 91    def getContainedGraphs(self):
 92        l = []
 93        for b in self.getBlocks():
 94            l.extend(b.getContainedGraphs())
 95        return l
 96
 97
 98def order_blocks(start_block, exit_block):
 99    """Order blocks so that they are emitted in the right order"""
100    # Rules:
101    # - when a block has a next block, the next block must be emitted just after
102    # - when a block has followers (relative jumps), it must be emitted before
103    #   them
104    # - all reachable blocks must be emitted
105    order = []
106
107    # Find all the blocks to be emitted.
108    remaining = set()
109    todo = [start_block]
110    while todo:
111        b = todo.pop()
112        if b in remaining:
113            continue
114        remaining.add(b)
115        for c in b.get_children():
116            if c not in remaining:
117                todo.append(c)
118
119    # A block is dominated by another block if that block must be emitted
120    # before it.
121    dominators = {}
122    for b in remaining:
123        if __debug__ and b.next:
124            assert b is b.next[0].prev[0], (b, b.next)
125        # Make sure every block appears in dominators, even if no
126        # other block must precede it.
127        dominators.setdefault(b, set())
128        # preceeding blocks dominate following blocks
129        for c in b.get_followers():
130            while 1:
131                dominators.setdefault(c, set()).add(b)
132                # Any block that has a next pointer leading to c is also
133                # dominated because the whole chain will be emitted at once.
134                # Walk backwards and add them all.
135                if c.prev and c.prev[0] is not b:
136                    c = c.prev[0]
137                else:
138                    break
139
140    def find_next():
141        # Find a block that can be emitted next.
142        for b in remaining:
143            for c in dominators[b]:
144                if c in remaining:
145                    break # can't emit yet, dominated by a remaining block
146            else:
147                return b
148        assert 0, 'circular dependency, cannot find next block'
149
150    b = start_block
151    while 1:
152        order.append(b)
153        remaining.discard(b)
154        if b.next:
155            b = b.next[0]
156            continue
157        elif b is not exit_block and not b.has_unconditional_transfer():
158            order.append(exit_block)
159        if not remaining:
160            break
161        b = find_next()
162    return order
163
164
165class Block:
166    _count = 0
167
168    def __init__(self, label=''):
169        self.insts = []
170        self.outEdges = set()
171        self.label = label
172        self.bid = Block._count
173        self.next = []
174        self.prev = []
175        Block._count = Block._count + 1
176
177    def __repr__(self):
178        if self.label:
179            return "<block %s id=%d>" % (self.label, self.bid)
180        else:
181            return "<block id=%d>" % (self.bid)
182
183    def __str__(self):
184        insts = map(str, self.insts)
185        return "<block %s %d:\n%s>" % (self.label, self.bid,
186                                       '\n'.join(insts))
187
188    def emit(self, inst):
189        op = inst[0]
190        self.insts.append(inst)
191
192    def getInstructions(self):
193        return self.insts
194
195    def addOutEdge(self, block):
196        self.outEdges.add(block)
197
198    def addNext(self, block):
199        self.next.append(block)
200        assert len(self.next) == 1, map(str, self.next)
201        block.prev.append(self)
202        assert len(block.prev) == 1, map(str, block.prev)
203
204    _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS_ZERO',
205                        'RAISE_VARARGS_ONE', 'RAISE_VARARGS_TWO',
206                        'RAISE_VARARGS_THREE', 'JUMP_ABSOLUTE',
207                        'JUMP_FORWARD', 'CONTINUE_LOOP',
208                        )
209
210    def has_unconditional_transfer(self):
211        """Returns True if there is an unconditional transfer to an other block
212        at the end of this block. This means there is no risk for the bytecode
213        executer to go past this block's bytecode."""
214        try:
215            op, arg = self.insts[-1]
216        except (IndexError, ValueError):
217            return
218        return op in self._uncond_transfer
219
220    def get_children(self):
221        return list(self.outEdges) + self.next
222
223    def get_followers(self):
224        """Get the whole list of followers, including the next block."""
225        followers = set(self.next)
226        # Blocks that must be emitted *after* this one, because of
227        # bytecode offsets (e.g. relative jumps) pointing to them.
228        for inst in self.insts:
229            if inst[0] in PyFlowGraph.hasjrel:
230                followers.add(inst[1])
231        return followers
232
233    def getContainedGraphs(self):
234        """Return all graphs contained within this block.
235
236        For example, the arguments to #@make_function will contain a
237        reference to the graph for the function body.
238        """
239        contained = []
240        for inst in self.insts:
241            if len(inst) == 1:
242                continue
243            op = inst[1]
244            if hasattr(op, 'graph'):
245                contained.append(op.graph)
246        return contained
247
248# flags for code objects
249
250# the FlowGraph is transformed in place; it exists in one of these states
251RAW = "RAW"
252FLAT = "FLAT"
253CONV = "CONV"
254DONE = "DONE"
255
256class PyFlowGraph(FlowGraph):
257    super_init = FlowGraph.__init__
258
259    def __init__(self, name, filename, args=(), optimized=0, klass=None):
260        self.super_init()
261        self.name = name
262        self.filename = filename
263        self.docstring = None
264        self.args = args # XXX
265        self.argcount = getArgCount(args)
266        self.klass = klass
267        if optimized:
268            self.flags = CO_OPTIMIZED | CO_NEWLOCALS
269        else:
270            self.flags = 0
271        self.consts = []
272        self.names = []
273        # Free variables found by the symbol table scan, including
274        # variables used only in nested scopes, are included here.
275        self.freevars = []
276        self.cellvars = []
277        # The closure list is used to track the order of cell
278        # variables and free variables in the resulting code object.
279        # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
280        # kinds of variables.
281        self.closure = []
282        self.varnames = list(args) or []
283        for i in range(len(self.varnames)):
284            var = self.varnames[i]
285            if isinstance(var, TupleArg):
286                self.varnames[i] = var.getName()
287        self.stage = RAW
288
289    def setDocstring(self, doc):
290        self.docstring = doc
291
292    def setFlag(self, flag):
293        self.flags = self.flags | flag
294        if flag == CO_VARARGS:
295            self.argcount = self.argcount - 1
296
297    def checkFlag(self, flag):
298        if self.flags & flag:
299            return 1
300
301    def setFreeVars(self, names):
302        self.freevars = list(names)
303
304    def setCellVars(self, names):
305        self.cellvars = names
306
307    def getCode(self):
308        """Get a Python code object"""
309        assert self.stage == RAW
310        self.computeStackDepth()
311        self.flattenGraph()
312        assert self.stage == FLAT
313        self.convertArgs()
314        assert self.stage == CONV
315        self.makeByteCode()
316        assert self.stage == DONE
317        return self.newCodeObject()
318
319    def dump(self, io=None):
320        if io:
321            save = sys.stdout
322            sys.stdout = io
323        pc = 0
324        for t in self.insts:
325            opname = t[0]
326            if opname == "SET_LINENO":
327                print
328            if len(t) == 1:
329                print "\t", "%3d" % pc, opname
330                pc = pc + 1
331            else:
332                print "\t", "%3d" % pc, opname, t[1]
333                pc = pc + 3
334        if io:
335            sys.stdout = save
336
337    def computeStackDepth(self):
338        """Compute the max stack depth.
339
340        Approach is to compute the stack effect of each basic block.
341        Then find the path through the code with the largest total
342        effect.
343        """
344        depth = {}
345        exit = None
346        for b in self.getBlocks():
347            depth[b] = findDepth(b.getInstructions())
348
349        seen = {}
350
351        def max_depth(b, d):
352            if b in seen:
353                return d
354            seen[b] = 1
355            d = d + depth[b]
356            children = b.get_children()
357            if children:
358                return max([max_depth(c, d) for c in children])
359            else:
360                if not b.label == "exit":
361                    return max_depth(self.exit, d)
362                else:
363                    return d
364
365        self.stacksize = max_depth(self.entry, 0)
366
367    def flattenGraph(self):
368        """Arrange the blocks in order and resolve jumps"""
369        assert self.stage == RAW
370        self.insts = insts = []
371        pc = 0
372        begin = {}
373        end = {}
374        for b in self.getBlocksInOrder():
375            begin[b] = pc
376            for inst in b.getInstructions():
377                insts.append(inst)
378                if len(inst) == 1:
379                    pc = pc + 1
380                elif inst[0] != "SET_LINENO":
381                    # arg takes 2 bytes
382                    pc = pc + 3
383            end[b] = pc
384        pc = 0
385        for i in range(len(insts)):
386            inst = insts[i]
387            if len(inst) == 1:
388                pc = pc + 1
389            elif inst[0] != "SET_LINENO":
390                pc = pc + 3
391            opname = inst[0]
392            if opname in self.hasjrel:
393                oparg = inst[1]
394                offset = begin[oparg] - pc
395                insts[i] = opname, offset
396            elif opname in self.hasjabs:
397                insts[i] = opname, begin[inst[1]]
398        self.stage = FLAT
399
400    hasjrel = set()
401    for i in dis.hasjrel:
402        hasjrel.add(dis.opname[i])
403    hasjabs = set()
404    for i in dis.hasjabs:
405        hasjabs.add(dis.opname[i])
406
407    def convertArgs(self):
408        """Convert arguments from symbolic to concrete form"""
409        assert self.stage == FLAT
410        self.consts.insert(0, self.docstring)
411        self.sort_cellvars()
412        for i in range(len(self.insts)):
413            t = self.insts[i]
414            if len(t) == 2:
415                opname, oparg = t
416                conv = self._converters.get(opname, None)
417                if conv:
418                    self.insts[i] = opname, conv(self, oparg)
419        self.stage = CONV
420
421    def sort_cellvars(self):
422        """Sort cellvars in the order of varnames and prune from freevars.
423        """
424        cells = {}
425        for name in self.cellvars:
426            cells[name] = 1
427        self.cellvars = [name for name in self.varnames
428                         if name in cells]
429        for name in self.cellvars:
430            del cells[name]
431        self.cellvars = self.cellvars + cells.keys()
432        self.closure = self.cellvars + self.freevars
433
434    def _lookupName(self, name, list):
435        """Return index of name in list, appending if necessary
436
437        This routine uses a list instead of a dictionary, because a
438        dictionary can't store two different keys if the keys have the
439        same value but different types, e.g. 2 and 2L.  The compiler
440        must treat these two separately, so it does an explicit type
441        comparison before comparing the values.
442        """
443        t = type(name)
444        for i in range(len(list)):
445            if t == type(list[i]) and list[i] == name:
446                return i
447        end = len(list)
448        list.append(name)
449        return end
450
451    _converters = {}
452    def _convert_LOAD_CONST(self, arg):
453        if hasattr(arg, 'getCode'):
454            arg = arg.getCode()
455        return self._lookupName(arg, self.consts)
456
457    def _convert_LOAD_FAST(self, arg):
458        self._lookupName(arg, self.names)
459        return self._lookupName(arg, self.varnames)
460    _convert_STORE_FAST = _convert_LOAD_FAST
461    _convert_DELETE_FAST = _convert_LOAD_FAST
462
463    def _convert_LOAD_NAME(self, arg):
464        if self.klass is None:
465            self._lookupName(arg, self.varnames)
466        return self._lookupName(arg, self.names)
467
468    def _convert_NAME(self, arg):
469        if self.klass is None:
470            self._lookupName(arg, self.varnames)
471        return self._lookupName(arg, self.names)
472    _convert_STORE_NAME = _convert_NAME
473    _convert_DELETE_NAME = _convert_NAME
474    _convert_STORE_ATTR = _convert_NAME
475    _convert_LOAD_ATTR = _convert_NAME
476    _convert_DELETE_ATTR = _convert_NAME
477    _convert_LOAD_GLOBAL = _convert_NAME
478    _convert_STORE_GLOBAL = _convert_NAME
479    _convert_DELETE_GLOBAL = _convert_NAME
480
481    def _convert_DEREF(self, arg):
482        self._lookupName(arg, self.names)
483        self._lookupName(arg, self.varnames)
484        return self._lookupName(arg, self.closure)
485    _convert_LOAD_DEREF = _convert_DEREF
486    _convert_STORE_DEREF = _convert_DEREF
487
488    def _convert_LOAD_CLOSURE(self, arg):
489        self._lookupName(arg, self.varnames)
490        return self._lookupName(arg, self.closure)
491
492    _cmp = list(dis.cmp_op)
493    def _convert_COMPARE_OP(self, arg):
494        return self._cmp.index(arg)
495
496    # similarly for other opcodes...
497
498    for name, obj in locals().items():
499        if name[:9] == "_convert_":
500            opname = name[9:]
501            _converters[opname] = obj
502    del name, obj, opname
503
504    def makeByteCode(self):
505        assert self.stage == CONV
506        self.lnotab = lnotab = LineAddrTable()
507        for t in self.insts:
508            opname = t[0]
509            if len(t) == 1:
510                lnotab.addCode(self.opnum[opname])
511            else:
512                oparg = t[1]
513                if opname == "SET_LINENO":
514                    lnotab.nextLine(oparg)
515                    continue
516                hi, lo = twobyte(oparg)
517                try:
518                    lnotab.addCode(self.opnum[opname], lo, hi)
519                except ValueError:
520                    print opname, oparg
521                    print self.opnum[opname], lo, hi
522                    raise
523        self.stage = DONE
524
525    opnum = {}
526    for num in range(len(dis.opname)):
527        opnum[dis.opname[num]] = num
528    del num
529
530    def newCodeObject(self):
531        assert self.stage == DONE
532        if (self.flags & CO_NEWLOCALS) == 0:
533            nlocals = 0
534        else:
535            nlocals = len(self.varnames)
536        argcount = self.argcount
537        if self.flags & CO_VARKEYWORDS:
538            argcount = argcount - 1
539        return types.CodeType(argcount, nlocals, self.stacksize, self.flags,
540                        self.lnotab.getCode(), self.getConsts(),
541                        tuple(self.names), tuple(self.varnames),
542                        self.filename, self.name, self.lnotab.firstline,
543                        self.lnotab.getTable(), tuple(self.freevars),
544                        tuple(self.cellvars))
545
546    def getConsts(self):
547        """Return a tuple for the const slot of the code object
548
549        Must convert references to code (#@make_function) to code
550        objects recursively.
551        """
552        l = []
553        for elt in self.consts:
554            if isinstance(elt, PyFlowGraph):
555                elt = elt.getCode()
556            l.append(elt)
557        return tuple(l)
558
559def isJump(opname):
560    if opname[:4] == 'JUMP':
561        return 1
562
563class TupleArg:
564    """Helper for marking func defs with nested tuples in arglist"""
565    def __init__(self, count, names):
566        self.count = count
567        self.names = names
568    def __repr__(self):
569        return "TupleArg(%s, %s)" % (self.count, self.names)
570    def getName(self):
571        return ".%d" % self.count
572
573def getArgCount(args):
574    argcount = len(args)
575    if args:
576        for arg in args:
577            if isinstance(arg, TupleArg):
578                numNames = len(misc.flatten(arg.names))
579                argcount = argcount - numNames
580    return argcount
581
582def twobyte(val):
583    """Convert an int argument into high and low bytes"""
584    assert isinstance(val, int)
585    return divmod(val, 256)
586
587class LineAddrTable:
588    """lnotab
589
590    This class builds the lnotab, which is documented in compile.c.
591    Here's a brief recap:
592
593    For each SET_LINENO instruction after the first one, two bytes are
594    added to lnotab.  (In some cases, multiple two-byte entries are
595    added.)  The first byte is the distance in bytes between the
596    instruction for the last SET_LINENO and the current SET_LINENO.
597    The second byte is offset in line numbers.  If either offset is
598    greater than 255, multiple two-byte entries are added -- see
599    compile.c for the delicate details.
600    """
601
602    def __init__(self):
603        self.code = []
604        self.codeOffset = 0
605        self.firstline = 0
606        self.lastline = 0
607        self.lastoff = 0
608        self.lnotab = []
609
610    def addCode(self, *args):
611        for arg in args:
612            self.code.append(chr(arg))
613        self.codeOffset = self.codeOffset + len(args)
614
615    def nextLine(self, lineno):
616        if self.firstline == 0:
617            self.firstline = lineno
618            self.lastline = lineno
619        else:
620            # compute deltas
621            addr = self.codeOffset - self.lastoff
622            line = lineno - self.lastline
623            # Python assumes that lineno always increases with
624            # increasing bytecode address (lnotab is unsigned char).
625            # Depending on when SET_LINENO instructions are emitted
626            # this is not always true.  Consider the code:
627            #     a = (1,
628            #          b)
629            # In the bytecode stream, the assignment to "a" occurs
630            # after the loading of "b".  This works with the C Python
631            # compiler because it only generates a SET_LINENO instruction
632            # for the assignment.
633            if line >= 0:
634                push = self.lnotab.append
635                while addr > 255:
636                    push(255); push(0)
637                    addr -= 255
638                while line > 255:
639                    push(addr); push(255)
640                    line -= 255
641                    addr = 0
642                if addr > 0 or line > 0:
643                    push(addr); push(line)
644                self.lastline = lineno
645                self.lastoff = self.codeOffset
646
647    def getCode(self):
648        return ''.join(self.code)
649
650    def getTable(self):
651        return ''.join(map(chr, self.lnotab))
652
653class StackDepthTracker:
654    # XXX 1. need to keep track of stack depth on jumps
655    # XXX 2. at least partly as a result, this code is broken
656
657    def findDepth(self, insts, debug=0):
658        depth = 0
659        maxDepth = 0
660        for i in insts:
661            opname = i[0]
662            if debug:
663                print i,
664            delta = self.effect.get(opname, None)
665            if delta is not None:
666                depth = depth + delta
667            else:
668                # now check patterns
669                for pat, pat_delta in self.patterns:
670                    if opname[:len(pat)] == pat:
671                        delta = pat_delta
672                        depth = depth + delta
673                        break
674                # if we still haven't found a match
675                if delta is None:
676                    meth = getattr(self, opname, None)
677                    if meth is not None:
678                        depth = depth + meth(i[1])
679            if depth > maxDepth:
680                maxDepth = depth
681            if debug:
682                print depth, maxDepth
683        return maxDepth
684
685    effect = {
686        'POP_TOP': -1,
687        'DUP_TOP': 1,
688        'DUP_TOP_TWO': 2,
689        'DUP_TOP_THREE': 3,
690        'LIST_APPEND': -2,
691        'SLICE_LEFT': -1,
692        'SLICE_RIGHT': -1,
693        'SLICE_BOTH': -2,
694        'STORE_SLICE_NONE': -1,
695        'STORE_SLICE_LEFT': -2,
696        'STORE_SLICE_RIGHT': -2,
697        'STORE_SLICE_BOTH': -3,
698        'DELETE_SLICE_NONE': -1,
699        'DELETE_SLICE_LEFT': -2,
700        'DELETE_SLICE_RIGHT': -2,
701        'DELETE_SLICE_BOTH': -3,
702        'STORE_SUBSCR': -3,
703        'DELETE_SUBSCR': -2,
704        'RETURN_VALUE': -1,
705        'YIELD_VALUE': -1,
706        'BUILD_SLICE_TWO': -1,
707        'BUILD_SLICE_THREE': -2,
708        'STORE_NAME': -1,
709        'STORE_ATTR': -2,
710        'DELETE_ATTR': -1,
711        'STORE_GLOBAL': -1,
712        'BUILD_MAP': 1,
713        'COMPARE_OP': -1,
714        'STORE_FAST': -1,
715        'LOAD_ATTR': 0, # unlike other loads
716        # close enough...
717        'SETUP_EXCEPT': 3,
718        'SETUP_FINALLY': 3,
719        'FOR_ITER': 1,
720        'WITH_CLEANUP': -1,
721        }
722    # use pattern match
723    patterns = [
724        ('BINARY_', -1),
725        ('LOAD_', 1),
726        ]
727
728    def UNPACK_SEQUENCE(self, count):
729        return count-1
730    def BUILD_TUPLE(self, count):
731        return -count+1
732    def BUILD_LIST(self, count):
733        return -count+1
734    def CALL_FUNCTION(self, argc):
735        hi, lo = divmod(argc, 256)
736        return -(lo + hi * 2)
737    def CALL_FUNCTION_VAR(self, argc):
738        return self.CALL_FUNCTION(argc)-1
739    def CALL_FUNCTION_KW(self, argc):
740        return self.CALL_FUNCTION(argc)-1
741    def CALL_FUNCTION_VAR_KW(self, argc):
742        return self.CALL_FUNCTION(argc)-2
743    def MAKE_CLOSURE(self, argc):
744        # XXX need to account for free variables too!
745        return -argc
746
747findDepth = StackDepthTracker().findDepth