PageRenderTime 48ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/lib/sqlalchemy/testing/assertions.py

https://bitbucket.org/eastfox2002/sqlalchemy
Python | 377 lines | 296 code | 37 blank | 44 comment | 37 complexity | 558a72542f07a1df1da8eb8d2655566c MD5 | raw file
  1. from __future__ import absolute_import
  2. from . import util as testutil
  3. from sqlalchemy import pool, orm, util
  4. from sqlalchemy.engine import default, create_engine
  5. from sqlalchemy import exc as sa_exc
  6. from sqlalchemy.util import decorator
  7. from sqlalchemy import types as sqltypes, schema
  8. import warnings
  9. import re
  10. from .warnings import resetwarnings
  11. from .exclusions import db_spec, _is_excluded
  12. from . import assertsql
  13. from . import config
  14. import itertools
  15. from .util import fail
  16. import contextlib
  17. def emits_warning(*messages):
  18. """Mark a test as emitting a warning.
  19. With no arguments, squelches all SAWarning failures. Or pass one or more
  20. strings; these will be matched to the root of the warning description by
  21. warnings.filterwarnings().
  22. """
  23. # TODO: it would be nice to assert that a named warning was
  24. # emitted. should work with some monkeypatching of warnings,
  25. # and may work on non-CPython if they keep to the spirit of
  26. # warnings.showwarning's docstring.
  27. # - update: jython looks ok, it uses cpython's module
  28. @decorator
  29. def decorate(fn, *args, **kw):
  30. # todo: should probably be strict about this, too
  31. filters = [dict(action='ignore',
  32. category=sa_exc.SAPendingDeprecationWarning)]
  33. if not messages:
  34. filters.append(dict(action='ignore',
  35. category=sa_exc.SAWarning))
  36. else:
  37. filters.extend(dict(action='ignore',
  38. message=message,
  39. category=sa_exc.SAWarning)
  40. for message in messages)
  41. for f in filters:
  42. warnings.filterwarnings(**f)
  43. try:
  44. return fn(*args, **kw)
  45. finally:
  46. resetwarnings()
  47. return decorate
  48. def emits_warning_on(db, *warnings):
  49. """Mark a test as emitting a warning on a specific dialect.
  50. With no arguments, squelches all SAWarning failures. Or pass one or more
  51. strings; these will be matched to the root of the warning description by
  52. warnings.filterwarnings().
  53. """
  54. spec = db_spec(db)
  55. @decorator
  56. def decorate(fn, *args, **kw):
  57. if isinstance(db, basestring):
  58. if not spec(config.db):
  59. return fn(*args, **kw)
  60. else:
  61. wrapped = emits_warning(*warnings)(fn)
  62. return wrapped(*args, **kw)
  63. else:
  64. if not _is_excluded(*db):
  65. return fn(*args, **kw)
  66. else:
  67. wrapped = emits_warning(*warnings)(fn)
  68. return wrapped(*args, **kw)
  69. return decorate
  70. def uses_deprecated(*messages):
  71. """Mark a test as immune from fatal deprecation warnings.
  72. With no arguments, squelches all SADeprecationWarning failures.
  73. Or pass one or more strings; these will be matched to the root
  74. of the warning description by warnings.filterwarnings().
  75. As a special case, you may pass a function name prefixed with //
  76. and it will be re-written as needed to match the standard warning
  77. verbiage emitted by the sqlalchemy.util.deprecated decorator.
  78. """
  79. @decorator
  80. def decorate(fn, *args, **kw):
  81. # todo: should probably be strict about this, too
  82. filters = [dict(action='ignore',
  83. category=sa_exc.SAPendingDeprecationWarning)]
  84. if not messages:
  85. filters.append(dict(action='ignore',
  86. category=sa_exc.SADeprecationWarning))
  87. else:
  88. filters.extend(
  89. [dict(action='ignore',
  90. message=message,
  91. category=sa_exc.SADeprecationWarning)
  92. for message in
  93. [(m.startswith('//') and
  94. ('Call to deprecated function ' + m[2:]) or m)
  95. for m in messages]])
  96. for f in filters:
  97. warnings.filterwarnings(**f)
  98. try:
  99. return fn(*args, **kw)
  100. finally:
  101. resetwarnings()
  102. return decorate
  103. def global_cleanup_assertions():
  104. """Check things that have to be finalized at the end of a test suite.
  105. Hardcoded at the moment, a modular system can be built here
  106. to support things like PG prepared transactions, tables all
  107. dropped, etc.
  108. """
  109. testutil.lazy_gc()
  110. assert not pool._refs, str(pool._refs)
  111. def eq_(a, b, msg=None):
  112. """Assert a == b, with repr messaging on failure."""
  113. assert a == b, msg or "%r != %r" % (a, b)
  114. def ne_(a, b, msg=None):
  115. """Assert a != b, with repr messaging on failure."""
  116. assert a != b, msg or "%r == %r" % (a, b)
  117. def is_(a, b, msg=None):
  118. """Assert a is b, with repr messaging on failure."""
  119. assert a is b, msg or "%r is not %r" % (a, b)
  120. def is_not_(a, b, msg=None):
  121. """Assert a is not b, with repr messaging on failure."""
  122. assert a is not b, msg or "%r is %r" % (a, b)
  123. def startswith_(a, fragment, msg=None):
  124. """Assert a.startswith(fragment), with repr messaging on failure."""
  125. assert a.startswith(fragment), msg or "%r does not start with %r" % (
  126. a, fragment)
  127. def assert_raises(except_cls, callable_, *args, **kw):
  128. try:
  129. callable_(*args, **kw)
  130. success = False
  131. except except_cls:
  132. success = True
  133. # assert outside the block so it works for AssertionError too !
  134. assert success, "Callable did not raise an exception"
  135. def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
  136. try:
  137. callable_(*args, **kwargs)
  138. assert False, "Callable did not raise an exception"
  139. except except_cls, e:
  140. assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
  141. print unicode(e).encode('utf-8')
  142. class AssertsCompiledSQL(object):
  143. def assert_compile(self, clause, result, params=None,
  144. checkparams=None, dialect=None,
  145. checkpositional=None,
  146. use_default_dialect=False,
  147. allow_dialect_select=False):
  148. if use_default_dialect:
  149. dialect = default.DefaultDialect()
  150. elif dialect == None and not allow_dialect_select:
  151. dialect = getattr(self, '__dialect__', None)
  152. if dialect == 'default':
  153. dialect = default.DefaultDialect()
  154. elif dialect is None:
  155. dialect = config.db.dialect
  156. elif isinstance(dialect, basestring):
  157. dialect = create_engine("%s://" % dialect).dialect
  158. kw = {}
  159. if params is not None:
  160. kw['column_keys'] = params.keys()
  161. if isinstance(clause, orm.Query):
  162. context = clause._compile_context()
  163. context.statement.use_labels = True
  164. clause = context.statement
  165. c = clause.compile(dialect=dialect, **kw)
  166. param_str = repr(getattr(c, 'params', {}))
  167. # Py3K
  168. #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
  169. print "\nSQL String:\n" + str(c) + param_str
  170. cc = re.sub(r'[\n\t]', '', str(c))
  171. eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
  172. if checkparams is not None:
  173. eq_(c.construct_params(params), checkparams)
  174. if checkpositional is not None:
  175. p = c.construct_params(params)
  176. eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
  177. class ComparesTables(object):
  178. def assert_tables_equal(self, table, reflected_table, strict_types=False):
  179. assert len(table.c) == len(reflected_table.c)
  180. for c, reflected_c in zip(table.c, reflected_table.c):
  181. eq_(c.name, reflected_c.name)
  182. assert reflected_c is reflected_table.c[c.name]
  183. eq_(c.primary_key, reflected_c.primary_key)
  184. eq_(c.nullable, reflected_c.nullable)
  185. if strict_types:
  186. msg = "Type '%s' doesn't correspond to type '%s'"
  187. assert type(reflected_c.type) is type(c.type), \
  188. msg % (reflected_c.type, c.type)
  189. else:
  190. self.assert_types_base(reflected_c, c)
  191. if isinstance(c.type, sqltypes.String):
  192. eq_(c.type.length, reflected_c.type.length)
  193. eq_(
  194. set([f.column.name for f in c.foreign_keys]),
  195. set([f.column.name for f in reflected_c.foreign_keys])
  196. )
  197. if c.server_default:
  198. assert isinstance(reflected_c.server_default,
  199. schema.FetchedValue)
  200. assert len(table.primary_key) == len(reflected_table.primary_key)
  201. for c in table.primary_key:
  202. assert reflected_table.primary_key.columns[c.name] is not None
  203. def assert_types_base(self, c1, c2):
  204. assert c1.type._compare_type_affinity(c2.type),\
  205. "On column %r, type '%s' doesn't correspond to type '%s'" % \
  206. (c1.name, c1.type, c2.type)
  207. class AssertsExecutionResults(object):
  208. def assert_result(self, result, class_, *objects):
  209. result = list(result)
  210. print repr(result)
  211. self.assert_list(result, class_, objects)
  212. def assert_list(self, result, class_, list):
  213. self.assert_(len(result) == len(list),
  214. "result list is not the same size as test list, " +
  215. "for class " + class_.__name__)
  216. for i in range(0, len(list)):
  217. self.assert_row(class_, result[i], list[i])
  218. def assert_row(self, class_, rowobj, desc):
  219. self.assert_(rowobj.__class__ is class_,
  220. "item class is not " + repr(class_))
  221. for key, value in desc.iteritems():
  222. if isinstance(value, tuple):
  223. if isinstance(value[1], list):
  224. self.assert_list(getattr(rowobj, key), value[0], value[1])
  225. else:
  226. self.assert_row(value[0], getattr(rowobj, key), value[1])
  227. else:
  228. self.assert_(getattr(rowobj, key) == value,
  229. "attribute %s value %s does not match %s" % (
  230. key, getattr(rowobj, key), value))
  231. def assert_unordered_result(self, result, cls, *expected):
  232. """As assert_result, but the order of objects is not considered.
  233. The algorithm is very expensive but not a big deal for the small
  234. numbers of rows that the test suite manipulates.
  235. """
  236. class immutabledict(dict):
  237. def __hash__(self):
  238. return id(self)
  239. found = util.IdentitySet(result)
  240. expected = set([immutabledict(e) for e in expected])
  241. for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
  242. fail('Unexpected type "%s", expected "%s"' % (
  243. type(wrong).__name__, cls.__name__))
  244. if len(found) != len(expected):
  245. fail('Unexpected object count "%s", expected "%s"' % (
  246. len(found), len(expected)))
  247. NOVALUE = object()
  248. def _compare_item(obj, spec):
  249. for key, value in spec.iteritems():
  250. if isinstance(value, tuple):
  251. try:
  252. self.assert_unordered_result(
  253. getattr(obj, key), value[0], *value[1])
  254. except AssertionError:
  255. return False
  256. else:
  257. if getattr(obj, key, NOVALUE) != value:
  258. return False
  259. return True
  260. for expected_item in expected:
  261. for found_item in found:
  262. if _compare_item(found_item, expected_item):
  263. found.remove(found_item)
  264. break
  265. else:
  266. fail(
  267. "Expected %s instance with attributes %s not found." % (
  268. cls.__name__, repr(expected_item)))
  269. return True
  270. def assert_sql_execution(self, db, callable_, *rules):
  271. assertsql.asserter.add_rules(rules)
  272. try:
  273. callable_()
  274. assertsql.asserter.statement_complete()
  275. finally:
  276. assertsql.asserter.clear_rules()
  277. def assert_sql(self, db, callable_, list_, with_sequences=None):
  278. if with_sequences is not None and config.db.dialect.supports_sequences:
  279. rules = with_sequences
  280. else:
  281. rules = list_
  282. newrules = []
  283. for rule in rules:
  284. if isinstance(rule, dict):
  285. newrule = assertsql.AllOf(*[
  286. assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
  287. ])
  288. else:
  289. newrule = assertsql.ExactSQL(*rule)
  290. newrules.append(newrule)
  291. self.assert_sql_execution(db, callable_, *newrules)
  292. def assert_sql_count(self, db, callable_, count):
  293. self.assert_sql_execution(
  294. db, callable_, assertsql.CountStatements(count))
  295. @contextlib.contextmanager
  296. def assert_execution(self, *rules):
  297. assertsql.asserter.add_rules(rules)
  298. try:
  299. yield
  300. assertsql.asserter.statement_complete()
  301. finally:
  302. assertsql.asserter.clear_rules()
  303. def assert_statement_count(self, count):
  304. return self.assert_execution(assertsql.CountStatements(count))