PageRenderTime 44ms CodeModel.GetById 13ms RepoModel.GetById 0ms app.codeStats 0ms

/Windows/Python3.8/WPy64-3830/WPy64-3830/python-3.8.3.amd64/Lib/site-packages/sqlalchemy/testing/assertsql.py

https://gitlab.com/abhi1tb/build
Python | 419 lines | 397 code | 16 blank | 6 comment | 9 complexity | 252352d7db623ef5dbc9189bbea55718 MD5 | raw file
  1. # testing/assertsql.py
  2. # Copyright (C) 2005-2020 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import collections
  8. import contextlib
  9. import re
  10. from .. import event
  11. from .. import util
  12. from ..engine import url
  13. from ..engine.default import DefaultDialect
  14. from ..engine.util import _distill_params
  15. from ..schema import _DDLCompiles
  16. class AssertRule(object):
  17. is_consumed = False
  18. errormessage = None
  19. consume_statement = True
  20. def process_statement(self, execute_observed):
  21. pass
  22. def no_more_statements(self):
  23. assert False, (
  24. "All statements are complete, but pending "
  25. "assertion rules remain"
  26. )
  27. class SQLMatchRule(AssertRule):
  28. pass
  29. class CursorSQL(SQLMatchRule):
  30. consume_statement = False
  31. def __init__(self, statement, params=None):
  32. self.statement = statement
  33. self.params = params
  34. def process_statement(self, execute_observed):
  35. stmt = execute_observed.statements[0]
  36. if self.statement != stmt.statement or (
  37. self.params is not None and self.params != stmt.parameters
  38. ):
  39. self.errormessage = (
  40. "Testing for exact SQL %s parameters %s received %s %s"
  41. % (
  42. self.statement,
  43. self.params,
  44. stmt.statement,
  45. stmt.parameters,
  46. )
  47. )
  48. else:
  49. execute_observed.statements.pop(0)
  50. self.is_consumed = True
  51. if not execute_observed.statements:
  52. self.consume_statement = True
  53. class CompiledSQL(SQLMatchRule):
  54. def __init__(self, statement, params=None, dialect="default"):
  55. self.statement = statement
  56. self.params = params
  57. self.dialect = dialect
  58. def _compare_sql(self, execute_observed, received_statement):
  59. stmt = re.sub(r"[\n\t]", "", self.statement)
  60. return received_statement == stmt
  61. def _compile_dialect(self, execute_observed):
  62. if self.dialect == "default":
  63. return DefaultDialect()
  64. else:
  65. # ugh
  66. if self.dialect == "postgresql":
  67. params = {"implicit_returning": True}
  68. else:
  69. params = {}
  70. return url.URL(self.dialect).get_dialect()(**params)
  71. def _received_statement(self, execute_observed):
  72. """reconstruct the statement and params in terms
  73. of a target dialect, which for CompiledSQL is just DefaultDialect."""
  74. context = execute_observed.context
  75. compare_dialect = self._compile_dialect(execute_observed)
  76. if isinstance(context.compiled.statement, _DDLCompiles):
  77. compiled = context.compiled.statement.compile(
  78. dialect=compare_dialect,
  79. schema_translate_map=context.execution_options.get(
  80. "schema_translate_map"
  81. ),
  82. )
  83. else:
  84. compiled = context.compiled.statement.compile(
  85. dialect=compare_dialect,
  86. column_keys=context.compiled.column_keys,
  87. inline=context.compiled.inline,
  88. schema_translate_map=context.execution_options.get(
  89. "schema_translate_map"
  90. ),
  91. )
  92. _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
  93. parameters = execute_observed.parameters
  94. if not parameters:
  95. _received_parameters = [compiled.construct_params()]
  96. else:
  97. _received_parameters = [
  98. compiled.construct_params(m) for m in parameters
  99. ]
  100. return _received_statement, _received_parameters
  101. def process_statement(self, execute_observed):
  102. context = execute_observed.context
  103. _received_statement, _received_parameters = self._received_statement(
  104. execute_observed
  105. )
  106. params = self._all_params(context)
  107. equivalent = self._compare_sql(execute_observed, _received_statement)
  108. if equivalent:
  109. if params is not None:
  110. all_params = list(params)
  111. all_received = list(_received_parameters)
  112. while all_params and all_received:
  113. param = dict(all_params.pop(0))
  114. for idx, received in enumerate(list(all_received)):
  115. # do a positive compare only
  116. for param_key in param:
  117. # a key in param did not match current
  118. # 'received'
  119. if (
  120. param_key not in received
  121. or received[param_key] != param[param_key]
  122. ):
  123. break
  124. else:
  125. # all keys in param matched 'received';
  126. # onto next param
  127. del all_received[idx]
  128. break
  129. else:
  130. # param did not match any entry
  131. # in all_received
  132. equivalent = False
  133. break
  134. if all_params or all_received:
  135. equivalent = False
  136. if equivalent:
  137. self.is_consumed = True
  138. self.errormessage = None
  139. else:
  140. self.errormessage = self._failure_message(params) % {
  141. "received_statement": _received_statement,
  142. "received_parameters": _received_parameters,
  143. }
  144. def _all_params(self, context):
  145. if self.params:
  146. if util.callable(self.params):
  147. params = self.params(context)
  148. else:
  149. params = self.params
  150. if not isinstance(params, list):
  151. params = [params]
  152. return params
  153. else:
  154. return None
  155. def _failure_message(self, expected_params):
  156. return (
  157. "Testing for compiled statement %r partial params %s, "
  158. "received %%(received_statement)r with params "
  159. "%%(received_parameters)r"
  160. % (
  161. self.statement.replace("%", "%%"),
  162. repr(expected_params).replace("%", "%%"),
  163. )
  164. )
  165. class RegexSQL(CompiledSQL):
  166. def __init__(self, regex, params=None):
  167. SQLMatchRule.__init__(self)
  168. self.regex = re.compile(regex)
  169. self.orig_regex = regex
  170. self.params = params
  171. self.dialect = "default"
  172. def _failure_message(self, expected_params):
  173. return (
  174. "Testing for compiled statement ~%r partial params %s, "
  175. "received %%(received_statement)r with params "
  176. "%%(received_parameters)r"
  177. % (
  178. self.orig_regex.replace("%", "%%"),
  179. repr(expected_params).replace("%", "%%"),
  180. )
  181. )
  182. def _compare_sql(self, execute_observed, received_statement):
  183. return bool(self.regex.match(received_statement))
  184. class DialectSQL(CompiledSQL):
  185. def _compile_dialect(self, execute_observed):
  186. return execute_observed.context.dialect
  187. def _compare_no_space(self, real_stmt, received_stmt):
  188. stmt = re.sub(r"[\n\t]", "", real_stmt)
  189. return received_stmt == stmt
  190. def _received_statement(self, execute_observed):
  191. received_stmt, received_params = super(
  192. DialectSQL, self
  193. )._received_statement(execute_observed)
  194. # TODO: why do we need this part?
  195. for real_stmt in execute_observed.statements:
  196. if self._compare_no_space(real_stmt.statement, received_stmt):
  197. break
  198. else:
  199. raise AssertionError(
  200. "Can't locate compiled statement %r in list of "
  201. "statements actually invoked" % received_stmt
  202. )
  203. return received_stmt, execute_observed.context.compiled_parameters
  204. def _compare_sql(self, execute_observed, received_statement):
  205. stmt = re.sub(r"[\n\t]", "", self.statement)
  206. # convert our comparison statement to have the
  207. # paramstyle of the received
  208. paramstyle = execute_observed.context.dialect.paramstyle
  209. if paramstyle == "pyformat":
  210. stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
  211. else:
  212. # positional params
  213. repl = None
  214. if paramstyle == "qmark":
  215. repl = "?"
  216. elif paramstyle == "format":
  217. repl = r"%s"
  218. elif paramstyle == "numeric":
  219. repl = None
  220. stmt = re.sub(r":([\w_]+)", repl, stmt)
  221. return received_statement == stmt
  222. class CountStatements(AssertRule):
  223. def __init__(self, count):
  224. self.count = count
  225. self._statement_count = 0
  226. def process_statement(self, execute_observed):
  227. self._statement_count += 1
  228. def no_more_statements(self):
  229. if self.count != self._statement_count:
  230. assert False, "desired statement count %d does not match %d" % (
  231. self.count,
  232. self._statement_count,
  233. )
  234. class AllOf(AssertRule):
  235. def __init__(self, *rules):
  236. self.rules = set(rules)
  237. def process_statement(self, execute_observed):
  238. for rule in list(self.rules):
  239. rule.errormessage = None
  240. rule.process_statement(execute_observed)
  241. if rule.is_consumed:
  242. self.rules.discard(rule)
  243. if not self.rules:
  244. self.is_consumed = True
  245. break
  246. elif not rule.errormessage:
  247. # rule is not done yet
  248. self.errormessage = None
  249. break
  250. else:
  251. self.errormessage = list(self.rules)[0].errormessage
  252. class EachOf(AssertRule):
  253. def __init__(self, *rules):
  254. self.rules = list(rules)
  255. def process_statement(self, execute_observed):
  256. while self.rules:
  257. rule = self.rules[0]
  258. rule.process_statement(execute_observed)
  259. if rule.is_consumed:
  260. self.rules.pop(0)
  261. elif rule.errormessage:
  262. self.errormessage = rule.errormessage
  263. if rule.consume_statement:
  264. break
  265. if not self.rules:
  266. self.is_consumed = True
  267. def no_more_statements(self):
  268. if self.rules and not self.rules[0].is_consumed:
  269. self.rules[0].no_more_statements()
  270. elif self.rules:
  271. super(EachOf, self).no_more_statements()
  272. class Or(AllOf):
  273. def process_statement(self, execute_observed):
  274. for rule in self.rules:
  275. rule.process_statement(execute_observed)
  276. if rule.is_consumed:
  277. self.is_consumed = True
  278. break
  279. else:
  280. self.errormessage = list(self.rules)[0].errormessage
  281. class SQLExecuteObserved(object):
  282. def __init__(self, context, clauseelement, multiparams, params):
  283. self.context = context
  284. self.clauseelement = clauseelement
  285. self.parameters = _distill_params(multiparams, params)
  286. self.statements = []
  287. class SQLCursorExecuteObserved(
  288. collections.namedtuple(
  289. "SQLCursorExecuteObserved",
  290. ["statement", "parameters", "context", "executemany"],
  291. )
  292. ):
  293. pass
  294. class SQLAsserter(object):
  295. def __init__(self):
  296. self.accumulated = []
  297. def _close(self):
  298. self._final = self.accumulated
  299. del self.accumulated
  300. def assert_(self, *rules):
  301. rule = EachOf(*rules)
  302. observed = list(self._final)
  303. while observed:
  304. statement = observed.pop(0)
  305. rule.process_statement(statement)
  306. if rule.is_consumed:
  307. break
  308. elif rule.errormessage:
  309. assert False, rule.errormessage
  310. if observed:
  311. assert False, "Additional SQL statements remain"
  312. elif not rule.is_consumed:
  313. rule.no_more_statements()
  314. @contextlib.contextmanager
  315. def assert_engine(engine):
  316. asserter = SQLAsserter()
  317. orig = []
  318. @event.listens_for(engine, "before_execute")
  319. def connection_execute(conn, clauseelement, multiparams, params):
  320. # grab the original statement + params before any cursor
  321. # execution
  322. orig[:] = clauseelement, multiparams, params
  323. @event.listens_for(engine, "after_cursor_execute")
  324. def cursor_execute(
  325. conn, cursor, statement, parameters, context, executemany
  326. ):
  327. if not context:
  328. return
  329. # then grab real cursor statements and associate them all
  330. # around a single context
  331. if (
  332. asserter.accumulated
  333. and asserter.accumulated[-1].context is context
  334. ):
  335. obs = asserter.accumulated[-1]
  336. else:
  337. obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
  338. asserter.accumulated.append(obs)
  339. obs.statements.append(
  340. SQLCursorExecuteObserved(
  341. statement, parameters, context, executemany
  342. )
  343. )
  344. try:
  345. yield asserter
  346. finally:
  347. event.remove(engine, "after_cursor_execute", cursor_execute)
  348. event.remove(engine, "before_execute", connection_execute)
  349. asserter._close()