PageRenderTime 48ms CodeModel.GetById 14ms RepoModel.GetById 1ms app.codeStats 0ms

/pypy/interpreter/astcompiler/tools/asdl_py.py

https://bitbucket.org/pypy/pypy/
Python | 572 lines | 550 code | 19 blank | 3 comment | 19 complexity | 24256331f25a86e081299dcd5a996796 MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. """
  2. Generate AST node definitions from an ASDL description.
  3. """
  4. import sys
  5. import os
  6. import asdl
  7. class ASDLVisitor(asdl.VisitorBase):
  8. def __init__(self, stream, data):
  9. super(ASDLVisitor, self).__init__()
  10. self.stream = stream
  11. self.data = data
  12. def visitModule(self, mod, *args):
  13. for df in mod.dfns:
  14. self.visit(df, *args)
  15. def visitSum(self, sum, *args):
  16. for tp in sum.types:
  17. self.visit(tp, *args)
  18. def visitType(self, tp, *args):
  19. self.visit(tp.value, *args)
  20. def visitProduct(self, prod, *args):
  21. for field in prod.fields:
  22. self.visit(field, *args)
  23. def visitConstructor(self, cons, *args):
  24. for field in cons.fields:
  25. self.visit(field, *args)
  26. def visitField(self, field, *args):
  27. pass
  28. def emit(self, line, level=0):
  29. indent = " "*level
  30. self.stream.write(indent + line + "\n")
  31. def is_simple_sum(sum):
  32. assert isinstance(sum, asdl.Sum)
  33. for constructor in sum.types:
  34. if constructor.fields:
  35. return False
  36. return True
  37. class ASTNodeVisitor(ASDLVisitor):
  38. def visitType(self, tp):
  39. self.visit(tp.value, tp.name)
  40. def visitSum(self, sum, base):
  41. if is_simple_sum(sum):
  42. self.emit("class %s(AST):" % (base,))
  43. self.emit("@staticmethod", 1)
  44. self.emit("def from_object(space, w_node):", 1)
  45. for i, cons in enumerate(sum.types):
  46. self.emit("if space.isinstance_w(w_node, get(space).w_%s):"
  47. % (cons.name,), 2)
  48. self.emit("return %i" % (i+1,), 3)
  49. self.emit("raise oefmt(space.w_TypeError,", 2)
  50. self.emit(" \"Expected %s node, got %%T\", w_node)" % (base,), 2)
  51. self.emit("State.ast_type('%s', 'AST', None)" % (base,))
  52. self.emit("")
  53. for i, cons in enumerate(sum.types):
  54. self.emit("class _%s(%s):" % (cons.name, base))
  55. self.emit("def to_object(self, space):", 1)
  56. self.emit("return space.call_function(get(space).w_%s)" % (cons.name,), 2)
  57. self.emit("State.ast_type('%s', '%s', None)" % (cons.name, base))
  58. self.emit("")
  59. for i, cons in enumerate(sum.types):
  60. self.emit("%s = %i" % (cons.name, i + 1))
  61. self.emit("")
  62. self.emit("%s_to_class = [" % (base,))
  63. for cons in sum.types:
  64. self.emit("_%s," % (cons.name,), 1)
  65. self.emit("]")
  66. self.emit("")
  67. else:
  68. self.emit("class %s(AST):" % (base,))
  69. if sum.attributes:
  70. self.emit("")
  71. args = ", ".join(attr.name.value for attr in sum.attributes)
  72. self.emit("def __init__(self, %s):" % (args,), 1)
  73. for attr in sum.attributes:
  74. self.visit(attr)
  75. self.emit("")
  76. self.emit("@staticmethod", 1)
  77. self.emit("def from_object(space, w_node):", 1)
  78. self.emit("if space.is_w(w_node, space.w_None):", 2)
  79. self.emit(" return None", 2)
  80. for typ in sum.types:
  81. self.emit("if space.isinstance_w(w_node, get(space).w_%s):"
  82. % (typ.name,), 2)
  83. self.emit("return %s.from_object(space, w_node)"
  84. % (typ.name,), 3)
  85. self.emit("raise oefmt(space.w_TypeError,", 2)
  86. self.emit(" \"Expected %s node, got %%T\", w_node)" % (base,), 2)
  87. self.emit("State.ast_type('%r', 'AST', None, %s)" %
  88. (base, [repr(attr.name) for attr in sum.attributes]))
  89. self.emit("")
  90. for cons in sum.types:
  91. self.visit(cons, base, sum.attributes)
  92. self.emit("")
  93. def visitProduct(self, product, name):
  94. self.emit("class %s(AST):" % (name,))
  95. self.emit("")
  96. self.make_constructor(product.fields, product)
  97. self.emit("")
  98. self.make_mutate_over(product, name)
  99. self.emit("def walkabout(self, visitor):", 1)
  100. self.emit("visitor.visit_%s(self)" % (name,), 2)
  101. self.emit("")
  102. self.make_converters(product.fields, name)
  103. self.emit("State.ast_type('%r', 'AST', %s)" %
  104. (name, [repr(f.name) for f in product.fields]))
  105. self.emit("")
  106. def get_value_converter(self, field, value):
  107. if field.type.value in self.data.simple_types:
  108. return "%s_to_class[%s - 1]().to_object(space)" % (field.type, value)
  109. elif field.type.value in ("object", "string"):
  110. return value
  111. elif field.type.value in ("identifier", "int", "bool"):
  112. return "space.wrap(%s)" % (value,)
  113. else:
  114. wrapper = "%s.to_object(space)" % (value,)
  115. if field.opt:
  116. wrapper += " if %s is not None else space.w_None" % (value,)
  117. return wrapper
  118. def get_value_extractor(self, field, value):
  119. if field.type.value in self.data.simple_types:
  120. return "%s.from_object(space, %s)" % (field.type, value)
  121. elif field.type.value in ("object",):
  122. return value
  123. elif field.type.value in ("string",):
  124. return "check_string(space, %s)" % (value,)
  125. elif field.type.value in ("identifier",):
  126. if field.opt:
  127. return "space.str_or_None_w(%s)" % (value,)
  128. return "space.realstr_w(%s)" % (value,)
  129. elif field.type.value in ("int",):
  130. return "space.int_w(%s)" % (value,)
  131. elif field.type.value in ("bool",):
  132. return "space.bool_w(%s)" % (value,)
  133. else:
  134. return "%s.from_object(space, %s)" % (field.type, value)
  135. def get_field_converter(self, field):
  136. if field.seq:
  137. lines = []
  138. lines.append("if self.%s is None:" % field.name)
  139. lines.append(" %s_w = []" % field.name)
  140. lines.append("else:")
  141. wrapper = self.get_value_converter(field, "node")
  142. lines.append(" %s_w = [%s for node in self.%s] # %s" %
  143. (field.name, wrapper, field.name, field.type))
  144. lines.append("w_%s = space.newlist(%s_w)" % (field.name, field.name))
  145. return lines
  146. else:
  147. wrapper = self.get_value_converter(field, "self.%s" % field.name)
  148. return ["w_%s = %s # %s" % (field.name, wrapper, field.type)]
  149. def get_field_extractor(self, field):
  150. if field.seq:
  151. lines = []
  152. lines.append("%s_w = space.unpackiterable(w_%s)" %
  153. (field.name, field.name))
  154. value = self.get_value_extractor(field, "w_item")
  155. lines.append("_%s = [%s for w_item in %s_w]" %
  156. (field.name, value, field.name))
  157. else:
  158. value = self.get_value_extractor(field, "w_%s" % (field.name,))
  159. lines = ["_%s = %s" % (field.name, value)]
  160. return lines
  161. def make_converters(self, fields, name, extras=None):
  162. self.emit("def to_object(self, space):", 1)
  163. self.emit("w_node = space.call_function(get(space).w_%s)" % name, 2)
  164. all_fields = fields + extras if extras else fields
  165. for field in all_fields:
  166. wrapping_code = self.get_field_converter(field)
  167. for line in wrapping_code:
  168. self.emit(line, 2)
  169. self.emit("space.setattr(w_node, space.wrap(%r), w_%s)" % (
  170. str(field.name), field.name), 2)
  171. self.emit("return w_node", 2)
  172. self.emit("")
  173. self.emit("@staticmethod", 1)
  174. self.emit("def from_object(space, w_node):", 1)
  175. for field in all_fields:
  176. self.emit("w_%s = get_field(space, w_node, '%s', %s)" % (
  177. field.name, field.name, field.opt), 2)
  178. for field in all_fields:
  179. unwrapping_code = self.get_field_extractor(field)
  180. for line in unwrapping_code:
  181. self.emit(line, 2)
  182. self.emit("return %s(%s)" % (
  183. name, ', '.join("_%s" % (field.name,) for field in all_fields)), 2)
  184. self.emit("")
  185. def make_constructor(self, fields, node, extras=None, base=None):
  186. if fields or extras:
  187. arg_fields = fields + extras if extras else fields
  188. args = ", ".join(str(field.name) for field in arg_fields)
  189. self.emit("def __init__(self, %s):" % args, 1)
  190. for field in fields:
  191. self.visit(field)
  192. if extras:
  193. base_args = ", ".join(str(field.name) for field in extras)
  194. self.emit("%s.__init__(self, %s)" % (base, base_args), 2)
  195. def make_mutate_over(self, cons, name):
  196. self.emit("def mutate_over(self, visitor):", 1)
  197. for field in cons.fields:
  198. if (field.type.value not in asdl.builtin_types and
  199. field.type.value not in self.data.simple_types):
  200. if field.opt or field.seq:
  201. level = 3
  202. self.emit("if self.%s:" % (field.name,), 2)
  203. else:
  204. level = 2
  205. if field.seq:
  206. sub = (field.name,)
  207. self.emit("visitor._mutate_sequence(self.%s)" % sub, level)
  208. else:
  209. sub = (field.name, field.name)
  210. self.emit("self.%s = self.%s.mutate_over(visitor)" % sub,
  211. level)
  212. self.emit("return visitor.visit_%s(self)" % (name,), 2)
  213. self.emit("")
  214. def visitConstructor(self, cons, base, extra_attributes):
  215. self.emit("class %s(%s):" % (cons.name, base))
  216. self.emit("")
  217. self.make_constructor(cons.fields, cons, extra_attributes, base)
  218. self.emit("")
  219. self.emit("def walkabout(self, visitor):", 1)
  220. self.emit("visitor.visit_%s(self)" % (cons.name,), 2)
  221. self.emit("")
  222. self.make_mutate_over(cons, cons.name)
  223. self.make_converters(cons.fields, cons.name, extra_attributes)
  224. self.emit("State.ast_type('%r', '%s', %s)" %
  225. (cons.name, base, [repr(f.name) for f in cons.fields]))
  226. self.emit("")
  227. def visitField(self, field):
  228. self.emit("self.%s = %s" % (field.name, field.name), 2)
  229. class ASTVisitorVisitor(ASDLVisitor):
  230. """A meta visitor! :)"""
  231. def visitModule(self, mod):
  232. self.emit("class ASTVisitor(object):")
  233. self.emit("")
  234. self.emit("def visit_sequence(self, seq):", 1)
  235. self.emit("if seq is not None:", 2)
  236. self.emit("for node in seq:", 3)
  237. self.emit("node.walkabout(self)", 4)
  238. self.emit("")
  239. self.emit("def default_visitor(self, node):", 1)
  240. self.emit("raise NodeVisitorNotImplemented", 2)
  241. self.emit("")
  242. self.emit("def _mutate_sequence(self, seq):", 1)
  243. self.emit("for i in range(len(seq)):", 2)
  244. self.emit("seq[i] = seq[i].mutate_over(self)", 3)
  245. self.emit("")
  246. super(ASTVisitorVisitor, self).visitModule(mod)
  247. self.emit("")
  248. def visitType(self, tp):
  249. if not (isinstance(tp.value, asdl.Sum) and
  250. is_simple_sum(tp.value)):
  251. super(ASTVisitorVisitor, self).visitType(tp, tp.name)
  252. def visitProduct(self, prod, name):
  253. self.emit("def visit_%s(self, node):" % (name,), 1)
  254. self.emit("return self.default_visitor(node)", 2)
  255. def visitConstructor(self, cons, _):
  256. self.emit("def visit_%s(self, node):" % (cons.name,), 1)
  257. self.emit("return self.default_visitor(node)", 2)
  258. class GenericASTVisitorVisitor(ASDLVisitor):
  259. def visitModule(self, mod):
  260. self.emit("class GenericASTVisitor(ASTVisitor):")
  261. self.emit("")
  262. super(GenericASTVisitorVisitor, self).visitModule(mod)
  263. self.emit("")
  264. def visitType(self, tp):
  265. if not (isinstance(tp.value, asdl.Sum) and
  266. is_simple_sum(tp.value)):
  267. super(GenericASTVisitorVisitor, self).visitType(tp, tp.name)
  268. def visitProduct(self, prod, name):
  269. self.make_visitor(name, prod.fields)
  270. def visitConstructor(self, cons, _):
  271. self.make_visitor(cons.name, cons.fields)
  272. def make_visitor(self, name, fields):
  273. self.emit("def visit_%s(self, node):" % (name,), 1)
  274. have_body = False
  275. for field in fields:
  276. if self.visitField(field):
  277. have_body = True
  278. if not have_body:
  279. self.emit("pass", 2)
  280. self.emit("")
  281. def visitField(self, field):
  282. if (field.type.value not in asdl.builtin_types and
  283. field.type.value not in self.data.simple_types):
  284. level = 2
  285. template = "node.%s.walkabout(self)"
  286. if field.seq:
  287. template = "self.visit_sequence(node.%s)"
  288. elif field.opt:
  289. self.emit("if node.%s:" % (field.name,), 2)
  290. level = 3
  291. self.emit(template % (field.name,), level)
  292. return True
  293. return False
  294. class ASDLData(object):
  295. def __init__(self, tree):
  296. simple_types = set()
  297. prod_simple = set()
  298. field_masks = {}
  299. required_masks = {}
  300. optional_masks = {}
  301. cons_attributes = {}
  302. def add_masks(fields, node):
  303. required_mask = 0
  304. optional_mask = 0
  305. for i, field in enumerate(fields):
  306. flag = 1 << i
  307. if field not in field_masks:
  308. field_masks[field] = flag
  309. else:
  310. assert field_masks[field] == flag
  311. if field.opt:
  312. optional_mask |= flag
  313. else:
  314. required_mask |= flag
  315. required_masks[node] = required_mask
  316. optional_masks[node] = optional_mask
  317. for tp in tree.dfns:
  318. if isinstance(tp.value, asdl.Sum):
  319. sum = tp.value
  320. if is_simple_sum(sum):
  321. simple_types.add(tp.name.value)
  322. else:
  323. attrs = [field for field in sum.attributes]
  324. for cons in sum.types:
  325. add_masks(attrs + cons.fields, cons)
  326. cons_attributes[cons] = attrs
  327. else:
  328. prod = tp.value
  329. prod_simple.add(tp.name.value)
  330. add_masks(prod.fields, prod)
  331. prod_simple.update(simple_types)
  332. self.cons_attributes = cons_attributes
  333. self.simple_types = simple_types
  334. self.prod_simple = prod_simple
  335. self.field_masks = field_masks
  336. self.required_masks = required_masks
  337. self.optional_masks = optional_masks
  338. HEAD = r"""# Generated by tools/asdl_py.py
  339. from rpython.tool.pairtype import extendabletype
  340. from rpython.tool.sourcetools import func_with_new_name
  341. from pypy.interpreter import typedef
  342. from pypy.interpreter.baseobjspace import W_Root
  343. from pypy.interpreter.error import OperationError, oefmt
  344. from pypy.interpreter.gateway import interp2app
  345. def raise_attriberr(space, w_obj, name):
  346. raise oefmt(space.w_AttributeError,
  347. "'%T' object has no attribute '%s'", w_obj, name)
  348. def check_string(space, w_obj):
  349. if not (space.isinstance_w(w_obj, space.w_str) or
  350. space.isinstance_w(w_obj, space.w_unicode)):
  351. raise oefmt(space.w_TypeError,
  352. "AST string must be of type str or unicode")
  353. return w_obj
  354. def get_field(space, w_node, name, optional):
  355. w_obj = w_node.getdictvalue(space, name)
  356. if w_obj is None:
  357. if not optional:
  358. raise oefmt(space.w_TypeError,
  359. "required field \"%s\" missing from %T", name, w_node)
  360. w_obj = space.w_None
  361. return w_obj
  362. class AST(object):
  363. __metaclass__ = extendabletype
  364. def walkabout(self, visitor):
  365. raise AssertionError("walkabout() implementation not provided")
  366. def mutate_over(self, visitor):
  367. raise AssertionError("mutate_over() implementation not provided")
  368. class NodeVisitorNotImplemented(Exception):
  369. pass
  370. class _FieldsWrapper(W_Root):
  371. "Hack around the fact we can't store tuples on a TypeDef."
  372. def __init__(self, fields):
  373. self.fields = fields
  374. def __spacebind__(self, space):
  375. return space.newtuple([space.wrap(field) for field in self.fields])
  376. class W_AST(W_Root):
  377. w_dict = None
  378. def getdict(self, space):
  379. if self.w_dict is None:
  380. self.w_dict = space.newdict(instance=True)
  381. return self.w_dict
  382. def reduce_w(self, space):
  383. w_dict = self.w_dict
  384. if w_dict is None:
  385. w_dict = space.newdict()
  386. w_type = space.type(self)
  387. w_fields = space.getattr(w_type, space.wrap("_fields"))
  388. for w_name in space.fixedview(w_fields):
  389. try:
  390. space.setitem(w_dict, w_name,
  391. space.getattr(self, w_name))
  392. except OperationError:
  393. pass
  394. w_attrs = space.findattr(w_type, space.wrap("_attributes"))
  395. if w_attrs:
  396. for w_name in space.fixedview(w_attrs):
  397. try:
  398. space.setitem(w_dict, w_name,
  399. space.getattr(self, w_name))
  400. except OperationError:
  401. pass
  402. return space.newtuple([space.type(self),
  403. space.newtuple([]),
  404. w_dict])
  405. def setstate_w(self, space, w_state):
  406. for w_name in space.unpackiterable(w_state):
  407. space.setattr(self, w_name,
  408. space.getitem(w_state, w_name))
  409. def W_AST_new(space, w_type, __args__):
  410. node = space.allocate_instance(W_AST, w_type)
  411. return space.wrap(node)
  412. def W_AST_init(space, w_self, __args__):
  413. args_w, kwargs_w = __args__.unpack()
  414. fields_w = space.fixedview(space.getattr(space.type(w_self),
  415. space.wrap("_fields")))
  416. num_fields = len(fields_w) if fields_w else 0
  417. if args_w and len(args_w) != num_fields:
  418. if num_fields == 0:
  419. raise oefmt(space.w_TypeError,
  420. "%T constructor takes 0 positional arguments", w_self)
  421. elif num_fields == 1:
  422. raise oefmt(space.w_TypeError,
  423. "%T constructor takes either 0 or %d positional argument", w_self, num_fields)
  424. else:
  425. raise oefmt(space.w_TypeError,
  426. "%T constructor takes either 0 or %d positional arguments", w_self, num_fields)
  427. if args_w:
  428. for i, w_field in enumerate(fields_w):
  429. space.setattr(w_self, w_field, args_w[i])
  430. for field, w_value in kwargs_w.iteritems():
  431. space.setattr(w_self, space.wrap(field), w_value)
  432. W_AST.typedef = typedef.TypeDef("_ast.AST",
  433. _fields=_FieldsWrapper([]),
  434. _attributes=_FieldsWrapper([]),
  435. __reduce__=interp2app(W_AST.reduce_w),
  436. __setstate__=interp2app(W_AST.setstate_w),
  437. __dict__ = typedef.GetSetProperty(typedef.descr_get_dict,
  438. typedef.descr_set_dict, cls=W_AST),
  439. __new__=interp2app(W_AST_new),
  440. __init__=interp2app(W_AST_init),
  441. )
  442. class State:
  443. AST_TYPES = []
  444. @classmethod
  445. def ast_type(cls, name, base, fields, attributes=None):
  446. cls.AST_TYPES.append((name, base, fields, attributes))
  447. def __init__(self, space):
  448. self.w_AST = space.gettypeobject(W_AST.typedef)
  449. for (name, base, fields, attributes) in self.AST_TYPES:
  450. self.make_new_type(space, name, base, fields, attributes)
  451. def make_new_type(self, space, name, base, fields, attributes):
  452. w_base = getattr(self, 'w_%s' % base)
  453. w_dict = space.newdict()
  454. space.setitem_str(w_dict, '__module__', space.wrap('_ast'))
  455. if fields is not None:
  456. space.setitem_str(w_dict, "_fields",
  457. space.newtuple([space.wrap(f) for f in fields]))
  458. if attributes is not None:
  459. space.setitem_str(w_dict, "_attributes",
  460. space.newtuple([space.wrap(a) for a in attributes]))
  461. w_type = space.call_function(
  462. space.w_type,
  463. space.wrap(name), space.newtuple([w_base]), w_dict)
  464. setattr(self, 'w_%s' % name, w_type)
  465. def get(space):
  466. return space.fromcache(State)
  467. """
  468. visitors = [ASTNodeVisitor, ASTVisitorVisitor, GenericASTVisitorVisitor]
  469. def main(argv):
  470. if len(argv) == 3:
  471. def_file, out_file = argv[1:]
  472. elif len(argv) == 1:
  473. print "Assuming default values of Python.asdl and ast.py"
  474. here = os.path.dirname(__file__)
  475. def_file = os.path.join(here, "Python.asdl")
  476. out_file = os.path.join(here, "..", "ast.py")
  477. else:
  478. print >> sys.stderr, "invalid arguments"
  479. return 2
  480. mod = asdl.parse(def_file)
  481. data = ASDLData(mod)
  482. fp = open(out_file, "w")
  483. try:
  484. fp.write(HEAD)
  485. for visitor in visitors:
  486. visitor(fp, data).visit(mod)
  487. finally:
  488. fp.close()
  489. if __name__ == "__main__":
  490. sys.exit(main(sys.argv))