/Windows/Python3.8/WPy64-3830/WPy64-3830/python-3.8.3.amd64/Lib/site-packages/sqlalchemy/testing/assertsql.py
Python | 419 lines | 397 code | 16 blank | 6 comment | 9 complexity | 252352d7db623ef5dbc9189bbea55718 MD5 | raw file
- # testing/assertsql.py
- # Copyright (C) 2005-2020 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: http://www.opensource.org/licenses/mit-license.php
-
- import collections
- import contextlib
- import re
-
- from .. import event
- from .. import util
- from ..engine import url
- from ..engine.default import DefaultDialect
- from ..engine.util import _distill_params
- from ..schema import _DDLCompiles
-
-
- class AssertRule(object):
-
- is_consumed = False
- errormessage = None
- consume_statement = True
-
- def process_statement(self, execute_observed):
- pass
-
- def no_more_statements(self):
- assert False, (
- "All statements are complete, but pending "
- "assertion rules remain"
- )
-
-
- class SQLMatchRule(AssertRule):
- pass
-
-
- class CursorSQL(SQLMatchRule):
- consume_statement = False
-
- def __init__(self, statement, params=None):
- self.statement = statement
- self.params = params
-
- def process_statement(self, execute_observed):
- stmt = execute_observed.statements[0]
- if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters
- ):
- self.errormessage = (
- "Testing for exact SQL %s parameters %s received %s %s"
- % (
- self.statement,
- self.params,
- stmt.statement,
- stmt.parameters,
- )
- )
- else:
- execute_observed.statements.pop(0)
- self.is_consumed = True
- if not execute_observed.statements:
- self.consume_statement = True
-
-
- class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params=None, dialect="default"):
- self.statement = statement
- self.params = params
- self.dialect = dialect
-
- def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r"[\n\t]", "", self.statement)
- return received_statement == stmt
-
- def _compile_dialect(self, execute_observed):
- if self.dialect == "default":
- return DefaultDialect()
- else:
- # ugh
- if self.dialect == "postgresql":
- params = {"implicit_returning": True}
- else:
- params = {}
- return url.URL(self.dialect).get_dialect()(**params)
-
- def _received_statement(self, execute_observed):
- """reconstruct the statement and params in terms
- of a target dialect, which for CompiledSQL is just DefaultDialect."""
-
- context = execute_observed.context
- compare_dialect = self._compile_dialect(execute_observed)
- if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.execution_options.get(
- "schema_translate_map"
- ),
- )
- else:
- compiled = context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.execution_options.get(
- "schema_translate_map"
- ),
- )
- _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
- parameters = execute_observed.parameters
-
- if not parameters:
- _received_parameters = [compiled.construct_params()]
- else:
- _received_parameters = [
- compiled.construct_params(m) for m in parameters
- ]
-
- return _received_statement, _received_parameters
-
- def process_statement(self, execute_observed):
- context = execute_observed.context
-
- _received_statement, _received_parameters = self._received_statement(
- execute_observed
- )
- params = self._all_params(context)
-
- equivalent = self._compare_sql(execute_observed, _received_statement)
-
- if equivalent:
- if params is not None:
- all_params = list(params)
- all_received = list(_received_parameters)
- while all_params and all_received:
- param = dict(all_params.pop(0))
-
- for idx, received in enumerate(list(all_received)):
- # do a positive compare only
- for param_key in param:
- # a key in param did not match current
- # 'received'
- if (
- param_key not in received
- or received[param_key] != param[param_key]
- ):
- break
- else:
- # all keys in param matched 'received';
- # onto next param
- del all_received[idx]
- break
- else:
- # param did not match any entry
- # in all_received
- equivalent = False
- break
- if all_params or all_received:
- equivalent = False
-
- if equivalent:
- self.is_consumed = True
- self.errormessage = None
- else:
- self.errormessage = self._failure_message(params) % {
- "received_statement": _received_statement,
- "received_parameters": _received_parameters,
- }
-
- def _all_params(self, context):
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- return params
- else:
- return None
-
- def _failure_message(self, expected_params):
- return (
- "Testing for compiled statement %r partial params %s, "
- "received %%(received_statement)r with params "
- "%%(received_parameters)r"
- % (
- self.statement.replace("%", "%%"),
- repr(expected_params).replace("%", "%%"),
- )
- )
-
-
- class RegexSQL(CompiledSQL):
- def __init__(self, regex, params=None):
- SQLMatchRule.__init__(self)
- self.regex = re.compile(regex)
- self.orig_regex = regex
- self.params = params
- self.dialect = "default"
-
- def _failure_message(self, expected_params):
- return (
- "Testing for compiled statement ~%r partial params %s, "
- "received %%(received_statement)r with params "
- "%%(received_parameters)r"
- % (
- self.orig_regex.replace("%", "%%"),
- repr(expected_params).replace("%", "%%"),
- )
- )
-
- def _compare_sql(self, execute_observed, received_statement):
- return bool(self.regex.match(received_statement))
-
-
- class DialectSQL(CompiledSQL):
- def _compile_dialect(self, execute_observed):
- return execute_observed.context.dialect
-
- def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r"[\n\t]", "", real_stmt)
- return received_stmt == stmt
-
- def _received_statement(self, execute_observed):
- received_stmt, received_params = super(
- DialectSQL, self
- )._received_statement(execute_observed)
-
- # TODO: why do we need this part?
- for real_stmt in execute_observed.statements:
- if self._compare_no_space(real_stmt.statement, received_stmt):
- break
- else:
- raise AssertionError(
- "Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt
- )
-
- return received_stmt, execute_observed.context.compiled_parameters
-
- def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r"[\n\t]", "", self.statement)
- # convert our comparison statement to have the
- # paramstyle of the received
- paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == "pyformat":
- stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
- else:
- # positional params
- repl = None
- if paramstyle == "qmark":
- repl = "?"
- elif paramstyle == "format":
- repl = r"%s"
- elif paramstyle == "numeric":
- repl = None
- stmt = re.sub(r":([\w_]+)", repl, stmt)
-
- return received_statement == stmt
-
-
- class CountStatements(AssertRule):
- def __init__(self, count):
- self.count = count
- self._statement_count = 0
-
- def process_statement(self, execute_observed):
- self._statement_count += 1
-
- def no_more_statements(self):
- if self.count != self._statement_count:
- assert False, "desired statement count %d does not match %d" % (
- self.count,
- self._statement_count,
- )
-
-
- class AllOf(AssertRule):
- def __init__(self, *rules):
- self.rules = set(rules)
-
- def process_statement(self, execute_observed):
- for rule in list(self.rules):
- rule.errormessage = None
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.rules.discard(rule)
- if not self.rules:
- self.is_consumed = True
- break
- elif not rule.errormessage:
- # rule is not done yet
- self.errormessage = None
- break
- else:
- self.errormessage = list(self.rules)[0].errormessage
-
-
- class EachOf(AssertRule):
- def __init__(self, *rules):
- self.rules = list(rules)
-
- def process_statement(self, execute_observed):
- while self.rules:
- rule = self.rules[0]
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.rules.pop(0)
- elif rule.errormessage:
- self.errormessage = rule.errormessage
- if rule.consume_statement:
- break
-
- if not self.rules:
- self.is_consumed = True
-
- def no_more_statements(self):
- if self.rules and not self.rules[0].is_consumed:
- self.rules[0].no_more_statements()
- elif self.rules:
- super(EachOf, self).no_more_statements()
-
-
- class Or(AllOf):
- def process_statement(self, execute_observed):
- for rule in self.rules:
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.is_consumed = True
- break
- else:
- self.errormessage = list(self.rules)[0].errormessage
-
-
- class SQLExecuteObserved(object):
- def __init__(self, context, clauseelement, multiparams, params):
- self.context = context
- self.clauseelement = clauseelement
- self.parameters = _distill_params(multiparams, params)
- self.statements = []
-
-
- class SQLCursorExecuteObserved(
- collections.namedtuple(
- "SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"],
- )
- ):
- pass
-
-
- class SQLAsserter(object):
- def __init__(self):
- self.accumulated = []
-
- def _close(self):
- self._final = self.accumulated
- del self.accumulated
-
- def assert_(self, *rules):
- rule = EachOf(*rules)
-
- observed = list(self._final)
- while observed:
- statement = observed.pop(0)
- rule.process_statement(statement)
- if rule.is_consumed:
- break
- elif rule.errormessage:
- assert False, rule.errormessage
- if observed:
- assert False, "Additional SQL statements remain"
- elif not rule.is_consumed:
- rule.no_more_statements()
-
-
- @contextlib.contextmanager
- def assert_engine(engine):
- asserter = SQLAsserter()
-
- orig = []
-
- @event.listens_for(engine, "before_execute")
- def connection_execute(conn, clauseelement, multiparams, params):
- # grab the original statement + params before any cursor
- # execution
- orig[:] = clauseelement, multiparams, params
-
- @event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(
- conn, cursor, statement, parameters, context, executemany
- ):
- if not context:
- return
- # then grab real cursor statements and associate them all
- # around a single context
- if (
- asserter.accumulated
- and asserter.accumulated[-1].context is context
- ):
- obs = asserter.accumulated[-1]
- else:
- obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
- asserter.accumulated.append(obs)
- obs.statements.append(
- SQLCursorExecuteObserved(
- statement, parameters, context, executemany
- )
- )
-
- try:
- yield asserter
- finally:
- event.remove(engine, "after_cursor_execute", cursor_execute)
- event.remove(engine, "before_execute", connection_execute)
- asserter._close()