2
from sqlalchemy.interfaces import ConnectionProxy
3
from sqlalchemy.engine.default import DefaultDialect
4
from sqlalchemy.engine.base import Connection
5
from sqlalchemy import util
8
class AssertRule(object):
10
def process_execute(self, clauseelement, *multiparams, **params):
13
def process_cursor_execute(self, statement, parameters, context,
17
def is_consumed(self):
18
"""Return True if this rule has been consumed, False if not.
20
Should raise an AssertionError if this rule's condition has
25
raise NotImplementedError()
27
def rule_passed(self):
28
"""Return True if the last test of this rule passed, False if
29
failed, None if no test was applied."""
31
raise NotImplementedError()
33
def consume_final(self):
34
"""Return True if this rule has been consumed.
36
Should raise an AssertionError if this rule's condition has not
37
been consumed or has failed.
41
if self._result is None:
42
assert False, 'Rule has not been consumed'
43
return self.is_consumed()
45
class SQLMatchRule(AssertRule):
50
def rule_passed(self):
53
def is_consumed(self):
54
if self._result is None:
57
assert self._result, self._errmsg
61
class ExactSQL(SQLMatchRule):
63
def __init__(self, sql, params=None):
64
SQLMatchRule.__init__(self)
68
def process_cursor_execute(self, statement, parameters, context,
72
_received_statement = \
73
_process_engine_statement(context.unicode_statement,
75
_received_parameters = context.compiled_parameters
77
# TODO: remove this step once all unit tests are migrated, as
78
# ExactSQL should really be *exact* SQL
80
sql = _process_assertion_statement(self.sql, context)
81
equivalent = _received_statement == sql
83
if util.callable(self.params):
84
params = self.params(context)
87
if not isinstance(params, list):
89
equivalent = equivalent and params \
90
== context.compiled_parameters
93
self._result = equivalent
96
'Testing for exact statement %r exact params %r, '\
97
'received %r with params %r' % (sql, params,
98
_received_statement, _received_parameters)
101
class RegexSQL(SQLMatchRule):
103
def __init__(self, regex, params=None):
104
SQLMatchRule.__init__(self)
105
self.regex = re.compile(regex)
106
self.orig_regex = regex
109
def process_cursor_execute(self, statement, parameters, context,
113
_received_statement = \
114
_process_engine_statement(context.unicode_statement,
116
_received_parameters = context.compiled_parameters
117
equivalent = bool(self.regex.match(_received_statement))
119
if util.callable(self.params):
120
params = self.params(context)
123
if not isinstance(params, list):
126
# do a positive compare only
128
for param, received in zip(params, _received_parameters):
129
for k, v in param.iteritems():
130
if k not in received or received[k] != v:
135
self._result = equivalent
138
'Testing for regex %r partial params %r, received %r '\
139
'with params %r' % (self.orig_regex, params,
141
_received_parameters)
143
class CompiledSQL(SQLMatchRule):
145
def __init__(self, statement, params):
146
SQLMatchRule.__init__(self)
147
self.statement = statement
150
def process_cursor_execute(self, statement, parameters, context,
154
_received_parameters = list(context.compiled_parameters)
156
# recompile from the context, using the default dialect
159
context.compiled.statement.compile(dialect=DefaultDialect(),
160
column_keys=context.compiled.column_keys)
161
_received_statement = re.sub(r'\n', '', str(compiled))
162
equivalent = self.statement == _received_statement
164
if util.callable(self.params):
165
params = self.params(context)
168
if not isinstance(params, list):
170
all_params = list(params)
171
all_received = list(_received_parameters)
173
param = dict(params.pop(0))
174
for k, v in context.compiled.params.iteritems():
175
param.setdefault(k, v)
176
if param not in _received_parameters:
180
_received_parameters.remove(param)
181
if _received_parameters:
187
self._result = equivalent
189
print 'Testing for compiled statement %r partial params '\
190
'%r, received %r with params %r' % (self.statement,
191
all_params, _received_statement, all_received)
193
'Testing for compiled statement %r partial params %r, '\
194
'received %r with params %r' % (self.statement,
195
all_params, _received_statement, all_received)
200
class CountStatements(AssertRule):
202
def __init__(self, count):
204
self._statement_count = 0
206
def process_execute(self, clauseelement, *multiparams, **params):
207
self._statement_count += 1
209
def process_cursor_execute(self, statement, parameters, context,
213
def is_consumed(self):
216
def consume_final(self):
217
assert self.count == self._statement_count, \
218
'desired statement count %d does not match %d' \
219
% (self.count, self._statement_count)
222
class AllOf(AssertRule):
224
def __init__(self, *rules):
225
self.rules = set(rules)
227
def process_execute(self, clauseelement, *multiparams, **params):
228
for rule in self.rules:
229
rule.process_execute(clauseelement, *multiparams, **params)
231
def process_cursor_execute(self, statement, parameters, context,
233
for rule in self.rules:
234
rule.process_cursor_execute(statement, parameters, context,
237
def is_consumed(self):
240
for rule in list(self.rules):
241
if rule.rule_passed(): # a rule passed, move on
242
self.rules.remove(rule)
243
return len(self.rules) == 0
244
assert False, 'No assertion rules were satisfied for statement'
246
def consume_final(self):
247
return len(self.rules) == 0
249
def _process_engine_statement(query, context):
252
# oracle+zxjdbc passes a PyStatement when returning into
254
query = unicode(query)
255
if context.engine.name == 'mssql' \
256
and query.endswith('; select scope_identity()'):
258
query = re.sub(r'\n', '', query)
261
def _process_assertion_statement(query, context):
262
paramstyle = context.dialect.paramstyle
263
if paramstyle == 'named':
265
elif paramstyle =='pyformat':
266
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
270
if paramstyle=='qmark':
272
elif paramstyle=='format':
274
elif paramstyle=='numeric':
276
query = re.sub(r':([\w_]+)', repl, query)
280
class SQLAssert(object):
284
def add_rules(self, rules):
285
self.rules = list(rules)
287
def statement_complete(self):
288
for rule in self.rules:
289
if not rule.consume_final():
291
'All statements are complete, but pending '\
292
'assertion rules remain'
294
def clear_rules(self):
297
def execute(self, conn, clauseelement, multiparams, params, result):
298
if self.rules is not None:
301
'All rules have been exhausted, but further '\
304
rule.process_execute(clauseelement, *multiparams, **params)
305
if rule.is_consumed():
308
def cursor_execute(self, conn, cursor, statement, parameters,
309
context, executemany):
312
rule.process_cursor_execute(statement, parameters, context,
315
asserter = SQLAssert()