PageRenderTime 798ms CodeModel.GetById 161ms app.highlight 413ms RepoModel.GetById 211ms app.codeStats 0ms

/Parser/asdl_c.py

http://unladen-swallow.googlecode.com/
Python | 1210 lines | 1177 code | 12 blank | 21 comment | 20 complexity | 599a148b59405c45e28c56ccad0d03b4 MD5 | raw file
   1#! /usr/bin/env python
   2"""Generate C code from an ASDL description."""
   3
   4# TO DO
   5# handle fields that have a type but no name
   6
   7import os, sys
   8
   9import asdl
  10
  11TABSIZE = 8
  12MAX_COL = 80
  13
  14def get_c_type(name):
  15    """Return a string for the C name of the type.
  16
  17    This function special cases the default types provided by asdl:
  18    identifier, string, int, bool.
  19    """
  20    # XXX ack!  need to figure out where Id is useful and where string
  21    if isinstance(name, asdl.Id):
  22        name = name.value
  23    if name in asdl.builtin_types:
  24        return name
  25    else:
  26        return "%s_ty" % name
  27
  28def reflow_lines(s, depth):
  29    """Reflow the line s indented depth tabs.
  30
  31    Return a sequence of lines where no line extends beyond MAX_COL
  32    when properly indented.  The first line is properly indented based
  33    exclusively on depth * TABSIZE.  All following lines -- these are
  34    the reflowed lines generated by this function -- start at the same
  35    column as the first character beyond the opening { in the first
  36    line.
  37    """
  38    size = MAX_COL - depth * TABSIZE
  39    if len(s) < size:
  40        return [s]
  41
  42    lines = []
  43    cur = s
  44    padding = ""
  45    while len(cur) > size:
  46        i = cur.rfind(' ', 0, size)
  47        # XXX this should be fixed for real
  48        if i == -1 and 'GeneratorExp' in cur:
  49            i = size + 3
  50        assert i != -1, "Impossible line %d to reflow: %r" % (size, s)
  51        lines.append(padding + cur[:i])
  52        if len(lines) == 1:
  53            # find new size based on brace
  54            j = cur.find('{', 0, i)
  55            if j >= 0:
  56                j += 2 # account for the brace and the space after it
  57                size -= j
  58                padding = " " * j
  59            else:
  60                j = cur.find('(', 0, i)
  61                if j >= 0:
  62                    j += 1 # account for the paren (no space after it)
  63                    size -= j
  64                    padding = " " * j
  65        cur = cur[i+1:]
  66    else:
  67        lines.append(padding + cur)
  68    return lines
  69
  70def is_simple(sum):
  71    """Return True if a sum is a simple.
  72
  73    A sum is simple if its types have no fields, e.g.
  74    unaryop = Invert | Not | UAdd | USub
  75    """
  76    for t in sum.types:
  77        if t.fields:
  78            return False
  79    return True
  80
  81
  82class EmitVisitor(asdl.VisitorBase):
  83    """Visit that emits lines"""
  84
  85    def __init__(self, file):
  86        self.file = file
  87        super(EmitVisitor, self).__init__()
  88
  89    def emit(self, s, depth, reflow=1):
  90        # XXX reflow long lines?
  91        if reflow:
  92            lines = reflow_lines(s, depth)
  93        else:
  94            lines = [s]
  95        for line in lines:
  96            line = (" " * TABSIZE * depth) + line + "\n"
  97            self.file.write(line)
  98
  99
 100class TypeDefVisitor(EmitVisitor):
 101    def visitModule(self, mod):
 102        for dfn in mod.dfns:
 103            self.visit(dfn)
 104
 105    def visitType(self, type, depth=0):
 106        self.visit(type.value, type.name, depth)
 107
 108    def visitSum(self, sum, name, depth):
 109        if is_simple(sum):
 110            self.simple_sum(sum, name, depth)
 111        else:
 112            self.sum_with_constructors(sum, name, depth)
 113
 114    def simple_sum(self, sum, name, depth):
 115        enum = []
 116        for i in range(len(sum.types)):
 117            type = sum.types[i]
 118            enum.append("%s=%d" % (type.name, i + 1))
 119        enums = ", ".join(enum)
 120        ctype = get_c_type(name)
 121        s = "typedef enum _%s { %s } %s;" % (name, enums, ctype)
 122        self.emit(s, depth)
 123        self.emit("", depth)
 124
 125    def sum_with_constructors(self, sum, name, depth):
 126        ctype = get_c_type(name)
 127        s = "typedef struct _%(name)s *%(ctype)s;" % locals()
 128        self.emit(s, depth)
 129        self.emit("", depth)
 130
 131    def visitProduct(self, product, name, depth):
 132        ctype = get_c_type(name)
 133        s = "typedef struct _%(name)s *%(ctype)s;" % locals()
 134        self.emit(s, depth)
 135        self.emit("", depth)
 136
 137
 138class StructVisitor(EmitVisitor):
 139    """Visitor to generate typdefs for AST."""
 140
 141    def visitModule(self, mod):
 142        for dfn in mod.dfns:
 143            self.visit(dfn)
 144
 145    def visitType(self, type, depth=0):
 146        self.visit(type.value, type.name, depth)
 147
 148    def visitSum(self, sum, name, depth):
 149        if not is_simple(sum):
 150            self.sum_with_constructors(sum, name, depth)
 151
 152    def sum_with_constructors(self, sum, name, depth):
 153        def emit(s, depth=depth):
 154            self.emit(s % sys._getframe(1).f_locals, depth)
 155        enum = []
 156        for i in range(len(sum.types)):
 157            type = sum.types[i]
 158            enum.append("%s_kind=%d" % (type.name, i + 1))
 159
 160        emit("enum _%(name)s_kind {" + ", ".join(enum) + "};")
 161
 162        emit("struct _%(name)s {")
 163        emit("enum _%(name)s_kind kind;", depth + 1)
 164        emit("union {", depth + 1)
 165        for t in sum.types:
 166            self.visit(t, depth + 2)
 167        emit("} v;", depth + 1)
 168        for field in sum.attributes:
 169            # rudimentary attribute handling
 170            type = str(field.type)
 171            assert type in asdl.builtin_types, type
 172            emit("%s %s;" % (type, field.name), depth + 1);
 173        emit("};")
 174        emit("")
 175
 176    def visitConstructor(self, cons, depth):
 177        if cons.fields:
 178            self.emit("struct {", depth)
 179            for f in cons.fields:
 180                self.visit(f, depth + 1)
 181            self.emit("} %s;" % cons.name, depth)
 182            self.emit("", depth)
 183        else:
 184            # XXX not sure what I want here, nothing is probably fine
 185            pass
 186
 187    def visitField(self, field, depth):
 188        # XXX need to lookup field.type, because it might be something
 189        # like a builtin...
 190        ctype = get_c_type(field.type)
 191        name = field.name
 192        if field.seq:
 193            if field.type.value in ('cmpop',):
 194                self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
 195            else:
 196                self.emit("asdl_seq *%(name)s;" % locals(), depth)
 197        else:
 198            self.emit("%(ctype)s %(name)s;" % locals(), depth)
 199
 200    def visitProduct(self, product, name, depth):
 201        self.emit("struct _%(name)s {" % locals(), depth)
 202        for f in product.fields:
 203            self.visit(f, depth + 1)
 204        self.emit("};", depth)
 205        self.emit("", depth)
 206
 207
 208class PrototypeVisitor(EmitVisitor):
 209    """Generate function prototypes for the .h file"""
 210
 211    def visitModule(self, mod):
 212        for dfn in mod.dfns:
 213            self.visit(dfn)
 214
 215    def visitType(self, type):
 216        self.visit(type.value, type.name)
 217
 218    def visitSum(self, sum, name):
 219        if is_simple(sum):
 220            pass # XXX
 221        else:
 222            for t in sum.types:
 223                self.visit(t, name, sum.attributes)
 224
 225    def get_args(self, fields):
 226        """Return list of C argument into, one for each field.
 227
 228        Argument info is 3-tuple of a C type, variable name, and flag
 229        that is true if type can be NULL.
 230        """
 231        args = []
 232        unnamed = {}
 233        for f in fields:
 234            if f.name is None:
 235                name = f.type
 236                c = unnamed[name] = unnamed.get(name, 0) + 1
 237                if c > 1:
 238                    name = "name%d" % (c - 1)
 239            else:
 240                name = f.name
 241            # XXX should extend get_c_type() to handle this
 242            if f.seq:
 243                if f.type.value in ('cmpop',):
 244                    ctype = "asdl_int_seq *"
 245                else:
 246                    ctype = "asdl_seq *"
 247            else:
 248                ctype = get_c_type(f.type)
 249            args.append((ctype, name, f.opt or f.seq))
 250        return args
 251
 252    def visitConstructor(self, cons, type, attrs):
 253        args = self.get_args(cons.fields)
 254        attrs = self.get_args(attrs)
 255        ctype = get_c_type(type)
 256        self.emit_function(cons.name, ctype, args, attrs)
 257
 258    def emit_function(self, name, ctype, args, attrs, union=1):
 259        args = args + attrs
 260        if args:
 261            argstr = ", ".join(["%s %s" % (atype, aname)
 262                                for atype, aname, opt in args])
 263            argstr += ", PyArena *arena"
 264        else:
 265            argstr = "PyArena *arena"
 266        margs = "a0"
 267        for i in range(1, len(args)+1):
 268            margs += ", a%d" % i
 269        self.emit("#define %s(%s) _Py_%s(%s)" % (name, margs, name, margs), 0,
 270                reflow = 0)
 271        self.emit("%s _Py_%s(%s);" % (ctype, name, argstr), 0)
 272
 273    def visitProduct(self, prod, name):
 274        self.emit_function(name, get_c_type(name),
 275                           self.get_args(prod.fields), [], union=0)
 276
 277
 278class FunctionVisitor(PrototypeVisitor):
 279    """Visitor to generate constructor functions for AST."""
 280
 281    def emit_function(self, name, ctype, args, attrs, union=1):
 282        def emit(s, depth=0, reflow=1):
 283            self.emit(s, depth, reflow)
 284        argstr = ", ".join(["%s %s" % (atype, aname)
 285                            for atype, aname, opt in args + attrs])
 286        if argstr:
 287            argstr += ", PyArena *arena"
 288        else:
 289            argstr = "PyArena *arena"
 290        self.emit("%s" % ctype, 0)
 291        emit("%s(%s)" % (name, argstr))
 292        emit("{")
 293        emit("%s p;" % ctype, 1)
 294        for argtype, argname, opt in args:
 295            # XXX hack alert: false is allowed for a bool
 296            if not opt and not (argtype == "bool" or argtype == "int"):
 297                emit("if (!%s) {" % argname, 1)
 298                emit("PyErr_SetString(PyExc_ValueError,", 2)
 299                msg = "field %s is required for %s" % (argname, name)
 300                emit('                "%s");' % msg,
 301                     2, reflow=0)
 302                emit('return NULL;', 2)
 303                emit('}', 1)
 304
 305        emit("p = (%s)PyArena_Malloc(arena, sizeof(*p));" % ctype, 1);
 306        emit("if (!p)", 1)
 307        emit("return NULL;", 2)
 308        if union:
 309            self.emit_body_union(name, args, attrs)
 310        else:
 311            self.emit_body_struct(name, args, attrs)
 312        emit("return p;", 1)
 313        emit("}")
 314        emit("")
 315
 316    def emit_body_union(self, name, args, attrs):
 317        def emit(s, depth=0, reflow=1):
 318            self.emit(s, depth, reflow)
 319        emit("p->kind = %s_kind;" % name, 1)
 320        for argtype, argname, opt in args:
 321            emit("p->v.%s.%s = %s;" % (name, argname, argname), 1)
 322        for argtype, argname, opt in attrs:
 323            emit("p->%s = %s;" % (argname, argname), 1)
 324
 325    def emit_body_struct(self, name, args, attrs):
 326        def emit(s, depth=0, reflow=1):
 327            self.emit(s, depth, reflow)
 328        for argtype, argname, opt in args:
 329            emit("p->%s = %s;" % (argname, argname), 1)
 330        assert not attrs
 331
 332
 333class PickleVisitor(EmitVisitor):
 334
 335    def visitModule(self, mod):
 336        for dfn in mod.dfns:
 337            self.visit(dfn)
 338
 339    def visitType(self, type):
 340        self.visit(type.value, type.name)
 341
 342    def visitSum(self, sum, name):
 343        pass
 344
 345    def visitProduct(self, sum, name):
 346        pass
 347
 348    def visitConstructor(self, cons, name):
 349        pass
 350
 351    def visitField(self, sum):
 352        pass
 353
 354
 355class Obj2ModPrototypeVisitor(PickleVisitor):
 356    def visitProduct(self, prod, name):
 357        code = "static int obj2ast_%s(PyObject* obj, %s* out, PyArena* arena);"
 358        self.emit(code % (name, get_c_type(name)), 0)
 359
 360    visitSum = visitProduct
 361
 362
 363class Obj2ModVisitor(PickleVisitor):
 364    def funcHeader(self, name):
 365        ctype = get_c_type(name)
 366        self.emit("int", 0)
 367        self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
 368        self.emit("{", 0)
 369        self.emit("PyObject* tmp = NULL;", 1)
 370        self.emit("", 0)
 371
 372    def sumTrailer(self, name):
 373        self.emit("", 0)
 374        self.emit("tmp = PyObject_Repr(obj);", 1)
 375        # there's really nothing more we can do if this fails ...
 376        self.emit("if (tmp == NULL) goto failed;", 1)
 377        error = "expected some sort of %s, but got %%.400s" % name
 378        format = "PyErr_Format(PyExc_TypeError, \"%s\", PyString_AS_STRING(tmp));"
 379        self.emit(format % error, 1, reflow=False)
 380        self.emit("failed:", 0)
 381        self.emit("Py_XDECREF(tmp);", 1)
 382        self.emit("return 1;", 1)
 383        self.emit("}", 0)
 384        self.emit("", 0)
 385
 386    def simpleSum(self, sum, name):
 387        self.funcHeader(name)
 388        for t in sum.types:
 389            self.emit("if (PyObject_IsInstance(obj, (PyObject*)%s_type)) {" % t.name, 1)
 390            self.emit("*out = %s;" % t.name, 2)
 391            self.emit("return 0;", 2)
 392            self.emit("}", 1)
 393        self.sumTrailer(name)
 394
 395    def buildArgs(self, fields):
 396        return ", ".join(fields + ["arena"])
 397
 398    def complexSum(self, sum, name):
 399        self.funcHeader(name)
 400        for a in sum.attributes:
 401            self.visitAttributeDeclaration(a, name, sum=sum)
 402        self.emit("", 0)
 403        # XXX: should we only do this for 'expr'?
 404        self.emit("if (obj == Py_None) {", 1)
 405        self.emit("*out = NULL;", 2)
 406        self.emit("return 0;", 2)
 407        self.emit("}", 1)
 408        for a in sum.attributes:
 409            self.visitField(a, name, sum=sum, depth=1)
 410        for t in sum.types:
 411            self.emit("if (PyObject_IsInstance(obj, (PyObject*)%s_type)) {" % t.name, 1)
 412            for f in t.fields:
 413                self.visitFieldDeclaration(f, t.name, sum=sum, depth=2)
 414            self.emit("", 0)
 415            for f in t.fields:
 416                self.visitField(f, t.name, sum=sum, depth=2)
 417            args = [f.name.value for f in t.fields] + [a.name.value for a in sum.attributes]
 418            self.emit("*out = %s(%s);" % (t.name, self.buildArgs(args)), 2)
 419            self.emit("if (*out == NULL) goto failed;", 2)
 420            self.emit("return 0;", 2)
 421            self.emit("}", 1)
 422        self.sumTrailer(name)
 423
 424    def visitAttributeDeclaration(self, a, name, sum=sum):
 425        ctype = get_c_type(a.type)
 426        self.emit("%s %s;" % (ctype, a.name), 1)
 427
 428    def visitSum(self, sum, name):
 429        if is_simple(sum):
 430            self.simpleSum(sum, name)
 431        else:
 432            self.complexSum(sum, name)
 433
 434    def visitProduct(self, prod, name):
 435        ctype = get_c_type(name)
 436        self.emit("int", 0)
 437        self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
 438        self.emit("{", 0)
 439        self.emit("PyObject* tmp = NULL;", 1)
 440        for f in prod.fields:
 441            self.visitFieldDeclaration(f, name, prod=prod, depth=1)
 442        self.emit("", 0)
 443        for f in prod.fields:
 444            self.visitField(f, name, prod=prod, depth=1)
 445        args = [f.name.value for f in prod.fields]
 446        self.emit("*out = %s(%s);" % (name, self.buildArgs(args)), 1)
 447        self.emit("return 0;", 1)
 448        self.emit("failed:", 0)
 449        self.emit("Py_XDECREF(tmp);", 1)
 450        self.emit("return 1;", 1)
 451        self.emit("}", 0)
 452        self.emit("", 0)
 453
 454    def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0):
 455        ctype = get_c_type(field.type)
 456        if field.seq:
 457            if self.isSimpleType(field):
 458                self.emit("asdl_int_seq* %s;" % field.name, depth)
 459            else:
 460                self.emit("asdl_seq* %s;" % field.name, depth)
 461        else:
 462            ctype = get_c_type(field.type)
 463            self.emit("%s %s;" % (ctype, field.name), depth)
 464
 465    def isSimpleSum(self, field):
 466        # XXX can the members of this list be determined automatically?
 467        return field.type.value in ('expr_context', 'boolop', 'operator',
 468                                    'unaryop', 'cmpop')
 469
 470    def isNumeric(self, field):
 471        return get_c_type(field.type) in ("int", "bool")
 472
 473    def isSimpleType(self, field):
 474        return self.isSimpleSum(field) or self.isNumeric(field)
 475
 476    def visitField(self, field, name, sum=None, prod=None, depth=0):
 477        ctype = get_c_type(field.type)
 478        self.emit("if (PyObject_HasAttrString(obj, \"%s\")) {" % field.name, depth)
 479        self.emit("int res;", depth+1)
 480        if field.seq:
 481            self.emit("Py_ssize_t len;", depth+1)
 482            self.emit("Py_ssize_t i;", depth+1)
 483        self.emit("tmp = PyObject_GetAttrString(obj, \"%s\");" % field.name, depth+1)
 484        self.emit("if (tmp == NULL) goto failed;", depth+1)
 485        if field.seq:
 486            self.emit("if (!PyList_Check(tmp)) {", depth+1)
 487            self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must "
 488                      "be a list, not a %%.200s\", tmp->ob_type->tp_name);" %
 489                      (name, field.name),
 490                      depth+2, reflow=False)
 491            self.emit("goto failed;", depth+2)
 492            self.emit("}", depth+1)
 493            self.emit("len = PyList_GET_SIZE(tmp);", depth+1)
 494            if self.isSimpleType(field):
 495                self.emit("%s = asdl_int_seq_new(len, arena);" % field.name, depth+1)
 496            else:
 497                self.emit("%s = asdl_seq_new(len, arena);" % field.name, depth+1)
 498            self.emit("if (%s == NULL) goto failed;" % field.name, depth+1)
 499            self.emit("for (i = 0; i < len; i++) {", depth+1)
 500            self.emit("%s value;" % ctype, depth+2)
 501            self.emit("res = obj2ast_%s(PyList_GET_ITEM(tmp, i), &value, arena);" %
 502                      field.type, depth+2, reflow=False)
 503            self.emit("if (res != 0) goto failed;", depth+2)
 504            self.emit("asdl_seq_SET(%s, i, value);" % field.name, depth+2)
 505            self.emit("}", depth+1)
 506        else:
 507            self.emit("res = obj2ast_%s(tmp, &%s, arena);" %
 508                      (field.type, field.name), depth+1)
 509            self.emit("if (res != 0) goto failed;", depth+1)
 510
 511        self.emit("Py_XDECREF(tmp);", depth+1)
 512        self.emit("tmp = NULL;", depth+1)
 513        self.emit("} else {", depth)
 514        if not field.opt:
 515            message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
 516            format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
 517            self.emit(format % message, depth+1, reflow=False)
 518            self.emit("return 1;", depth+1)
 519        else:
 520            if self.isNumeric(field):
 521                self.emit("%s = 0;" % field.name, depth+1)
 522            elif not self.isSimpleType(field):
 523                self.emit("%s = NULL;" % field.name, depth+1)
 524            else:
 525                raise TypeError("could not determine the default value for %s" % field.name)
 526        self.emit("}", depth)
 527
 528
 529class MarshalPrototypeVisitor(PickleVisitor):
 530
 531    def prototype(self, sum, name):
 532        ctype = get_c_type(name)
 533        self.emit("static int marshal_write_%s(PyObject **, int *, %s);"
 534                  % (name, ctype), 0)
 535
 536    visitProduct = visitSum = prototype
 537
 538
 539class PyTypesDeclareVisitor(PickleVisitor):
 540
 541    def visitProduct(self, prod, name):
 542        self.emit("static PyTypeObject *%s_type;" % name, 0)
 543        self.emit("static PyObject* ast2obj_%s(void*);" % name, 0)
 544        if prod.fields:
 545            self.emit("static char *%s_fields[]={" % name,0)
 546            for f in prod.fields:
 547                self.emit('"%s",' % f.name, 1)
 548            self.emit("};", 0)
 549
 550    def visitSum(self, sum, name):
 551        self.emit("static PyTypeObject *%s_type;" % name, 0)
 552        if sum.attributes:
 553            self.emit("static char *%s_attributes[] = {" % name, 0)
 554            for a in sum.attributes:
 555                self.emit('"%s",' % a.name, 1)
 556            self.emit("};", 0)
 557        ptype = "void*"
 558        if is_simple(sum):
 559            ptype = get_c_type(name)
 560            tnames = []
 561            for t in sum.types:
 562                tnames.append(str(t.name)+"_singleton")
 563            tnames = ", *".join(tnames)
 564            self.emit("static PyObject *%s;" % tnames, 0)
 565        self.emit("static PyObject* ast2obj_%s(%s);" % (name, ptype), 0)
 566        for t in sum.types:
 567            self.visitConstructor(t, name)
 568
 569    def visitConstructor(self, cons, name):
 570        self.emit("static PyTypeObject *%s_type;" % cons.name, 0)
 571        if cons.fields:
 572            self.emit("static char *%s_fields[]={" % cons.name, 0)
 573            for t in cons.fields:
 574                self.emit('"%s",' % t.name, 1)
 575            self.emit("};",0)
 576
 577class PyTypesVisitor(PickleVisitor):
 578
 579    def visitModule(self, mod):
 580        self.emit("""
 581static int
 582ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
 583{
 584    Py_ssize_t i, numfields = 0;
 585    int res = -1;
 586    PyObject *key, *value, *fields;
 587    fields = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "_fields");
 588    if (!fields)
 589        PyErr_Clear();
 590    if (fields) {
 591        numfields = PySequence_Size(fields);
 592        if (numfields == -1)
 593            goto cleanup;
 594    }
 595    res = 0; /* if no error occurs, this stays 0 to the end */
 596    if (PyTuple_GET_SIZE(args) > 0) {
 597        if (numfields != PyTuple_GET_SIZE(args)) {
 598            PyErr_Format(PyExc_TypeError, "%.400s constructor takes %s"
 599                         "%zd positional argument%s",
 600                         Py_TYPE(self)->tp_name,
 601                         numfields == 0 ? "" : "either 0 or ",
 602                         numfields, numfields == 1 ? "" : "s");
 603            res = -1;
 604            goto cleanup;
 605        }
 606        for (i = 0; i < PyTuple_GET_SIZE(args); i++) {
 607            /* cannot be reached when fields is NULL */
 608            PyObject *name = PySequence_GetItem(fields, i);
 609            if (!name) {
 610                res = -1;
 611                goto cleanup;
 612            }
 613            res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
 614            Py_DECREF(name);
 615            if (res < 0)
 616                goto cleanup;
 617        }
 618    }
 619    if (kw) {
 620        i = 0;  /* needed by PyDict_Next */
 621        while (PyDict_Next(kw, &i, &key, &value)) {
 622            res = PyObject_SetAttr(self, key, value);
 623            if (res < 0)
 624                goto cleanup;
 625        }
 626    }
 627  cleanup:
 628    Py_XDECREF(fields);
 629    return res;
 630}
 631
 632/* Pickling support */
 633static PyObject *
 634ast_type_reduce(PyObject *self, PyObject *unused)
 635{
 636    PyObject *res;
 637    PyObject *dict = PyObject_GetAttrString(self, "__dict__");
 638    if (dict == NULL) {
 639        if (PyErr_ExceptionMatches(PyExc_AttributeError))
 640            PyErr_Clear();
 641        else
 642            return NULL;
 643    }
 644    if (dict) {
 645        res = Py_BuildValue("O()O", Py_TYPE(self), dict);
 646        Py_DECREF(dict);
 647        return res;
 648    }
 649    return Py_BuildValue("O()", Py_TYPE(self));
 650}
 651
 652static PyMethodDef ast_type_methods[] = {
 653    {"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
 654    {NULL}
 655};
 656
 657static PyTypeObject AST_type = {
 658    PyVarObject_HEAD_INIT(&PyType_Type, 0)
 659    "_ast.AST",
 660    sizeof(PyObject),
 661    0,
 662    0,                       /* tp_dealloc */
 663    0,                       /* tp_print */
 664    0,                       /* tp_getattr */
 665    0,                       /* tp_setattr */
 666    0,                       /* tp_compare */
 667    0,                       /* tp_repr */
 668    0,                       /* tp_as_number */
 669    0,                       /* tp_as_sequence */
 670    0,                       /* tp_as_mapping */
 671    0,                       /* tp_hash */
 672    0,                       /* tp_call */
 673    0,                       /* tp_str */
 674    PyObject_GenericGetAttr, /* tp_getattro */
 675    PyObject_GenericSetAttr, /* tp_setattro */
 676    0,                       /* tp_as_buffer */
 677    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
 678    0,                       /* tp_doc */
 679    0,                       /* tp_traverse */
 680    0,                       /* tp_clear */
 681    0,                       /* tp_richcompare */
 682    0,                       /* tp_weaklistoffset */
 683    0,                       /* tp_iter */
 684    0,                       /* tp_iternext */
 685    ast_type_methods,        /* tp_methods */
 686    0,                       /* tp_members */
 687    0,                       /* tp_getset */
 688    0,                       /* tp_base */
 689    0,                       /* tp_dict */
 690    0,                       /* tp_descr_get */
 691    0,                       /* tp_descr_set */
 692    0,                       /* tp_dictoffset */
 693    (initproc)ast_type_init, /* tp_init */
 694    PyType_GenericAlloc,     /* tp_alloc */
 695    PyType_GenericNew,       /* tp_new */
 696    PyObject_Del,            /* tp_free */
 697};
 698
 699
 700static PyTypeObject* make_type(char *type, PyTypeObject* base, char**fields, int num_fields)
 701{
 702    PyObject *fnames, *result;
 703    int i;
 704    fnames = PyTuple_New(num_fields);
 705    if (!fnames) return NULL;
 706    for (i = 0; i < num_fields; i++) {
 707        PyObject *field = PyString_FromString(fields[i]);
 708        if (!field) {
 709            Py_DECREF(fnames);
 710            return NULL;
 711        }
 712        PyTuple_SET_ITEM(fnames, i, field);
 713    }
 714    result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){sOss}",
 715                    type, base, "_fields", fnames, "__module__", "_ast");
 716    Py_DECREF(fnames);
 717    return (PyTypeObject*)result;
 718}
 719
 720static int add_attributes(PyTypeObject* type, char**attrs, int num_fields)
 721{
 722    int i, result;
 723    PyObject *s, *l = PyTuple_New(num_fields);
 724    if (!l) return 0;
 725    for(i = 0; i < num_fields; i++) {
 726        s = PyString_FromString(attrs[i]);
 727        if (!s) {
 728            Py_DECREF(l);
 729            return 0;
 730        }
 731        PyTuple_SET_ITEM(l, i, s);
 732    }
 733    result = PyObject_SetAttrString((PyObject*)type, "_attributes", l) >= 0;
 734    Py_DECREF(l);
 735    return result;
 736}
 737
 738/* Conversion AST -> Python */
 739
 740static PyObject* ast2obj_list(asdl_seq *seq, PyObject* (*func)(void*))
 741{
 742    int i, n = asdl_seq_LEN(seq);
 743    PyObject *result = PyList_New(n);
 744    PyObject *value;
 745    if (!result)
 746        return NULL;
 747    for (i = 0; i < n; i++) {
 748        value = func(asdl_seq_GET(seq, i));
 749        if (!value) {
 750            Py_DECREF(result);
 751            return NULL;
 752        }
 753        PyList_SET_ITEM(result, i, value);
 754    }
 755    return result;
 756}
 757
 758static PyObject* ast2obj_object(void *o)
 759{
 760    if (!o)
 761        o = Py_None;
 762    Py_INCREF((PyObject*)o);
 763    return (PyObject*)o;
 764}
 765#define ast2obj_identifier ast2obj_object
 766#define ast2obj_string ast2obj_object
 767static PyObject* ast2obj_bool(bool b)
 768{
 769    return PyBool_FromLong(b);
 770}
 771
 772static PyObject* ast2obj_int(long b)
 773{
 774    return PyInt_FromLong(b);
 775}
 776
 777/* Conversion Python -> AST */
 778
 779static int obj2ast_object(PyObject* obj, PyObject** out, PyArena* arena)
 780{
 781    if (obj == Py_None)
 782        obj = NULL;
 783    if (obj)
 784        PyArena_AddPyObject(arena, obj);
 785    Py_XINCREF(obj);
 786    *out = obj;
 787    return 0;
 788}
 789
 790#define obj2ast_identifier obj2ast_object
 791#define obj2ast_string obj2ast_object
 792
 793static int obj2ast_int(PyObject* obj, int* out, PyArena* arena)
 794{
 795    int i;
 796    if (!PyInt_Check(obj) && !PyLong_Check(obj)) {
 797        PyObject *s = PyObject_Repr(obj);
 798        if (s == NULL) return 1;
 799        PyErr_Format(PyExc_ValueError, "invalid integer value: %.400s",
 800                     PyString_AS_STRING(s));
 801        Py_DECREF(s);
 802        return 1;
 803    }
 804
 805    i = (int)PyLong_AsLong(obj);
 806    if (i == -1 && PyErr_Occurred())
 807        return 1;
 808    *out = i;
 809    return 0;
 810}
 811
 812static int obj2ast_bool(PyObject* obj, bool* out, PyArena* arena)
 813{
 814    if (!PyBool_Check(obj)) {
 815        PyObject *s = PyObject_Repr(obj);
 816        if (s == NULL) return 1;
 817        PyErr_Format(PyExc_ValueError, "invalid boolean value: %.400s",
 818                     PyString_AS_STRING(s));
 819        Py_DECREF(s);
 820        return 1;
 821    }
 822
 823    *out = (obj == Py_True);
 824    return 0;
 825}
 826
 827static int add_ast_fields(void)
 828{
 829    PyObject *empty_tuple, *d;
 830    if (PyType_Ready(&AST_type) < 0)
 831        return -1;
 832    d = AST_type.tp_dict;
 833    empty_tuple = PyTuple_New(0);
 834    if (!empty_tuple ||
 835        PyDict_SetItemString(d, "_fields", empty_tuple) < 0 ||
 836        PyDict_SetItemString(d, "_attributes", empty_tuple) < 0) {
 837        Py_XDECREF(empty_tuple);
 838        return -1;
 839    }
 840    Py_DECREF(empty_tuple);
 841    return 0;
 842}
 843
 844""", 0, reflow=False)
 845
 846        self.emit("static int init_types(void)",0)
 847        self.emit("{", 0)
 848        self.emit("static int initialized;", 1)
 849        self.emit("if (initialized) return 1;", 1)
 850        self.emit("if (add_ast_fields() < 0) return 0;", 1)
 851        for dfn in mod.dfns:
 852            self.visit(dfn)
 853        self.emit("initialized = 1;", 1)
 854        self.emit("return 1;", 1);
 855        self.emit("}", 0)
 856
 857    def visitProduct(self, prod, name):
 858        if prod.fields:
 859            fields = name.value+"_fields"
 860        else:
 861            fields = "NULL"
 862        self.emit('%s_type = make_type("%s", &AST_type, %s, %d);' %
 863                        (name, name, fields, len(prod.fields)), 1)
 864        self.emit("if (!%s_type) return 0;" % name, 1)
 865
 866    def visitSum(self, sum, name):
 867        self.emit('%s_type = make_type("%s", &AST_type, NULL, 0);' %
 868                  (name, name), 1)
 869        self.emit("if (!%s_type) return 0;" % name, 1)
 870        if sum.attributes:
 871            self.emit("if (!add_attributes(%s_type, %s_attributes, %d)) return 0;" %
 872                            (name, name, len(sum.attributes)), 1)
 873        else:
 874            self.emit("if (!add_attributes(%s_type, NULL, 0)) return 0;" % name, 1)
 875        simple = is_simple(sum)
 876        for t in sum.types:
 877            self.visitConstructor(t, name, simple)
 878
 879    def visitConstructor(self, cons, name, simple):
 880        if cons.fields:
 881            fields = cons.name.value+"_fields"
 882        else:
 883            fields = "NULL"
 884        self.emit('%s_type = make_type("%s", %s_type, %s, %d);' %
 885                            (cons.name, cons.name, name, fields, len(cons.fields)), 1)
 886        self.emit("if (!%s_type) return 0;" % cons.name, 1)
 887        if simple:
 888            self.emit("%s_singleton = PyType_GenericNew(%s_type, NULL, NULL);" %
 889                             (cons.name, cons.name), 1)
 890            self.emit("if (!%s_singleton) return 0;" % cons.name, 1)
 891
 892
 893def parse_version(mod):
 894    return mod.version.value[12:-3]
 895
 896class ASTModuleVisitor(PickleVisitor):
 897
 898    def visitModule(self, mod):
 899        self.emit("PyMODINIT_FUNC", 0)
 900        self.emit("init_ast(void)", 0)
 901        self.emit("{", 0)
 902        self.emit("PyObject *m, *d;", 1)
 903        self.emit("if (!init_types()) return;", 1)
 904        self.emit('m = Py_InitModule3("_ast", NULL, NULL);', 1)
 905        self.emit("if (!m) return;", 1)
 906        self.emit("d = PyModule_GetDict(m);", 1)
 907        self.emit('if (PyDict_SetItemString(d, "AST", (PyObject*)&AST_type) < 0) return;', 1)
 908        self.emit('if (PyModule_AddIntConstant(m, "PyCF_ONLY_AST", PyCF_ONLY_AST) < 0)', 1)
 909        self.emit("return;", 2)
 910        # Value of version: "$Revision: 67146 $"
 911        self.emit('if (PyModule_AddStringConstant(m, "__version__", "%s") < 0)'
 912                % parse_version(mod), 1)
 913        self.emit("return;", 2)
 914        for dfn in mod.dfns:
 915            self.visit(dfn)
 916        self.emit("}", 0)
 917
 918    def visitProduct(self, prod, name):
 919        self.addObj(name)
 920
 921    def visitSum(self, sum, name):
 922        self.addObj(name)
 923        for t in sum.types:
 924            self.visitConstructor(t, name)
 925
 926    def visitConstructor(self, cons, name):
 927        self.addObj(cons.name)
 928
 929    def addObj(self, name):
 930        self.emit('if (PyDict_SetItemString(d, "%s", (PyObject*)%s_type) < 0) return;' % (name, name), 1)
 931
 932
 933_SPECIALIZED_SEQUENCES = ('stmt', 'expr')
 934
 935def find_sequence(fields, doing_specialization):
 936    """Return True if any field uses a sequence."""
 937    for f in fields:
 938        if f.seq:
 939            if not doing_specialization:
 940                return True
 941            if str(f.type) not in _SPECIALIZED_SEQUENCES:
 942                return True
 943    return False
 944
 945def has_sequence(types, doing_specialization):
 946    for t in types:
 947        if find_sequence(t.fields, doing_specialization):
 948            return True
 949    return False
 950
 951
 952class StaticVisitor(PickleVisitor):
 953    CODE = '''Very simple, always emit this static code.  Overide CODE'''
 954
 955    def visit(self, object):
 956        self.emit(self.CODE, 0, reflow=False)
 957
 958
 959class ObjVisitor(PickleVisitor):
 960
 961    def func_begin(self, name):
 962        ctype = get_c_type(name)
 963        self.emit("PyObject*", 0)
 964        self.emit("ast2obj_%s(void* _o)" % (name), 0)
 965        self.emit("{", 0)
 966        self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
 967        self.emit("PyObject *result = NULL, *value = NULL;", 1)
 968        self.emit('if (!o) {', 1)
 969        self.emit("Py_INCREF(Py_None);", 2)
 970        self.emit('return Py_None;', 2)
 971        self.emit("}", 1)
 972        self.emit('', 0)
 973
 974    def func_end(self):
 975        self.emit("return result;", 1)
 976        self.emit("failed:", 0)
 977        self.emit("Py_XDECREF(value);", 1)
 978        self.emit("Py_XDECREF(result);", 1)
 979        self.emit("return NULL;", 1)
 980        self.emit("}", 0)
 981        self.emit("", 0)
 982
 983    def visitSum(self, sum, name):
 984        if is_simple(sum):
 985            self.simpleSum(sum, name)
 986            return
 987        self.func_begin(name)
 988        self.emit("switch (o->kind) {", 1)
 989        for i in range(len(sum.types)):
 990            t = sum.types[i]
 991            self.visitConstructor(t, i + 1, name)
 992        self.emit("}", 1)
 993        for a in sum.attributes:
 994            self.emit("value = ast2obj_%s(o->%s);" % (a.type, a.name), 1)
 995            self.emit("if (!value) goto failed;", 1)
 996            self.emit('if (PyObject_SetAttrString(result, "%s", value) < 0)' % a.name, 1)
 997            self.emit('goto failed;', 2)
 998            self.emit('Py_DECREF(value);', 1)
 999        self.func_end()
