PageRenderTime 44ms CodeModel.GetById 10ms RepoModel.GetById 0ms app.codeStats 0ms

/monitor_batch/pymodules/python2.7/lib/python/sqlalchemy/testing/assertions.py

https://gitlab.com/pooja043/Globus_Docker_4
Python | 453 lines | 339 code | 52 blank | 62 comment | 46 complexity | e374925bc67cc9664e8842c19e899d5b MD5 | raw file
  1. # testing/assertions.py
  2. # Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. from __future__ import absolute_import
  8. from . import util as testutil
  9. from sqlalchemy import pool, orm, util
  10. from sqlalchemy.engine import default, create_engine, url
  11. from sqlalchemy import exc as sa_exc
  12. from sqlalchemy.util import decorator
  13. from sqlalchemy import types as sqltypes, schema
  14. import warnings
  15. import re
  16. from .warnings import resetwarnings
  17. from .exclusions import db_spec, _is_excluded
  18. from . import assertsql
  19. from . import config
  20. import itertools
  21. from .util import fail
  22. import contextlib
  23. def emits_warning(*messages):
  24. """Mark a test as emitting a warning.
  25. With no arguments, squelches all SAWarning failures. Or pass one or more
  26. strings; these will be matched to the root of the warning description by
  27. warnings.filterwarnings().
  28. """
  29. # TODO: it would be nice to assert that a named warning was
  30. # emitted. should work with some monkeypatching of warnings,
  31. # and may work on non-CPython if they keep to the spirit of
  32. # warnings.showwarning's docstring.
  33. # - update: jython looks ok, it uses cpython's module
  34. @decorator
  35. def decorate(fn, *args, **kw):
  36. # todo: should probably be strict about this, too
  37. filters = [dict(action='ignore',
  38. category=sa_exc.SAPendingDeprecationWarning)]
  39. if not messages:
  40. filters.append(dict(action='ignore',
  41. category=sa_exc.SAWarning))
  42. else:
  43. filters.extend(dict(action='ignore',
  44. message=message,
  45. category=sa_exc.SAWarning)
  46. for message in messages)
  47. for f in filters:
  48. warnings.filterwarnings(**f)
  49. try:
  50. return fn(*args, **kw)
  51. finally:
  52. resetwarnings()
  53. return decorate
  54. def emits_warning_on(db, *warnings):
  55. """Mark a test as emitting a warning on a specific dialect.
  56. With no arguments, squelches all SAWarning failures. Or pass one or more
  57. strings; these will be matched to the root of the warning description by
  58. warnings.filterwarnings().
  59. """
  60. spec = db_spec(db)
  61. @decorator
  62. def decorate(fn, *args, **kw):
  63. if isinstance(db, util.string_types):
  64. if not spec(config._current):
  65. return fn(*args, **kw)
  66. else:
  67. wrapped = emits_warning(*warnings)(fn)
  68. return wrapped(*args, **kw)
  69. else:
  70. if not _is_excluded(*db):
  71. return fn(*args, **kw)
  72. else:
  73. wrapped = emits_warning(*warnings)(fn)
  74. return wrapped(*args, **kw)
  75. return decorate
  76. def uses_deprecated(*messages):
  77. """Mark a test as immune from fatal deprecation warnings.
  78. With no arguments, squelches all SADeprecationWarning failures.
  79. Or pass one or more strings; these will be matched to the root
  80. of the warning description by warnings.filterwarnings().
  81. As a special case, you may pass a function name prefixed with //
  82. and it will be re-written as needed to match the standard warning
  83. verbiage emitted by the sqlalchemy.util.deprecated decorator.
  84. """
  85. @decorator
  86. def decorate(fn, *args, **kw):
  87. with expect_deprecated(*messages):
  88. return fn(*args, **kw)
  89. return decorate
  90. @contextlib.contextmanager
  91. def expect_deprecated(*messages):
  92. # todo: should probably be strict about this, too
  93. filters = [dict(action='ignore',
  94. category=sa_exc.SAPendingDeprecationWarning)]
  95. if not messages:
  96. filters.append(dict(action='ignore',
  97. category=sa_exc.SADeprecationWarning))
  98. else:
  99. filters.extend(
  100. [dict(action='ignore',
  101. message=message,
  102. category=sa_exc.SADeprecationWarning)
  103. for message in
  104. [(m.startswith('//') and
  105. ('Call to deprecated function ' + m[2:]) or m)
  106. for m in messages]])
  107. for f in filters:
  108. warnings.filterwarnings(**f)
  109. try:
  110. yield
  111. finally:
  112. resetwarnings()
  113. def global_cleanup_assertions():
  114. """Check things that have to be finalized at the end of a test suite.
  115. Hardcoded at the moment, a modular system can be built here
  116. to support things like PG prepared transactions, tables all
  117. dropped, etc.
  118. """
  119. _assert_no_stray_pool_connections()
  120. _STRAY_CONNECTION_FAILURES = 0
  121. def _assert_no_stray_pool_connections():
  122. global _STRAY_CONNECTION_FAILURES
  123. # lazy gc on cPython means "do nothing." pool connections
  124. # shouldn't be in cycles, should go away.
  125. testutil.lazy_gc()
  126. # however, once in awhile, on an EC2 machine usually,
  127. # there's a ref in there. usually just one.
  128. if pool._refs:
  129. # OK, let's be somewhat forgiving.
  130. _STRAY_CONNECTION_FAILURES += 1
  131. print("Encountered a stray connection in test cleanup: %s"
  132. % str(pool._refs))
  133. # then do a real GC sweep. We shouldn't even be here
  134. # so a single sweep should really be doing it, otherwise
  135. # there's probably a real unreachable cycle somewhere.
  136. testutil.gc_collect()
  137. # if we've already had two of these occurrences, or
  138. # after a hard gc sweep we still have pool._refs?!
  139. # now we have to raise.
  140. if pool._refs:
  141. err = str(pool._refs)
  142. # but clean out the pool refs collection directly,
  143. # reset the counter,
  144. # so the error doesn't at least keep happening.
  145. pool._refs.clear()
  146. _STRAY_CONNECTION_FAILURES = 0
  147. assert False, "Stray connection refused to leave "\
  148. "after gc.collect(): %s" % err
  149. elif _STRAY_CONNECTION_FAILURES > 10:
  150. assert False, "Encountered more than 10 stray connections"
  151. _STRAY_CONNECTION_FAILURES = 0
  152. def eq_(a, b, msg=None):
  153. """Assert a == b, with repr messaging on failure."""
  154. assert a == b, msg or "%r != %r" % (a, b)
  155. def ne_(a, b, msg=None):
  156. """Assert a != b, with repr messaging on failure."""
  157. assert a != b, msg or "%r == %r" % (a, b)
  158. def is_(a, b, msg=None):
  159. """Assert a is b, with repr messaging on failure."""
  160. assert a is b, msg or "%r is not %r" % (a, b)
  161. def is_not_(a, b, msg=None):
  162. """Assert a is not b, with repr messaging on failure."""
  163. assert a is not b, msg or "%r is %r" % (a, b)
  164. def startswith_(a, fragment, msg=None):
  165. """Assert a.startswith(fragment), with repr messaging on failure."""
  166. assert a.startswith(fragment), msg or "%r does not start with %r" % (
  167. a, fragment)
  168. def assert_raises(except_cls, callable_, *args, **kw):
  169. try:
  170. callable_(*args, **kw)
  171. success = False
  172. except except_cls:
  173. success = True
  174. # assert outside the block so it works for AssertionError too !
  175. assert success, "Callable did not raise an exception"
  176. def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
  177. try:
  178. callable_(*args, **kwargs)
  179. assert False, "Callable did not raise an exception"
  180. except except_cls as e:
  181. assert re.search(
  182. msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
  183. print(util.text_type(e).encode('utf-8'))
  184. class AssertsCompiledSQL(object):
  185. def assert_compile(self, clause, result, params=None,
  186. checkparams=None, dialect=None,
  187. checkpositional=None,
  188. use_default_dialect=False,
  189. allow_dialect_select=False,
  190. literal_binds=False):
  191. if use_default_dialect:
  192. dialect = default.DefaultDialect()
  193. elif allow_dialect_select:
  194. dialect = None
  195. else:
  196. if dialect is None:
  197. dialect = getattr(self, '__dialect__', None)
  198. if dialect is None:
  199. dialect = config.db.dialect
  200. elif dialect == 'default':
  201. dialect = default.DefaultDialect()
  202. elif isinstance(dialect, util.string_types):
  203. dialect = url.URL(dialect).get_dialect()()
  204. kw = {}
  205. compile_kwargs = {}
  206. if params is not None:
  207. kw['column_keys'] = list(params)
  208. if literal_binds:
  209. compile_kwargs['literal_binds'] = True
  210. if isinstance(clause, orm.Query):
  211. context = clause._compile_context()
  212. context.statement.use_labels = True
  213. clause = context.statement
  214. if compile_kwargs:
  215. kw['compile_kwargs'] = compile_kwargs
  216. c = clause.compile(dialect=dialect, **kw)
  217. param_str = repr(getattr(c, 'params', {}))
  218. if util.py3k:
  219. param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
  220. print(
  221. ("\nSQL String:\n" +
  222. util.text_type(c) +
  223. param_str).encode('utf-8'))
  224. else:
  225. print(
  226. "\nSQL String:\n" +
  227. util.text_type(c).encode('utf-8') +
  228. param_str)
  229. cc = re.sub(r'[\n\t]', '', util.text_type(c))
  230. eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
  231. if checkparams is not None:
  232. eq_(c.construct_params(params), checkparams)
  233. if checkpositional is not None:
  234. p = c.construct_params(params)
  235. eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
  236. class ComparesTables(object):
  237. def assert_tables_equal(self, table, reflected_table, strict_types=False):
  238. assert len(table.c) == len(reflected_table.c)
  239. for c, reflected_c in zip(table.c, reflected_table.c):
  240. eq_(c.name, reflected_c.name)
  241. assert reflected_c is reflected_table.c[c.name]
  242. eq_(c.primary_key, reflected_c.primary_key)
  243. eq_(c.nullable, reflected_c.nullable)
  244. if strict_types:
  245. msg = "Type '%s' doesn't correspond to type '%s'"
  246. assert isinstance(reflected_c.type, type(c.type)), \
  247. msg % (reflected_c.type, c.type)
  248. else:
  249. self.assert_types_base(reflected_c, c)
  250. if isinstance(c.type, sqltypes.String):
  251. eq_(c.type.length, reflected_c.type.length)
  252. eq_(
  253. set([f.column.name for f in c.foreign_keys]),
  254. set([f.column.name for f in reflected_c.foreign_keys])
  255. )
  256. if c.server_default:
  257. assert isinstance(reflected_c.server_default,
  258. schema.FetchedValue)
  259. assert len(table.primary_key) == len(reflected_table.primary_key)
  260. for c in table.primary_key:
  261. assert reflected_table.primary_key.columns[c.name] is not None
  262. def assert_types_base(self, c1, c2):
  263. assert c1.type._compare_type_affinity(c2.type),\
  264. "On column %r, type '%s' doesn't correspond to type '%s'" % \
  265. (c1.name, c1.type, c2.type)
  266. class AssertsExecutionResults(object):
  267. def assert_result(self, result, class_, *objects):
  268. result = list(result)
  269. print(repr(result))
  270. self.assert_list(result, class_, objects)
  271. def assert_list(self, result, class_, list):
  272. self.assert_(len(result) == len(list),
  273. "result list is not the same size as test list, " +
  274. "for class " + class_.__name__)
  275. for i in range(0, len(list)):
  276. self.assert_row(class_, result[i], list[i])
  277. def assert_row(self, class_, rowobj, desc):
  278. self.assert_(rowobj.__class__ is class_,
  279. "item class is not " + repr(class_))
  280. for key, value in desc.items():
  281. if isinstance(value, tuple):
  282. if isinstance(value[1], list):
  283. self.assert_list(getattr(rowobj, key), value[0], value[1])
  284. else:
  285. self.assert_row(value[0], getattr(rowobj, key), value[1])
  286. else:
  287. self.assert_(getattr(rowobj, key) == value,
  288. "attribute %s value %s does not match %s" % (
  289. key, getattr(rowobj, key), value))
  290. def assert_unordered_result(self, result, cls, *expected):
  291. """As assert_result, but the order of objects is not considered.
  292. The algorithm is very expensive but not a big deal for the small
  293. numbers of rows that the test suite manipulates.
  294. """
  295. class immutabledict(dict):
  296. def __hash__(self):
  297. return id(self)
  298. found = util.IdentitySet(result)
  299. expected = set([immutabledict(e) for e in expected])
  300. for wrong in util.itertools_filterfalse(lambda o:
  301. isinstance(o, cls), found):
  302. fail('Unexpected type "%s", expected "%s"' % (
  303. type(wrong).__name__, cls.__name__))
  304. if len(found) != len(expected):
  305. fail('Unexpected object count "%s", expected "%s"' % (
  306. len(found), len(expected)))
  307. NOVALUE = object()
  308. def _compare_item(obj, spec):
  309. for key, value in spec.items():
  310. if isinstance(value, tuple):
  311. try:
  312. self.assert_unordered_result(
  313. getattr(obj, key), value[0], *value[1])
  314. except AssertionError:
  315. return False
  316. else:
  317. if getattr(obj, key, NOVALUE) != value:
  318. return False
  319. return True
  320. for expected_item in expected:
  321. for found_item in found:
  322. if _compare_item(found_item, expected_item):
  323. found.remove(found_item)
  324. break
  325. else:
  326. fail(
  327. "Expected %s instance with attributes %s not found." % (
  328. cls.__name__, repr(expected_item)))
  329. return True
  330. def assert_sql_execution(self, db, callable_, *rules):
  331. assertsql.asserter.add_rules(rules)
  332. try:
  333. callable_()
  334. assertsql.asserter.statement_complete()
  335. finally:
  336. assertsql.asserter.clear_rules()
  337. def assert_sql(self, db, callable_, list_, with_sequences=None):
  338. if (with_sequences is not None and
  339. config.db.dialect.supports_sequences):
  340. rules = with_sequences
  341. else:
  342. rules = list_
  343. newrules = []
  344. for rule in rules:
  345. if isinstance(rule, dict):
  346. newrule = assertsql.AllOf(*[
  347. assertsql.ExactSQL(k, v) for k, v in rule.items()
  348. ])
  349. else:
  350. newrule = assertsql.ExactSQL(*rule)
  351. newrules.append(newrule)
  352. self.assert_sql_execution(db, callable_, *newrules)
  353. def assert_sql_count(self, db, callable_, count):
  354. self.assert_sql_execution(
  355. db, callable_, assertsql.CountStatements(count))
  356. @contextlib.contextmanager
  357. def assert_execution(self, *rules):
  358. assertsql.asserter.add_rules(rules)
  359. try:
  360. yield
  361. assertsql.asserter.statement_complete()
  362. finally:
  363. assertsql.asserter.clear_rules()
  364. def assert_statement_count(self, count):
  365. return self.assert_execution(assertsql.CountStatements(count))