2
from ..engine.default import DefaultDialect
7
class AssertRule(object):
9
def process_execute(self, clauseelement, *multiparams, **params):
12
def process_cursor_execute(self, statement, parameters, context,
16
def is_consumed(self):
17
"""Return True if this rule has been consumed, False if not.
19
Should raise an AssertionError if this rule's condition has
24
raise NotImplementedError()
26
def rule_passed(self):
27
"""Return True if the last test of this rule passed, False if
28
failed, None if no test was applied."""
30
raise NotImplementedError()
32
def consume_final(self):
33
"""Return True if this rule has been consumed.
35
Should raise an AssertionError if this rule's condition has not
36
been consumed or has failed.
40
if self._result is None:
41
assert False, 'Rule has not been consumed'
42
return self.is_consumed()
45
class SQLMatchRule(AssertRule):
50
def rule_passed(self):
53
def is_consumed(self):
54
if self._result is None:
57
assert self._result, self._errmsg
62
class ExactSQL(SQLMatchRule):
64
def __init__(self, sql, params=None):
65
SQLMatchRule.__init__(self)
69
def process_cursor_execute(self, statement, parameters, context,
73
_received_statement = \
74
_process_engine_statement(context.unicode_statement,
76
_received_parameters = context.compiled_parameters
78
# TODO: remove this step once all unit tests are migrated, as
79
# ExactSQL should really be *exact* SQL
81
sql = _process_assertion_statement(self.sql, context)
82
equivalent = _received_statement == sql
84
if util.callable(self.params):
85
params = self.params(context)
88
if not isinstance(params, list):
90
equivalent = equivalent and params \
91
== context.compiled_parameters
94
self._result = equivalent
97
'Testing for exact statement %r exact params %r, '\
98
'received %r with params %r' % (sql, params,
99
_received_statement, _received_parameters)
102
class RegexSQL(SQLMatchRule):
104
def __init__(self, regex, params=None):
105
SQLMatchRule.__init__(self)
106
self.regex = re.compile(regex)
107
self.orig_regex = regex
110
def process_cursor_execute(self, statement, parameters, context,
114
_received_statement = \
115
_process_engine_statement(context.unicode_statement,
117
_received_parameters = context.compiled_parameters
118
equivalent = bool(self.regex.match(_received_statement))
120
if util.callable(self.params):
121
params = self.params(context)
124
if not isinstance(params, list):
127
# do a positive compare only
129
for param, received in zip(params, _received_parameters):
130
for k, v in param.iteritems():
131
if k not in received or received[k] != v:
136
self._result = equivalent
139
'Testing for regex %r partial params %r, received %r '\
140
'with params %r' % (self.orig_regex, params,
142
_received_parameters)
145
class CompiledSQL(SQLMatchRule):
147
def __init__(self, statement, params=None):
148
SQLMatchRule.__init__(self)
149
self.statement = statement
152
def process_cursor_execute(self, statement, parameters, context,
156
from sqlalchemy.schema import _DDLCompiles
157
_received_parameters = list(context.compiled_parameters)
159
# recompile from the context, using the default dialect
161
if isinstance(context.compiled.statement, _DDLCompiles):
163
context.compiled.statement.compile(dialect=DefaultDialect())
166
context.compiled.statement.compile(dialect=DefaultDialect(),
167
column_keys=context.compiled.column_keys)
168
_received_statement = re.sub(r'[\n\t]', '', str(compiled))
169
equivalent = self.statement == _received_statement
171
if util.callable(self.params):
172
params = self.params(context)
175
if not isinstance(params, list):
178
params = list(params)
179
all_params = list(params)
180
all_received = list(_received_parameters)
182
param = dict(params.pop(0))
183
for k, v in context.compiled.params.iteritems():
184
param.setdefault(k, v)
185
if param not in _received_parameters:
189
_received_parameters.remove(param)
190
if _received_parameters:
196
self._result = equivalent
198
print 'Testing for compiled statement %r partial params '\
199
'%r, received %r with params %r' % (self.statement,
200
all_params, _received_statement, all_received)
202
'Testing for compiled statement %r partial params %r, '\
203
'received %r with params %r' % (self.statement,
204
all_params, _received_statement, all_received)
209
class CountStatements(AssertRule):
211
def __init__(self, count):
213
self._statement_count = 0
215
def process_execute(self, clauseelement, *multiparams, **params):
216
self._statement_count += 1
218
def process_cursor_execute(self, statement, parameters, context,
222
def is_consumed(self):
225
def consume_final(self):
226
assert self.count == self._statement_count, \
227
'desired statement count %d does not match %d' \
228
% (self.count, self._statement_count)
232
class AllOf(AssertRule):
234
def __init__(self, *rules):
235
self.rules = set(rules)
237
def process_execute(self, clauseelement, *multiparams, **params):
238
for rule in self.rules:
239
rule.process_execute(clauseelement, *multiparams, **params)
241
def process_cursor_execute(self, statement, parameters, context,
243
for rule in self.rules:
244
rule.process_cursor_execute(statement, parameters, context,
247
def is_consumed(self):
250
for rule in list(self.rules):
251
if rule.rule_passed(): # a rule passed, move on
252
self.rules.remove(rule)
253
return len(self.rules) == 0
254
assert False, 'No assertion rules were satisfied for statement'
256
def consume_final(self):
257
return len(self.rules) == 0
260
def _process_engine_statement(query, context):
263
# oracle+zxjdbc passes a PyStatement when returning into
265
query = unicode(query)
266
if context.engine.name == 'mssql' \
267
and query.endswith('; select scope_identity()'):
269
query = re.sub(r'\n', '', query)
273
def _process_assertion_statement(query, context):
274
paramstyle = context.dialect.paramstyle
275
if paramstyle == 'named':
277
elif paramstyle == 'pyformat':
278
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
282
if paramstyle == 'qmark':
284
elif paramstyle == 'format':
286
elif paramstyle == 'numeric':
288
query = re.sub(r':([\w_]+)', repl, query)
293
class SQLAssert(object):
297
def add_rules(self, rules):
298
self.rules = list(rules)
300
def statement_complete(self):
301
for rule in self.rules:
302
if not rule.consume_final():
304
'All statements are complete, but pending '\
305
'assertion rules remain'
307
def clear_rules(self):
310
def execute(self, conn, clauseelement, multiparams, params, result):
311
if self.rules is not None:
314
'All rules have been exhausted, but further '\
317
rule.process_execute(clauseelement, *multiparams, **params)
318
if rule.is_consumed():
321
def cursor_execute(self, conn, cursor, statement, parameters,
322
context, executemany):
325
rule.process_cursor_execute(statement, parameters, context,
328
asserter = SQLAssert()