1000
1001    def simpleSum(self, sum, name):
1002        self.emit("PyObject* ast2obj_%s(%s_ty o)" % (name, name), 0)
1003        self.emit("{", 0)
1004        self.emit("switch(o) {", 1)
1005        for t in sum.types:
1006            self.emit("case %s:" % t.name, 2)
1007            self.emit("Py_INCREF(%s_singleton);" % t.name, 3)
1008            self.emit("return %s_singleton;" % t.name, 3)
1009        self.emit("default:" % name, 2)
1010        self.emit('/* should never happen, but just in case ... */', 3)
1011        code = "PyErr_Format(PyExc_SystemError, \"unknown %s found\");" % name
1012        self.emit(code, 3, reflow=False)
1013        self.emit("return NULL;", 3)
1014        self.emit("}", 1)
1015        self.emit("}", 0)
1016
1017    def visitProduct(self, prod, name):
1018        self.func_begin(name)
1019        self.emit("result = PyType_GenericNew(%s_type, NULL, NULL);" % name, 1);
1020        self.emit("if (!result) return NULL;", 1)
1021        for field in prod.fields:
1022            self.visitField(field, name, 1, True)
1023        self.func_end()
1024
1025    def visitConstructor(self, cons, enum, name):
1026        self.emit("case %s_kind:" % cons.name, 1)
1027        self.emit("result = PyType_GenericNew(%s_type, NULL, NULL);" % cons.name, 2);
1028        self.emit("if (!result) goto failed;", 2)
1029        for f in cons.fields:
1030            self.visitField(f, cons.name, 2, False)
1031        self.emit("break;", 2)
1032
1033    def visitField(self, field, name, depth, product):
1034        def emit(s, d):
1035            self.emit(s, depth + d)
1036        if product:
1037            value = "o->%s" % field.name
1038        else:
1039            value = "o->v.%s.%s" % (name, field.name)
1040        self.set(field, value, depth)
1041        emit("if (!value) goto failed;", 0)
1042        emit('if (PyObject_SetAttrString(result, "%s", value) == -1)' % field.name, 0)
1043        emit("goto failed;", 1)
1044        emit("Py_DECREF(value);", 0)
1045
1046    def emitSeq(self, field, value, depth, emit):
1047        emit("seq = %s;" % value, 0)
1048        emit("n = asdl_seq_LEN(seq);", 0)
1049        emit("value = PyList_New(n);", 0)
1050        emit("if (!value) goto failed;", 0)
1051        emit("for (i = 0; i < n; i++) {", 0)
1052        self.set("value", field, "asdl_seq_GET(seq, i)", depth + 1)
1053        emit("if (!value1) goto failed;", 1)
1054        emit("PyList_SET_ITEM(value, i, value1);", 1)
1055        emit("value1 = NULL;", 1)
1056        emit("}", 0)
1057
1058    def set(self, field, value, depth):
1059        if field.seq:
1060            # XXX should really check for is_simple, but that requires a symbol table
1061            if field.type.value == "cmpop":
1062                # While the sequence elements are stored as void*,
1063                # ast2obj_cmpop expects an enum
1064                self.emit("{", depth)
1065                self.emit("int i, n = asdl_seq_LEN(%s);" % value, depth+1)
1066                self.emit("value = PyList_New(n);", depth+1)
1067                self.emit("if (!value) goto failed;", depth+1)
1068                self.emit("for(i = 0; i < n; i++)", depth+1)
1069                # This cannot fail, so no need for error handling
1070                self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop((cmpop_ty)asdl_seq_GET(%s, i)));" % value,
1071                          depth+2, reflow=False)
1072                self.emit("}", depth)
1073            else:
1074                self.emit("value = ast2obj_list(%s, ast2obj_%s);" % (value, field.type), depth)
1075        else:
1076            ctype = get_c_type(field.type)
1077            self.emit("value = ast2obj_%s(%s);" % (field.type, value), depth, reflow=False)
1078
1079
1080class PartingShots(StaticVisitor):
1081
1082    CODE = """
1083PyObject* PyAST_mod2obj(mod_ty t)
1084{
1085    init_types();
1086    return ast2obj_mod(t);
1087}
1088
1089/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
1090mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
1091{
1092    mod_ty res;
1093    PyObject *req_type[] = {(PyObject*)Module_type, (PyObject*)Expression_type,
1094                            (PyObject*)Interactive_type};
1095    char *req_name[] = {"Module", "Expression", "Interactive"};
1096    assert(0 <= mode && mode <= 2);
1097
1098    init_types();
1099
1100    if (!PyObject_IsInstance(ast, req_type[mode])) {
1101        PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s",
1102                     req_name[mode], Py_TYPE(ast)->tp_name);
1103        return NULL;
1104    }
1105    if (obj2ast_mod(ast, &res, arena) != 0)
1106        return NULL;
1107    else
1108        return res;
1109}
1110
1111int PyAST_Check(PyObject* obj)
1112{
1113    init_types();
1114    return PyObject_IsInstance(obj, (PyObject*)&AST_type);
1115}
1116"""
1117
1118class ChainOfVisitors:
1119    def __init__(self, *visitors):
1120        self.visitors = visitors
1121
1122    def visit(self, object):
1123        for v in self.visitors:
1124            v.visit(object)
1125            v.emit("", 0)
1126
1127common_msg = "/* File automatically generated by %s. */\n\n"
1128
1129c_file_msg = """
1130/*
1131   __version__ %s.
1132
1133   This module must be committed separately after each AST grammar change;
1134   The __version__ number is set to the revision number of the commit
1135   containing the grammar change.
1136*/
1137
1138"""
1139
1140def main(srcfile):
1141    argv0 = sys.argv[0]
1142    components = argv0.split(os.sep)
1143    argv0 = os.sep.join(components[-2:])
1144    auto_gen_msg = common_msg % argv0
1145    mod = asdl.parse(srcfile)
1146    if not asdl.check(mod):
1147        sys.exit(1)
1148    if INC_DIR:
1149        p = "%s/%s-ast.h" % (INC_DIR, mod.name)
1150        f = open(p, "wb")
1151        f.write(auto_gen_msg)
1152        f.write('#include "asdl.h"\n\n')
1153        f.write('#ifdef __cplusplus\n'
1154                'extern "C" {\n'
1155                '#endif\n\n')
1156
1157        c = ChainOfVisitors(TypeDefVisitor(f),
1158                            StructVisitor(f),
1159                            PrototypeVisitor(f),
1160                            )
1161        c.visit(mod)
1162        f.write("PyObject* PyAST_mod2obj(mod_ty t);\n")
1163        f.write("mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);\n")
1164        f.write("int PyAST_Check(PyObject* obj);\n")
1165        f.write('\n#ifdef __cplusplus\n'
1166                '}\n'  # To end the extern "C"
1167                '#endif\n')
1168        f.close()
1169
1170    if SRC_DIR:
1171        p = os.path.join(SRC_DIR, str(mod.name) + "-ast.c")
1172        f = open(p, "wb")
1173        f.write(auto_gen_msg)
1174        f.write(c_file_msg % parse_version(mod))
1175        f.write('#include "Python.h"\n')
1176        f.write('#include "%s-ast.h"\n' % mod.name)
1177        f.write('\n')
1178        f.write("static PyTypeObject AST_type;\n")
1179        v = ChainOfVisitors(
1180            PyTypesDeclareVisitor(f),
1181            PyTypesVisitor(f),
1182            Obj2ModPrototypeVisitor(f),
1183            FunctionVisitor(f),
1184            ObjVisitor(f),
1185            Obj2ModVisitor(f),
1186            ASTModuleVisitor(f),
1187            PartingShots(f),
1188            )
1189        v.visit(mod)
1190        f.close()
1191
1192if __name__ == "__main__":
1193    import sys
1194    import getopt
1195
1196    INC_DIR = ''
1197    SRC_DIR = ''
1198    opts, args = getopt.getopt(sys.argv[1:], "h:c:")
1199    if len(opts) != 1:
1200        print "Must specify exactly one output file"
1201        sys.exit(1)
1202    for o, v in opts:
1203        if o == '-h':
1204            INC_DIR = v
1205        if o == '-c':
1206            SRC_DIR = v
1207    if len(args) != 1:
1208        print "Must specify single input file"
1209        sys.exit(1)
1210    main(args[0])