assertsql.py 15 KB


  1. # testing/assertsql.py
  2. # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. import collections
  8. import contextlib
  9. import re
  10. from .. import event
  11. from .. import util
  12. from ..engine import url
  13. from ..engine.default import DefaultDialect
  14. from ..engine.util import _distill_cursor_params
  15. from ..schema import _DDLCompiles
  16. class AssertRule(object):
  17. is_consumed = False
  18. errormessage = None
  19. consume_statement = True
  20. def process_statement(self, execute_observed):
  21. pass
  22. def no_more_statements(self):
  23. assert False, (
  24. "All statements are complete, but pending "
  25. "assertion rules remain"
  26. )
  27. class SQLMatchRule(AssertRule):
  28. pass
  29. class CursorSQL(SQLMatchRule):
  30. def __init__(self, statement, params=None, consume_statement=True):
  31. self.statement = statement
  32. self.params = params
  33. self.consume_statement = consume_statement
  34. def process_statement(self, execute_observed):
  35. stmt = execute_observed.statements[0]
  36. if self.statement != stmt.statement or (
  37. self.params is not None and self.params != stmt.parameters
  38. ):
  39. self.errormessage = (
  40. "Testing for exact SQL %s parameters %s received %s %s"
  41. % (
  42. self.statement,
  43. self.params,
  44. stmt.statement,
  45. stmt.parameters,
  46. )
  47. )
  48. else:
  49. execute_observed.statements.pop(0)
  50. self.is_consumed = True
  51. if not execute_observed.statements:
  52. self.consume_statement = True
  53. class CompiledSQL(SQLMatchRule):
  54. def __init__(self, statement, params=None, dialect="default"):
  55. self.statement = statement
  56. self.params = params
  57. self.dialect = dialect
  58. def _compare_sql(self, execute_observed, received_statement):
  59. stmt = re.sub(r"[\n\t]", "", self.statement)
  60. return received_statement == stmt
  61. def _compile_dialect(self, execute_observed):
  62. if self.dialect == "default":
  63. dialect = DefaultDialect()
  64. # this is currently what tests are expecting
  65. # dialect.supports_default_values = True
  66. dialect.supports_default_metavalue = True
  67. return dialect
  68. else:
  69. # ugh
  70. if self.dialect == "postgresql":
  71. params = {"implicit_returning": True}
  72. else:
  73. params = {}
  74. return url.URL.create(self.dialect).get_dialect()(**params)
  75. def _received_statement(self, execute_observed):
  76. """reconstruct the statement and params in terms
  77. of a target dialect, which for CompiledSQL is just DefaultDialect."""
  78. context = execute_observed.context
  79. compare_dialect = self._compile_dialect(execute_observed)
  80. # received_statement runs a full compile(). we should not need to
  81. # consider extracted_parameters; if we do this indicates some state
  82. # is being sent from a previous cached query, which some misbehaviors
  83. # in the ORM can cause, see #6881
  84. cache_key = None # execute_observed.context.compiled.cache_key
  85. extracted_parameters = (
  86. None # execute_observed.context.extracted_parameters
  87. )
  88. if "schema_translate_map" in context.execution_options:
  89. map_ = context.execution_options["schema_translate_map"]
  90. else:
  91. map_ = None
  92. if isinstance(execute_observed.clauseelement, _DDLCompiles):
  93. compiled = execute_observed.clauseelement.compile(
  94. dialect=compare_dialect,
  95. schema_translate_map=map_,
  96. )
  97. else:
  98. compiled = execute_observed.clauseelement.compile(
  99. cache_key=cache_key,
  100. dialect=compare_dialect,
  101. column_keys=context.compiled.column_keys,
  102. for_executemany=context.compiled.for_executemany,
  103. schema_translate_map=map_,
  104. )
  105. _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
  106. parameters = execute_observed.parameters
  107. if not parameters:
  108. _received_parameters = [
  109. compiled.construct_params(
  110. extracted_parameters=extracted_parameters
  111. )
  112. ]
  113. else:
  114. _received_parameters = [
  115. compiled.construct_params(
  116. m, extracted_parameters=extracted_parameters
  117. )
  118. for m in parameters
  119. ]
  120. return _received_statement, _received_parameters
  121. def process_statement(self, execute_observed):
  122. context = execute_observed.context
  123. _received_statement, _received_parameters = self._received_statement(
  124. execute_observed
  125. )
  126. params = self._all_params(context)
  127. equivalent = self._compare_sql(execute_observed, _received_statement)
  128. if equivalent:
  129. if params is not None:
  130. all_params = list(params)
  131. all_received = list(_received_parameters)
  132. while all_params and all_received:
  133. param = dict(all_params.pop(0))
  134. for idx, received in enumerate(list(all_received)):
  135. # do a positive compare only
  136. for param_key in param:
  137. # a key in param did not match current
  138. # 'received'
  139. if (
  140. param_key not in received
  141. or received[param_key] != param[param_key]
  142. ):
  143. break
  144. else:
  145. # all keys in param matched 'received';
  146. # onto next param
  147. del all_received[idx]
  148. break
  149. else:
  150. # param did not match any entry
  151. # in all_received
  152. equivalent = False
  153. break
  154. if all_params or all_received:
  155. equivalent = False
  156. if equivalent:
  157. self.is_consumed = True
  158. self.errormessage = None
  159. else:
  160. self.errormessage = self._failure_message(params) % {
  161. "received_statement": _received_statement,
  162. "received_parameters": _received_parameters,
  163. }
  164. def _all_params(self, context):
  165. if self.params:
  166. if callable(self.params):
  167. params = self.params(context)
  168. else:
  169. params = self.params
  170. if not isinstance(params, list):
  171. params = [params]
  172. return params
  173. else:
  174. return None
  175. def _failure_message(self, expected_params):
  176. return (
  177. "Testing for compiled statement\n%r partial params %s, "
  178. "received\n%%(received_statement)r with params "
  179. "%%(received_parameters)r"
  180. % (
  181. self.statement.replace("%", "%%"),
  182. repr(expected_params).replace("%", "%%"),
  183. )
  184. )
  185. class RegexSQL(CompiledSQL):
  186. def __init__(self, regex, params=None, dialect="default"):
  187. SQLMatchRule.__init__(self)
  188. self.regex = re.compile(regex)
  189. self.orig_regex = regex
  190. self.params = params
  191. self.dialect = dialect
  192. def _failure_message(self, expected_params):
  193. return (
  194. "Testing for compiled statement ~%r partial params %s, "
  195. "received %%(received_statement)r with params "
  196. "%%(received_parameters)r"
  197. % (
  198. self.orig_regex.replace("%", "%%"),
  199. repr(expected_params).replace("%", "%%"),
  200. )
  201. )
  202. def _compare_sql(self, execute_observed, received_statement):
  203. return bool(self.regex.match(received_statement))
  204. class DialectSQL(CompiledSQL):
  205. def _compile_dialect(self, execute_observed):
  206. return execute_observed.context.dialect
  207. def _compare_no_space(self, real_stmt, received_stmt):
  208. stmt = re.sub(r"[\n\t]", "", real_stmt)
  209. return received_stmt == stmt
  210. def _received_statement(self, execute_observed):
  211. received_stmt, received_params = super(
  212. DialectSQL, self
  213. )._received_statement(execute_observed)
  214. # TODO: why do we need this part?
  215. for real_stmt in execute_observed.statements:
  216. if self._compare_no_space(real_stmt.statement, received_stmt):
  217. break
  218. else:
  219. raise AssertionError(
  220. "Can't locate compiled statement %r in list of "
  221. "statements actually invoked" % received_stmt
  222. )
  223. return received_stmt, execute_observed.context.compiled_parameters
  224. def _compare_sql(self, execute_observed, received_statement):
  225. stmt = re.sub(r"[\n\t]", "", self.statement)
  226. # convert our comparison statement to have the
  227. # paramstyle of the received
  228. paramstyle = execute_observed.context.dialect.paramstyle
  229. if paramstyle == "pyformat":
  230. stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
  231. else:
  232. # positional params
  233. repl = None
  234. if paramstyle == "qmark":
  235. repl = "?"
  236. elif paramstyle == "format":
  237. repl = r"%s"
  238. elif paramstyle == "numeric":
  239. repl = None
  240. stmt = re.sub(r":([\w_]+)", repl, stmt)
  241. return received_statement == stmt
  242. class CountStatements(AssertRule):
  243. def __init__(self, count):
  244. self.count = count
  245. self._statement_count = 0
  246. def process_statement(self, execute_observed):
  247. self._statement_count += 1
  248. def no_more_statements(self):
  249. if self.count != self._statement_count:
  250. assert False, "desired statement count %d does not match %d" % (
  251. self.count,
  252. self._statement_count,
  253. )
  254. class AllOf(AssertRule):
  255. def __init__(self, *rules):
  256. self.rules = set(rules)
  257. def process_statement(self, execute_observed):
  258. for rule in list(self.rules):
  259. rule.errormessage = None
  260. rule.process_statement(execute_observed)
  261. if rule.is_consumed:
  262. self.rules.discard(rule)
  263. if not self.rules:
  264. self.is_consumed = True
  265. break
  266. elif not rule.errormessage:
  267. # rule is not done yet
  268. self.errormessage = None
  269. break
  270. else:
  271. self.errormessage = list(self.rules)[0].errormessage
  272. class EachOf(AssertRule):
  273. def __init__(self, *rules):
  274. self.rules = list(rules)
  275. def process_statement(self, execute_observed):
  276. while self.rules:
  277. rule = self.rules[0]
  278. rule.process_statement(execute_observed)
  279. if rule.is_consumed:
  280. self.rules.pop(0)
  281. elif rule.errormessage:
  282. self.errormessage = rule.errormessage
  283. if rule.consume_statement:
  284. break
  285. if not self.rules:
  286. self.is_consumed = True
  287. def no_more_statements(self):
  288. if self.rules and not self.rules[0].is_consumed:
  289. self.rules[0].no_more_statements()
  290. elif self.rules:
  291. super(EachOf, self).no_more_statements()
  292. class Conditional(EachOf):
  293. def __init__(self, condition, rules, else_rules):
  294. if condition:
  295. super(Conditional, self).__init__(*rules)
  296. else:
  297. super(Conditional, self).__init__(*else_rules)
  298. class Or(AllOf):
  299. def process_statement(self, execute_observed):
  300. for rule in self.rules:
  301. rule.process_statement(execute_observed)
  302. if rule.is_consumed:
  303. self.is_consumed = True
  304. break
  305. else:
  306. self.errormessage = list(self.rules)[0].errormessage
  307. class SQLExecuteObserved(object):
  308. def __init__(self, context, clauseelement, multiparams, params):
  309. self.context = context
  310. self.clauseelement = clauseelement
  311. self.parameters = _distill_cursor_params(
  312. context.connection, tuple(multiparams), params
  313. )
  314. self.statements = []
  315. def __repr__(self):
  316. return str(self.statements)
  317. class SQLCursorExecuteObserved(
  318. collections.namedtuple(
  319. "SQLCursorExecuteObserved",
  320. ["statement", "parameters", "context", "executemany"],
  321. )
  322. ):
  323. pass
  324. class SQLAsserter(object):
  325. def __init__(self):
  326. self.accumulated = []
  327. def _close(self):
  328. self._final = self.accumulated
  329. del self.accumulated
  330. def assert_(self, *rules):
  331. rule = EachOf(*rules)
  332. observed = list(self._final)
  333. while observed:
  334. statement = observed.pop(0)
  335. rule.process_statement(statement)
  336. if rule.is_consumed:
  337. break
  338. elif rule.errormessage:
  339. assert False, rule.errormessage
  340. if observed:
  341. assert False, "Additional SQL statements remain:\n%s" % observed
  342. elif not rule.is_consumed:
  343. rule.no_more_statements()
  344. @contextlib.contextmanager
  345. def assert_engine(engine):
  346. asserter = SQLAsserter()
  347. orig = []
  348. @event.listens_for(engine, "before_execute")
  349. def connection_execute(
  350. conn, clauseelement, multiparams, params, execution_options
  351. ):
  352. # grab the original statement + params before any cursor
  353. # execution
  354. orig[:] = clauseelement, multiparams, params
  355. @event.listens_for(engine, "after_cursor_execute")
  356. def cursor_execute(
  357. conn, cursor, statement, parameters, context, executemany
  358. ):
  359. if not context:
  360. return
  361. # then grab real cursor statements and associate them all
  362. # around a single context
  363. if (
  364. asserter.accumulated
  365. and asserter.accumulated[-1].context is context
  366. ):
  367. obs = asserter.accumulated[-1]
  368. else:
  369. obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
  370. asserter.accumulated.append(obs)
  371. obs.statements.append(
  372. SQLCursorExecuteObserved(
  373. statement, parameters, context, executemany
  374. )
  375. )
  376. try:
  377. yield asserter
  378. finally:
  379. event.remove(engine, "after_cursor_execute", cursor_execute)
  380. event.remove(engine, "before_execute", connection_execute)
  381. asserter._close()