/notify_user/pymodules/python2.7/lib/python/sqlalchemy/testing/assertions.py
Python | 453 lines | 339 code | 52 blank | 62 comment | 46 complexity | e374925bc67cc9664e8842c19e899d5b MD5 | raw file
- # testing/assertions.py
- # Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: http://www.opensource.org/licenses/mit-license.php
- from __future__ import absolute_import
- from . import util as testutil
- from sqlalchemy import pool, orm, util
- from sqlalchemy.engine import default, create_engine, url
- from sqlalchemy import exc as sa_exc
- from sqlalchemy.util import decorator
- from sqlalchemy import types as sqltypes, schema
- import warnings
- import re
- from .warnings import resetwarnings
- from .exclusions import db_spec, _is_excluded
- from . import assertsql
- from . import config
- import itertools
- from .util import fail
- import contextlib
- def emits_warning(*messages):
- """Mark a test as emitting a warning.
- With no arguments, squelches all SAWarning failures. Or pass one or more
- strings; these will be matched to the root of the warning description by
- warnings.filterwarnings().
- """
- # TODO: it would be nice to assert that a named warning was
- # emitted. should work with some monkeypatching of warnings,
- # and may work on non-CPython if they keep to the spirit of
- # warnings.showwarning's docstring.
- # - update: jython looks ok, it uses cpython's module
- @decorator
- def decorate(fn, *args, **kw):
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SAWarning))
- else:
- filters.extend(dict(action='ignore',
- message=message,
- category=sa_exc.SAWarning)
- for message in messages)
- for f in filters:
- warnings.filterwarnings(**f)
- try:
- return fn(*args, **kw)
- finally:
- resetwarnings()
- return decorate
- def emits_warning_on(db, *warnings):
- """Mark a test as emitting a warning on a specific dialect.
- With no arguments, squelches all SAWarning failures. Or pass one or more
- strings; these will be matched to the root of the warning description by
- warnings.filterwarnings().
- """
- spec = db_spec(db)
- @decorator
- def decorate(fn, *args, **kw):
- if isinstance(db, util.string_types):
- if not spec(config._current):
- return fn(*args, **kw)
- else:
- wrapped = emits_warning(*warnings)(fn)
- return wrapped(*args, **kw)
- else:
- if not _is_excluded(*db):
- return fn(*args, **kw)
- else:
- wrapped = emits_warning(*warnings)(fn)
- return wrapped(*args, **kw)
- return decorate
- def uses_deprecated(*messages):
- """Mark a test as immune from fatal deprecation warnings.
- With no arguments, squelches all SADeprecationWarning failures.
- Or pass one or more strings; these will be matched to the root
- of the warning description by warnings.filterwarnings().
- As a special case, you may pass a function name prefixed with //
- and it will be re-written as needed to match the standard warning
- verbiage emitted by the sqlalchemy.util.deprecated decorator.
- """
- @decorator
- def decorate(fn, *args, **kw):
- with expect_deprecated(*messages):
- return fn(*args, **kw)
- return decorate
- @contextlib.contextmanager
- def expect_deprecated(*messages):
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SADeprecationWarning))
- else:
- filters.extend(
- [dict(action='ignore',
- message=message,
- category=sa_exc.SADeprecationWarning)
- for message in
- [(m.startswith('//') and
- ('Call to deprecated function ' + m[2:]) or m)
- for m in messages]])
- for f in filters:
- warnings.filterwarnings(**f)
- try:
- yield
- finally:
- resetwarnings()
- def global_cleanup_assertions():
- """Check things that have to be finalized at the end of a test suite.
- Hardcoded at the moment, a modular system can be built here
- to support things like PG prepared transactions, tables all
- dropped, etc.
- """
- _assert_no_stray_pool_connections()
- _STRAY_CONNECTION_FAILURES = 0
- def _assert_no_stray_pool_connections():
- global _STRAY_CONNECTION_FAILURES
- # lazy gc on cPython means "do nothing." pool connections
- # shouldn't be in cycles, should go away.
- testutil.lazy_gc()
- # however, once in awhile, on an EC2 machine usually,
- # there's a ref in there. usually just one.
- if pool._refs:
- # OK, let's be somewhat forgiving.
- _STRAY_CONNECTION_FAILURES += 1
- print("Encountered a stray connection in test cleanup: %s"
- % str(pool._refs))
- # then do a real GC sweep. We shouldn't even be here
- # so a single sweep should really be doing it, otherwise
- # there's probably a real unreachable cycle somewhere.
- testutil.gc_collect()
- # if we've already had two of these occurrences, or
- # after a hard gc sweep we still have pool._refs?!
- # now we have to raise.
- if pool._refs:
- err = str(pool._refs)
- # but clean out the pool refs collection directly,
- # reset the counter,
- # so the error doesn't at least keep happening.
- pool._refs.clear()
- _STRAY_CONNECTION_FAILURES = 0
- assert False, "Stray connection refused to leave "\
- "after gc.collect(): %s" % err
- elif _STRAY_CONNECTION_FAILURES > 10:
- assert False, "Encountered more than 10 stray connections"
- _STRAY_CONNECTION_FAILURES = 0
- def eq_(a, b, msg=None):
- """Assert a == b, with repr messaging on failure."""
- assert a == b, msg or "%r != %r" % (a, b)
- def ne_(a, b, msg=None):
- """Assert a != b, with repr messaging on failure."""
- assert a != b, msg or "%r == %r" % (a, b)
- def is_(a, b, msg=None):
- """Assert a is b, with repr messaging on failure."""
- assert a is b, msg or "%r is not %r" % (a, b)
- def is_not_(a, b, msg=None):
- """Assert a is not b, with repr messaging on failure."""
- assert a is not b, msg or "%r is %r" % (a, b)
- def startswith_(a, fragment, msg=None):
- """Assert a.startswith(fragment), with repr messaging on failure."""
- assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
- def assert_raises(except_cls, callable_, *args, **kw):
- try:
- callable_(*args, **kw)
- success = False
- except except_cls:
- success = True
- # assert outside the block so it works for AssertionError too !
- assert success, "Callable did not raise an exception"
- def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
- try:
- callable_(*args, **kwargs)
- assert False, "Callable did not raise an exception"
- except except_cls as e:
- assert re.search(
- msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(util.text_type(e).encode('utf-8'))
- class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- use_default_dialect=False,
- allow_dialect_select=False,
- literal_binds=False):
- if use_default_dialect:
- dialect = default.DefaultDialect()
- elif allow_dialect_select:
- dialect = None
- else:
- if dialect is None:
- dialect = getattr(self, '__dialect__', None)
- if dialect is None:
- dialect = config.db.dialect
- elif dialect == 'default':
- dialect = default.DefaultDialect()
- elif isinstance(dialect, util.string_types):
- dialect = url.URL(dialect).get_dialect()()
- kw = {}
- compile_kwargs = {}
- if params is not None:
- kw['column_keys'] = list(params)
- if literal_binds:
- compile_kwargs['literal_binds'] = True
- if isinstance(clause, orm.Query):
- context = clause._compile_context()
- context.statement.use_labels = True
- clause = context.statement
- if compile_kwargs:
- kw['compile_kwargs'] = compile_kwargs
- c = clause.compile(dialect=dialect, **kw)
- param_str = repr(getattr(c, 'params', {}))
- if util.py3k:
- param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
- print(
- ("\nSQL String:\n" +
- util.text_type(c) +
- param_str).encode('utf-8'))
- else:
- print(
- "\nSQL String:\n" +
- util.text_type(c).encode('utf-8') +
- param_str)
- cc = re.sub(r'[\n\t]', '', util.text_type(c))
- eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
- if checkparams is not None:
- eq_(c.construct_params(params), checkparams)
- if checkpositional is not None:
- p = c.construct_params(params)
- eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
- class ComparesTables(object):
- def assert_tables_equal(self, table, reflected_table, strict_types=False):
- assert len(table.c) == len(reflected_table.c)
- for c, reflected_c in zip(table.c, reflected_table.c):
- eq_(c.name, reflected_c.name)
- assert reflected_c is reflected_table.c[c.name]
- eq_(c.primary_key, reflected_c.primary_key)
- eq_(c.nullable, reflected_c.nullable)
- if strict_types:
- msg = "Type '%s' doesn't correspond to type '%s'"
- assert isinstance(reflected_c.type, type(c.type)), \
- msg % (reflected_c.type, c.type)
- else:
- self.assert_types_base(reflected_c, c)
- if isinstance(c.type, sqltypes.String):
- eq_(c.type.length, reflected_c.type.length)
- eq_(
- set([f.column.name for f in c.foreign_keys]),
- set([f.column.name for f in reflected_c.foreign_keys])
- )
- if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
- assert len(table.primary_key) == len(reflected_table.primary_key)
- for c in table.primary_key:
- assert reflected_table.primary_key.columns[c.name] is not None
- def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
- class AssertsExecutionResults(object):
- def assert_result(self, result, class_, *objects):
- result = list(result)
- print(repr(result))
- self.assert_list(result, class_, objects)
- def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
- for i in range(0, len(list)):
- self.assert_row(class_, result[i], list[i])
- def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
- for key, value in desc.items():
- if isinstance(value, tuple):
- if isinstance(value[1], list):
- self.assert_list(getattr(rowobj, key), value[0], value[1])
- else:
- self.assert_row(value[0], getattr(rowobj, key), value[1])
- else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
- def assert_unordered_result(self, result, cls, *expected):
- """As assert_result, but the order of objects is not considered.
- The algorithm is very expensive but not a big deal for the small
- numbers of rows that the test suite manipulates.
- """
- class immutabledict(dict):
- def __hash__(self):
- return id(self)
- found = util.IdentitySet(result)
- expected = set([immutabledict(e) for e in expected])
- for wrong in util.itertools_filterfalse(lambda o:
- isinstance(o, cls), found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
- if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
- NOVALUE = object()
- def _compare_item(obj, spec):
- for key, value in spec.items():
- if isinstance(value, tuple):
- try:
- self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
- except AssertionError:
- return False
- else:
- if getattr(obj, key, NOVALUE) != value:
- return False
- return True
- for expected_item in expected:
- for found_item in found:
- if _compare_item(found_item, expected_item):
- found.remove(found_item)
- break
- else:
- fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
- return True
- def assert_sql_execution(self, db, callable_, *rules):
- assertsql.asserter.add_rules(rules)
- try:
- callable_()
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
- def assert_sql(self, db, callable_, list_, with_sequences=None):
- if (with_sequences is not None and
- config.db.dialect.supports_sequences):
- rules = with_sequences
- else:
- rules = list_
- newrules = []
- for rule in rules:
- if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.ExactSQL(k, v) for k, v in rule.items()
- ])
- else:
- newrule = assertsql.ExactSQL(*rule)
- newrules.append(newrule)
- self.assert_sql_execution(db, callable_, *newrules)
- def assert_sql_count(self, db, callable_, count):
- self.assert_sql_execution(
- db, callable_, assertsql.CountStatements(count))
- @contextlib.contextmanager
- def assert_execution(self, *rules):
- assertsql.asserter.add_rules(rules)
- try:
- yield
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
- def assert_statement_count(self, count):
- return self.assert_execution(assertsql.CountStatements(count))