PageRenderTime 54ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 1ms

/_pytest/assertion/rewrite.py

https://bitbucket.org/bwesterb/pypy
Python | 597 lines | 489 code | 40 blank | 68 comment | 93 complexity | 316f1c5dc5ad862b2985c8761f777424 MD5 | raw file
  1. """Rewrite assertion AST to produce nice error messages"""
  2. import ast
  3. import errno
  4. import itertools
  5. import imp
  6. import marshal
  7. import os
  8. import struct
  9. import sys
  10. import types
  11. import py
  12. from _pytest.assertion import util
  13. # Windows gives ENOENT in places *nix gives ENOTDIR.
  14. if sys.platform.startswith("win"):
  15. PATH_COMPONENT_NOT_DIR = errno.ENOENT
  16. else:
  17. PATH_COMPONENT_NOT_DIR = errno.ENOTDIR
  18. # py.test caches rewritten pycs in __pycache__.
  19. if hasattr(imp, "get_tag"):
  20. PYTEST_TAG = imp.get_tag() + "-PYTEST"
  21. else:
  22. if hasattr(sys, "pypy_version_info"):
  23. impl = "pypy"
  24. elif sys.platform == "java":
  25. impl = "jython"
  26. else:
  27. impl = "cpython"
  28. ver = sys.version_info
  29. PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
  30. del ver, impl
  31. PYC_EXT = ".py" + "c" if __debug__ else "o"
  32. PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
  33. REWRITE_NEWLINES = sys.version_info[:2] != (2, 7) and sys.version_info < (3, 2)
  34. class AssertionRewritingHook(object):
  35. """Import hook which rewrites asserts."""
  36. def __init__(self):
  37. self.session = None
  38. self.modules = {}
  39. def set_session(self, session):
  40. self.fnpats = session.config.getini("python_files")
  41. self.session = session
  42. def find_module(self, name, path=None):
  43. if self.session is None:
  44. return None
  45. sess = self.session
  46. state = sess.config._assertstate
  47. state.trace("find_module called for: %s" % name)
  48. names = name.rsplit(".", 1)
  49. lastname = names[-1]
  50. pth = None
  51. if path is not None and len(path) == 1:
  52. pth = path[0]
  53. if pth is None:
  54. try:
  55. fd, fn, desc = imp.find_module(lastname, path)
  56. except ImportError:
  57. return None
  58. if fd is not None:
  59. fd.close()
  60. tp = desc[2]
  61. if tp == imp.PY_COMPILED:
  62. if hasattr(imp, "source_from_cache"):
  63. fn = imp.source_from_cache(fn)
  64. else:
  65. fn = fn[:-1]
  66. elif tp != imp.PY_SOURCE:
  67. # Don't know what this is.
  68. return None
  69. else:
  70. fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
  71. fn_pypath = py.path.local(fn)
  72. # Is this a test file?
  73. if not sess.isinitpath(fn):
  74. # We have to be very careful here because imports in this code can
  75. # trigger a cycle.
  76. self.session = None
  77. try:
  78. for pat in self.fnpats:
  79. if fn_pypath.fnmatch(pat):
  80. state.trace("matched test file %r" % (fn,))
  81. break
  82. else:
  83. return None
  84. finally:
  85. self.session = sess
  86. else:
  87. state.trace("matched test file (was specified on cmdline): %r" % (fn,))
  88. # The requested module looks like a test file, so rewrite it. This is
  89. # the most magical part of the process: load the source, rewrite the
  90. # asserts, and load the rewritten source. We also cache the rewritten
  91. # module code in a special pyc. We must be aware of the possibility of
  92. # concurrent py.test processes rewriting and loading pycs. To avoid
  93. # tricky race conditions, we maintain the following invariant: The
  94. # cached pyc is always a complete, valid pyc. Operations on it must be
  95. # atomic. POSIX's atomic rename comes in handy.
  96. write = not sys.dont_write_bytecode
  97. cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
  98. if write:
  99. try:
  100. os.mkdir(cache_dir)
  101. except OSError:
  102. e = sys.exc_info()[1].errno
  103. if e == errno.EEXIST:
  104. # Either the __pycache__ directory already exists (the
  105. # common case) or it's blocked by a non-dir node. In the
  106. # latter case, we'll ignore it in _write_pyc.
  107. pass
  108. elif e == PATH_COMPONENT_NOT_DIR:
  109. # One of the path components was not a directory, likely
  110. # because we're in a zip file.
  111. write = False
  112. elif e == errno.EACCES:
  113. state.trace("read only directory: %r" % (fn_pypath.dirname,))
  114. write = False
  115. else:
  116. raise
  117. cache_name = fn_pypath.basename[:-3] + PYC_TAIL
  118. pyc = os.path.join(cache_dir, cache_name)
  119. # Notice that even if we're in a read-only directory, I'm going to check
  120. # for a cached pyc. This may not be optimal...
  121. co = _read_pyc(fn_pypath, pyc)
  122. if co is None:
  123. state.trace("rewriting %r" % (fn,))
  124. co = _rewrite_test(state, fn_pypath)
  125. if co is None:
  126. # Probably a SyntaxError in the test.
  127. return None
  128. if write:
  129. _make_rewritten_pyc(state, fn_pypath, pyc, co)
  130. else:
  131. state.trace("found cached rewritten pyc for %r" % (fn,))
  132. self.modules[name] = co, pyc
  133. return self
  134. def load_module(self, name):
  135. co, pyc = self.modules.pop(name)
  136. # I wish I could just call imp.load_compiled here, but __file__ has to
  137. # be set properly. In Python 3.2+, this all would be handled correctly
  138. # by load_compiled.
  139. mod = sys.modules[name] = imp.new_module(name)
  140. try:
  141. mod.__file__ = co.co_filename
  142. # Normally, this attribute is 3.2+.
  143. mod.__cached__ = pyc
  144. py.builtin.exec_(co, mod.__dict__)
  145. except:
  146. del sys.modules[name]
  147. raise
  148. return sys.modules[name]
  149. def _write_pyc(co, source_path, pyc):
  150. # Technically, we don't have to have the same pyc format as (C)Python, since
  151. # these "pycs" should never be seen by builtin import. However, there's
  152. # little reason deviate, and I hope sometime to be able to use
  153. # imp.load_compiled to load them. (See the comment in load_module above.)
  154. mtime = int(source_path.mtime())
  155. try:
  156. fp = open(pyc, "wb")
  157. except IOError:
  158. err = sys.exc_info()[1].errno
  159. if err == PATH_COMPONENT_NOT_DIR:
  160. # This happens when we get a EEXIST in find_module creating the
  161. # __pycache__ directory and __pycache__ is by some non-dir node.
  162. return False
  163. raise
  164. try:
  165. fp.write(imp.get_magic())
  166. fp.write(struct.pack("<l", mtime))
  167. marshal.dump(co, fp)
  168. finally:
  169. fp.close()
  170. return True
  171. RN = "\r\n".encode("utf-8")
  172. N = "\n".encode("utf-8")
  173. def _rewrite_test(state, fn):
  174. """Try to read and rewrite *fn* and return the code object."""
  175. try:
  176. source = fn.read("rb")
  177. except EnvironmentError:
  178. return None
  179. # On Python versions which are not 2.7 and less than or equal to 3.1, the
  180. # parser expects *nix newlines.
  181. if REWRITE_NEWLINES:
  182. source = source.replace(RN, N) + N
  183. try:
  184. tree = ast.parse(source)
  185. except SyntaxError:
  186. # Let this pop up again in the real import.
  187. state.trace("failed to parse: %r" % (fn,))
  188. return None
  189. rewrite_asserts(tree)
  190. try:
  191. co = compile(tree, fn.strpath, "exec")
  192. except SyntaxError:
  193. # It's possible that this error is from some bug in the
  194. # assertion rewriting, but I don't know of a fast way to tell.
  195. state.trace("failed to compile: %r" % (fn,))
  196. return None
  197. return co
  198. def _make_rewritten_pyc(state, fn, pyc, co):
  199. """Try to dump rewritten code to *pyc*."""
  200. if sys.platform.startswith("win"):
  201. # Windows grants exclusive access to open files and doesn't have atomic
  202. # rename, so just write into the final file.
  203. _write_pyc(co, fn, pyc)
  204. else:
  205. # When not on windows, assume rename is atomic. Dump the code object
  206. # into a file specific to this process and atomically replace it.
  207. proc_pyc = pyc + "." + str(os.getpid())
  208. if _write_pyc(co, fn, proc_pyc):
  209. os.rename(proc_pyc, pyc)
  210. def _read_pyc(source, pyc):
  211. """Possibly read a py.test pyc containing rewritten code.
  212. Return rewritten code if successful or None if not.
  213. """
  214. try:
  215. fp = open(pyc, "rb")
  216. except IOError:
  217. return None
  218. try:
  219. try:
  220. mtime = int(source.mtime())
  221. data = fp.read(8)
  222. except EnvironmentError:
  223. return None
  224. # Check for invalid or out of date pyc file.
  225. if (len(data) != 8 or
  226. data[:4] != imp.get_magic() or
  227. struct.unpack("<l", data[4:])[0] != mtime):
  228. return None
  229. co = marshal.load(fp)
  230. if not isinstance(co, types.CodeType):
  231. # That's interesting....
  232. return None
  233. return co
  234. finally:
  235. fp.close()
  236. def rewrite_asserts(mod):
  237. """Rewrite the assert statements in mod."""
  238. AssertionRewriter().run(mod)
  239. _saferepr = py.io.saferepr
  240. from _pytest.assertion.util import format_explanation as _format_explanation
  241. def _format_boolop(explanations, is_or):
  242. return "(" + (is_or and " or " or " and ").join(explanations) + ")"
  243. def _call_reprcompare(ops, results, expls, each_obj):
  244. for i, res, expl in zip(range(len(ops)), results, expls):
  245. try:
  246. done = not res
  247. except Exception:
  248. done = True
  249. if done:
  250. break
  251. if util._reprcompare is not None:
  252. custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
  253. if custom is not None:
  254. return custom
  255. return expl
  256. unary_map = {
  257. ast.Not : "not %s",
  258. ast.Invert : "~%s",
  259. ast.USub : "-%s",
  260. ast.UAdd : "+%s"
  261. }
  262. binop_map = {
  263. ast.BitOr : "|",
  264. ast.BitXor : "^",
  265. ast.BitAnd : "&",
  266. ast.LShift : "<<",
  267. ast.RShift : ">>",
  268. ast.Add : "+",
  269. ast.Sub : "-",
  270. ast.Mult : "*",
  271. ast.Div : "/",
  272. ast.FloorDiv : "//",
  273. ast.Mod : "%",
  274. ast.Eq : "==",
  275. ast.NotEq : "!=",
  276. ast.Lt : "<",
  277. ast.LtE : "<=",
  278. ast.Gt : ">",
  279. ast.GtE : ">=",
  280. ast.Pow : "**",
  281. ast.Is : "is",
  282. ast.IsNot : "is not",
  283. ast.In : "in",
  284. ast.NotIn : "not in"
  285. }
  286. def set_location(node, lineno, col_offset):
  287. """Set node location information recursively."""
  288. def _fix(node, lineno, col_offset):
  289. if "lineno" in node._attributes:
  290. node.lineno = lineno
  291. if "col_offset" in node._attributes:
  292. node.col_offset = col_offset
  293. for child in ast.iter_child_nodes(node):
  294. _fix(child, lineno, col_offset)
  295. _fix(node, lineno, col_offset)
  296. return node
  297. class AssertionRewriter(ast.NodeVisitor):
  298. def run(self, mod):
  299. """Find all assert statements in *mod* and rewrite them."""
  300. if not mod.body:
  301. # Nothing to do.
  302. return
  303. # Insert some special imports at the top of the module but after any
  304. # docstrings and __future__ imports.
  305. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
  306. ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
  307. expect_docstring = True
  308. pos = 0
  309. lineno = 0
  310. for item in mod.body:
  311. if (expect_docstring and isinstance(item, ast.Expr) and
  312. isinstance(item.value, ast.Str)):
  313. doc = item.value.s
  314. if "PYTEST_DONT_REWRITE" in doc:
  315. # The module has disabled assertion rewriting.
  316. return
  317. lineno += len(doc) - 1
  318. expect_docstring = False
  319. elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
  320. item.module != "__future__"):
  321. lineno = item.lineno
  322. break
  323. pos += 1
  324. imports = [ast.Import([alias], lineno=lineno, col_offset=0)
  325. for alias in aliases]
  326. mod.body[pos:pos] = imports
  327. # Collect asserts.
  328. nodes = [mod]
  329. while nodes:
  330. node = nodes.pop()
  331. for name, field in ast.iter_fields(node):
  332. if isinstance(field, list):
  333. new = []
  334. for i, child in enumerate(field):
  335. if isinstance(child, ast.Assert):
  336. # Transform assert.
  337. new.extend(self.visit(child))
  338. else:
  339. new.append(child)
  340. if isinstance(child, ast.AST):
  341. nodes.append(child)
  342. setattr(node, name, new)
  343. elif (isinstance(field, ast.AST) and
  344. # Don't recurse into expressions as they can't contain
  345. # asserts.
  346. not isinstance(field, ast.expr)):
  347. nodes.append(field)
  348. def variable(self):
  349. """Get a new variable."""
  350. # Use a character invalid in python identifiers to avoid clashing.
  351. name = "@py_assert" + str(next(self.variable_counter))
  352. self.variables.append(name)
  353. return name
  354. def assign(self, expr):
  355. """Give *expr* a name."""
  356. name = self.variable()
  357. self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
  358. return ast.Name(name, ast.Load())
  359. def display(self, expr):
  360. """Call py.io.saferepr on the expression."""
  361. return self.helper("saferepr", expr)
  362. def helper(self, name, *args):
  363. """Call a helper in this module."""
  364. py_name = ast.Name("@pytest_ar", ast.Load())
  365. attr = ast.Attribute(py_name, "_" + name, ast.Load())
  366. return ast.Call(attr, list(args), [], None, None)
  367. def builtin(self, name):
  368. """Return the builtin called *name*."""
  369. builtin_name = ast.Name("@py_builtins", ast.Load())
  370. return ast.Attribute(builtin_name, name, ast.Load())
  371. def explanation_param(self, expr):
  372. specifier = "py" + str(next(self.variable_counter))
  373. self.explanation_specifiers[specifier] = expr
  374. return "%(" + specifier + ")s"
  375. def push_format_context(self):
  376. self.explanation_specifiers = {}
  377. self.stack.append(self.explanation_specifiers)
  378. def pop_format_context(self, expl_expr):
  379. current = self.stack.pop()
  380. if self.stack:
  381. self.explanation_specifiers = self.stack[-1]
  382. keys = [ast.Str(key) for key in current.keys()]
  383. format_dict = ast.Dict(keys, list(current.values()))
  384. form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
  385. name = "@py_format" + str(next(self.variable_counter))
  386. self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
  387. return ast.Name(name, ast.Load())
  388. def generic_visit(self, node):
  389. """Handle expressions we don't have custom code for."""
  390. assert isinstance(node, ast.expr)
  391. res = self.assign(node)
  392. return res, self.explanation_param(self.display(res))
  393. def visit_Assert(self, assert_):
  394. if assert_.msg:
  395. # There's already a message. Don't mess with it.
  396. return [assert_]
  397. self.statements = []
  398. self.cond_chain = ()
  399. self.variables = []
  400. self.variable_counter = itertools.count()
  401. self.stack = []
  402. self.on_failure = []
  403. self.push_format_context()
  404. # Rewrite assert into a bunch of statements.
  405. top_condition, explanation = self.visit(assert_.test)
  406. # Create failure message.
  407. body = self.on_failure
  408. negation = ast.UnaryOp(ast.Not(), top_condition)
  409. self.statements.append(ast.If(negation, body, []))
  410. explanation = "assert " + explanation
  411. template = ast.Str(explanation)
  412. msg = self.pop_format_context(template)
  413. fmt = self.helper("format_explanation", msg)
  414. err_name = ast.Name("AssertionError", ast.Load())
  415. exc = ast.Call(err_name, [fmt], [], None, None)
  416. if sys.version_info[0] >= 3:
  417. raise_ = ast.Raise(exc, None)
  418. else:
  419. raise_ = ast.Raise(exc, None, None)
  420. body.append(raise_)
  421. # Clear temporary variables by setting them to None.
  422. if self.variables:
  423. variables = [ast.Name(name, ast.Store()) for name in self.variables]
  424. clear = ast.Assign(variables, ast.Name("None", ast.Load()))
  425. self.statements.append(clear)
  426. # Fix line numbers.
  427. for stmt in self.statements:
  428. set_location(stmt, assert_.lineno, assert_.col_offset)
  429. return self.statements
  430. def visit_Name(self, name):
  431. # Check if the name is local or not.
  432. locs = ast.Call(self.builtin("locals"), [], [], None, None)
  433. globs = ast.Call(self.builtin("globals"), [], [], None, None)
  434. ops = [ast.In(), ast.IsNot()]
  435. test = ast.Compare(ast.Str(name.id), ops, [locs, globs])
  436. expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
  437. return name, self.explanation_param(expr)
  438. def visit_BoolOp(self, boolop):
  439. res_var = self.variable()
  440. expl_list = self.assign(ast.List([], ast.Load()))
  441. app = ast.Attribute(expl_list, "append", ast.Load())
  442. is_or = int(isinstance(boolop.op, ast.Or))
  443. body = save = self.statements
  444. fail_save = self.on_failure
  445. levels = len(boolop.values) - 1
  446. self.push_format_context()
  447. # Process each operand, short-circuting if needed.
  448. for i, v in enumerate(boolop.values):
  449. if i:
  450. fail_inner = []
  451. self.on_failure.append(ast.If(cond, fail_inner, []))
  452. self.on_failure = fail_inner
  453. self.push_format_context()
  454. res, expl = self.visit(v)
  455. body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
  456. expl_format = self.pop_format_context(ast.Str(expl))
  457. call = ast.Call(app, [expl_format], [], None, None)
  458. self.on_failure.append(ast.Expr(call))
  459. if i < levels:
  460. cond = res
  461. if is_or:
  462. cond = ast.UnaryOp(ast.Not(), cond)
  463. inner = []
  464. self.statements.append(ast.If(cond, inner, []))
  465. self.statements = body = inner
  466. self.statements = save
  467. self.on_failure = fail_save
  468. expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
  469. expl = self.pop_format_context(expl_template)
  470. return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
  471. def visit_UnaryOp(self, unary):
  472. pattern = unary_map[unary.op.__class__]
  473. operand_res, operand_expl = self.visit(unary.operand)
  474. res = self.assign(ast.UnaryOp(unary.op, operand_res))
  475. return res, pattern % (operand_expl,)
  476. def visit_BinOp(self, binop):
  477. symbol = binop_map[binop.op.__class__]
  478. left_expr, left_expl = self.visit(binop.left)
  479. right_expr, right_expl = self.visit(binop.right)
  480. explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
  481. res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
  482. return res, explanation
  483. def visit_Call(self, call):
  484. new_func, func_expl = self.visit(call.func)
  485. arg_expls = []
  486. new_args = []
  487. new_kwargs = []
  488. new_star = new_kwarg = None
  489. for arg in call.args:
  490. res, expl = self.visit(arg)
  491. new_args.append(res)
  492. arg_expls.append(expl)
  493. for keyword in call.keywords:
  494. res, expl = self.visit(keyword.value)
  495. new_kwargs.append(ast.keyword(keyword.arg, res))
  496. arg_expls.append(keyword.arg + "=" + expl)
  497. if call.starargs:
  498. new_star, expl = self.visit(call.starargs)
  499. arg_expls.append("*" + expl)
  500. if call.kwargs:
  501. new_kwarg, expl = self.visit(call.kwargs)
  502. arg_expls.append("**" + expl)
  503. expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
  504. new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
  505. res = self.assign(new_call)
  506. res_expl = self.explanation_param(self.display(res))
  507. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  508. return res, outer_expl
  509. def visit_Attribute(self, attr):
  510. if not isinstance(attr.ctx, ast.Load):
  511. return self.generic_visit(attr)
  512. value, value_expl = self.visit(attr.value)
  513. res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
  514. res_expl = self.explanation_param(self.display(res))
  515. pat = "%s\n{%s = %s.%s\n}"
  516. expl = pat % (res_expl, res_expl, value_expl, attr.attr)
  517. return res, expl
  518. def visit_Compare(self, comp):
  519. self.push_format_context()
  520. left_res, left_expl = self.visit(comp.left)
  521. res_variables = [self.variable() for i in range(len(comp.ops))]
  522. load_names = [ast.Name(v, ast.Load()) for v in res_variables]
  523. store_names = [ast.Name(v, ast.Store()) for v in res_variables]
  524. it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
  525. expls = []
  526. syms = []
  527. results = [left_res]
  528. for i, op, next_operand in it:
  529. next_res, next_expl = self.visit(next_operand)
  530. results.append(next_res)
  531. sym = binop_map[op.__class__]
  532. syms.append(ast.Str(sym))
  533. expl = "%s %s %s" % (left_expl, sym, next_expl)
  534. expls.append(ast.Str(expl))
  535. res_expr = ast.Compare(left_res, [op], [next_res])
  536. self.statements.append(ast.Assign([store_names[i]], res_expr))
  537. left_res, left_expl = next_res, next_expl
  538. # Use py.code._reprcompare if that's available.
  539. expl_call = self.helper("call_reprcompare",
  540. ast.Tuple(syms, ast.Load()),
  541. ast.Tuple(load_names, ast.Load()),
  542. ast.Tuple(expls, ast.Load()),
  543. ast.Tuple(results, ast.Load()))
  544. if len(comp.ops) > 1:
  545. res = ast.BoolOp(ast.And(), load_names)
  546. else:
  547. res = load_names[0]
  548. return res, self.explanation_param(self.pop_format_context(expl_call))