/meta/asttools/visitors/graph_visitor.py

https://github.com/srossross/Meta
Python | 402 lines | 259 code | 126 blank | 17 comment | 71 complexity | 75fa50c4a30cd05757823efd047c011f MD5 | raw file
  1. '''
  2. Created on Jul 18, 2011
  3. @author: sean
  4. '''
  5. from meta.asttools import Visitor, visit_children
  6. import _ast
  7. from meta.asttools.visitors.symbol_visitor import get_symbols
  8. try:
  9. from networkx import DiGraph
  10. except ImportError:
  11. DiGraph = None
  12. def collect_(self, node):
  13. names = set()
  14. for child in self.children(node):
  15. names.update(self.visit(child))
  16. if hasattr(node, 'ctx'):
  17. if isinstance(node.ctx, _ast.Store):
  18. self.modified.update(names)
  19. elif isinstance(node.ctx, _ast.Load):
  20. self.used.update(names)
  21. return names
  22. class CollectNodes(Visitor):
  23. def __init__(self, call_deps=False):
  24. self.graph = DiGraph()
  25. self.modified = set()
  26. self.used = set()
  27. self.undefined = set()
  28. self.sources = set()
  29. self.targets = set()
  30. self.context_names = set()
  31. self.call_deps = call_deps
  32. visitDefault = collect_
  33. def visitName(self, node):
  34. if isinstance(node.ctx, _ast.Store):
  35. self.modified.add(node.id)
  36. elif isinstance(node.ctx, _ast.Load):
  37. self.used.update(node.id)
  38. if not self.graph.has_node(node.id):
  39. self.graph.add_node(node.id)
  40. if isinstance(node.ctx, _ast.Load):
  41. self.undefined.add(node.id)
  42. for ctx_var in self.context_names:
  43. if not self.graph.has_edge(node.id, ctx_var):
  44. self.graph.add_edge(node.id, ctx_var)
  45. return {node.id}
  46. def visitalias(self, node):
  47. name = node.asname if node.asname else node.name
  48. if '.' in name:
  49. name = name.split('.', 1)[0]
  50. if not self.graph.has_node(name):
  51. self.graph.add_node(name)
  52. return {name}
  53. def visitCall(self, node):
  54. left = self.visit(node.func)
  55. right = set()
  56. for attr in ('args', 'keywords'):
  57. for child in getattr(node, attr):
  58. if child:
  59. right.update(self.visit(child))
  60. for attr in ('starargs', 'kwargs'):
  61. child = getattr(node, attr)
  62. if child:
  63. right.update(self.visit(child))
  64. for src in left | right:
  65. if not self.graph.has_node(src):
  66. self.undefined.add(src)
  67. if self.call_deps:
  68. add_edges(self.graph, left, right)
  69. add_edges(self.graph, right, left)
  70. right.update(left)
  71. return right
  72. def visitSubscript(self, node):
  73. if isinstance(node.ctx, _ast.Load):
  74. return collect_(self, node)
  75. else:
  76. sources = self.visit(node.slice)
  77. targets = self.visit(node.value)
  78. self.modified.update(targets)
  79. add_edges(self.graph, targets, sources)
  80. return targets
  81. def handle_generators(self, generators):
  82. defined = set()
  83. required = set()
  84. for generator in generators:
  85. get_symbols(generator, _ast.Load)
  86. required.update(get_symbols(generator, _ast.Load) - defined)
  87. defined.update(get_symbols(generator, _ast.Store))
  88. return defined, required
  89. def visitListComp(self, node):
  90. defined, required = self.handle_generators(node.generators)
  91. required.update(get_symbols(node.elt, _ast.Load) - defined)
  92. for symbol in required:
  93. if not self.graph.has_node(symbol):
  94. self.graph.add_node(symbol)
  95. self.undefined.add(symbol)
  96. return required
  97. def visitSetComp(self, node):
  98. defined, required = self.handle_generators(node.generators)
  99. required.update(get_symbols(node.elt, _ast.Load) - defined)
  100. for symbol in required:
  101. if not self.graph.has_node(symbol):
  102. self.graph.add_node(symbol)
  103. self.undefined.add(symbol)
  104. return required
  105. def visitDictComp(self, node):
  106. defined, required = self.handle_generators(node.generators)
  107. required.update(get_symbols(node.key, _ast.Load) - defined)
  108. required.update(get_symbols(node.value, _ast.Load) - defined)
  109. for symbol in required:
  110. if not self.graph.has_node(symbol):
  111. self.graph.add_node(symbol)
  112. self.undefined.add(symbol)
  113. return required
  114. def add_edges(graph, targets, sources):
  115. for target in targets:
  116. for src in sources:
  117. edge = target, src
  118. if not graph.has_edge(*edge):
  119. graph.add_edge(*edge)
  120. class GlobalDeps(object):
  121. def __init__(self, gen, nodes):
  122. self.nodes = nodes
  123. self.gen = gen
  124. def __enter__(self):
  125. self._old_context_names = set(self.gen.context_names)
  126. self.gen.context_names.update(self.nodes)
  127. def __exit__(self, *args):
  128. self.gen.context_names = self._old_context_names
  129. class GraphGen(CollectNodes):
  130. '''
  131. Create a graph from the execution flow of the ast
  132. '''
  133. visitModule = visit_children
  134. def depends_on(self, nodes):
  135. return GlobalDeps(self, set(nodes))
  136. def visit_lambda(self, node):
  137. sources = self.visit(node.args)
  138. self.sources.update(sources)
  139. self.visit(node.body)
  140. def visitLambda(self, node):
  141. gen = GraphGen()
  142. gen.visit_lambda(node)
  143. for undef in gen.undefined:
  144. if not self.graph.has_node(undef):
  145. self.graph.add_node(undef)
  146. return gen.undefined
  147. def visit_function_def(self, node):
  148. sources = self.visit(node.args)
  149. self.sources.update(sources)
  150. for stmnt in node.body:
  151. self.visit(stmnt)
  152. def visitFunctionDef(self, node):
  153. gen = GraphGen()
  154. gen.visit_function_def(node)
  155. if not self.graph.has_node(node.name):
  156. self.graph.add_node(node.name)
  157. for undef in gen.undefined:
  158. if not self.graph.has_node(undef):
  159. self.graph.add_node(undef)
  160. add_edges(self.graph, [node.name], gen.undefined)
  161. return gen.undefined
  162. def visitAssign(self, node):
  163. nodes = self.visit(node.value)
  164. tsymols = get_symbols(node, _ast.Store)
  165. re_defined = tsymols.intersection(set(self.graph.nodes()))
  166. if re_defined:
  167. add_edges(self.graph, re_defined, re_defined)
  168. targets = set()
  169. for target in node.targets:
  170. targets.update(self.visit(target))
  171. add_edges(self.graph, targets, nodes)
  172. return targets | nodes
  173. def visitAugAssign(self, node):
  174. targets = self.visit(node.target)
  175. values = self.visit(node.value)
  176. self.modified.update(targets)
  177. for target in targets:
  178. for value in values:
  179. edge = target, value
  180. if not self.graph.has_edge(*edge):
  181. self.graph.add_edge(*edge)
  182. for tgt2 in targets:
  183. edge = target, tgt2
  184. if not self.graph.has_edge(*edge):
  185. self.graph.add_edge(*edge)
  186. return targets | values
  187. def visitFor(self, node):
  188. nodes = set()
  189. targets = self.visit(node.target)
  190. for_iter = self.visit(node.iter)
  191. nodes.update(targets)
  192. nodes.update(for_iter)
  193. add_edges(self.graph, targets, for_iter)
  194. with self.depends_on(for_iter):
  195. for stmnt in node.body:
  196. nodes.update(self.visit(stmnt))
  197. return nodes
  198. def visitIf(self, node):
  199. nodes = set()
  200. names = self.visit(node.test)
  201. nodes.update(names)
  202. with self.depends_on(names):
  203. for stmnt in node.body:
  204. nodes.update(self.visit(stmnt))
  205. for stmnt in node.orelse:
  206. nodes.update(self.visit(stmnt))
  207. return nodes
  208. def visitReturn(self, node):
  209. targets = self.visit(node.value)
  210. self.targets.update(targets)
  211. return targets
  212. def visitWith(self, node):
  213. nodes = set()
  214. targets = self.visit(node.context_expr)
  215. nodes.update(targets)
  216. if node.optional_vars is None:
  217. vars = ()
  218. else:
  219. vars = self.visit(node.optional_vars)
  220. nodes.update(vars)
  221. add_edges(self.graph, vars, targets)
  222. with self.depends_on(targets):
  223. for stmnt in node.body:
  224. nodes.update(self.visit(stmnt))
  225. return nodes
  226. def visitWhile(self, node):
  227. nodes = set()
  228. targets = self.visit(node.test)
  229. nodes.update(targets)
  230. with self.depends_on(targets):
  231. for stmnt in node.body:
  232. nodes.update(self.visit(stmnt))
  233. for stmnt in node.orelse:
  234. nodes.update(self.visit(stmnt))
  235. return nodes
  236. def visitTryFinally(self, node):
  237. assert len(node.body) == 1
  238. nodes = self.visit(node.body[0])
  239. with self.depends_on(nodes):
  240. for stmnt in node.finalbody:
  241. nodes.update(self.visit(stmnt))
  242. def visitTryExcept(self, node):
  243. body_nodes = set()
  244. for stmnt in node.body:
  245. body_nodes.update(self.visit(stmnt))
  246. all_nodes = set(body_nodes)
  247. for hndlr in node.handlers:
  248. nodes = set(body_nodes)
  249. if hndlr.name:
  250. nodes.update(self.visit(hndlr.name))
  251. if hndlr.type:
  252. nodes.update(self.visit(hndlr.type))
  253. with self.depends_on(nodes):
  254. for stmnt in hndlr.body:
  255. nodes.update(self.visit(stmnt))
  256. all_nodes.update(nodes)
  257. nodes = set(body_nodes)
  258. with self.depends_on(nodes):
  259. for stmnt in node.orelse:
  260. nodes.update(self.visit(stmnt))
  261. all_nodes.update(nodes)
  262. return all_nodes
  263. def make_graph(node, call_deps=False):
  264. '''
  265. Create a dependency graph from an ast node.
  266. :param node: ast node.
  267. :param call_deps: if true, then the graph will create a cyclic dependance for all
  268. function calls. (i.e for `a.b(c)` a depends on b and b depends on a)
  269. :returns: a tuple of (graph, undefined)
  270. '''
  271. gen = GraphGen(call_deps=call_deps)
  272. gen.visit(node)
  273. return gen.graph, gen.undefined