1
from __future__ import absolute_import
3
from . import util as testutil
4
from sqlalchemy import pool, orm, util
5
from sqlalchemy.engine import default, create_engine
6
from sqlalchemy import exc as sa_exc
7
from sqlalchemy.util import decorator
8
from sqlalchemy import types as sqltypes, schema
11
from .warnings import resetwarnings
12
from .exclusions import db_spec, _is_excluded
13
from . import assertsql
16
from .util import fail
20
def emits_warning(*messages):
21
"""Mark a test as emitting a warning.
23
With no arguments, squelches all SAWarning failures. Or pass one or more
24
strings; these will be matched to the root of the warning description by
25
warnings.filterwarnings().
27
# TODO: it would be nice to assert that a named warning was
28
# emitted. should work with some monkeypatching of warnings,
29
# and may work on non-CPython if they keep to the spirit of
30
# warnings.showwarning's docstring.
31
# - update: jython looks ok, it uses cpython's module
34
def decorate(fn, *args, **kw):
35
# todo: should probably be strict about this, too
36
filters = [dict(action='ignore',
37
category=sa_exc.SAPendingDeprecationWarning)]
39
filters.append(dict(action='ignore',
40
category=sa_exc.SAWarning))
42
filters.extend(dict(action='ignore',
44
category=sa_exc.SAWarning)
45
for message in messages)
47
warnings.filterwarnings(**f)
49
return fn(*args, **kw)
55
def emits_warning_on(db, *warnings):
56
"""Mark a test as emitting a warning on a specific dialect.
58
With no arguments, squelches all SAWarning failures. Or pass one or more
59
strings; these will be matched to the root of the warning description by
60
warnings.filterwarnings().
65
def decorate(fn, *args, **kw):
66
if isinstance(db, basestring):
67
if not spec(config.db):
68
return fn(*args, **kw)
70
wrapped = emits_warning(*warnings)(fn)
71
return wrapped(*args, **kw)
73
if not _is_excluded(*db):
74
return fn(*args, **kw)
76
wrapped = emits_warning(*warnings)(fn)
77
return wrapped(*args, **kw)
81
def uses_deprecated(*messages):
82
"""Mark a test as immune from fatal deprecation warnings.
84
With no arguments, squelches all SADeprecationWarning failures.
85
Or pass one or more strings; these will be matched to the root
86
of the warning description by warnings.filterwarnings().
88
As a special case, you may pass a function name prefixed with //
89
and it will be re-written as needed to match the standard warning
90
verbiage emitted by the sqlalchemy.util.deprecated decorator.
94
def decorate(fn, *args, **kw):
95
# todo: should probably be strict about this, too
96
filters = [dict(action='ignore',
97
category=sa_exc.SAPendingDeprecationWarning)]
99
filters.append(dict(action='ignore',
100
category=sa_exc.SADeprecationWarning))
103
[dict(action='ignore',
105
category=sa_exc.SADeprecationWarning)
107
[(m.startswith('//') and
108
('Call to deprecated function ' + m[2:]) or m)
112
warnings.filterwarnings(**f)
114
return fn(*args, **kw)
120
def global_cleanup_assertions():
121
"""Check things that have to be finalized at the end of a test suite.
123
Hardcoded at the moment, a modular system can be built here
124
to support things like PG prepared transactions, tables all
130
assert not pool._refs, str(pool._refs)
133
def eq_(a, b, msg=None):
134
"""Assert a == b, with repr messaging on failure."""
135
assert a == b, msg or "%r != %r" % (a, b)
138
def ne_(a, b, msg=None):
139
"""Assert a != b, with repr messaging on failure."""
140
assert a != b, msg or "%r == %r" % (a, b)
143
def is_(a, b, msg=None):
144
"""Assert a is b, with repr messaging on failure."""
145
assert a is b, msg or "%r is not %r" % (a, b)
148
def is_not_(a, b, msg=None):
149
"""Assert a is not b, with repr messaging on failure."""
150
assert a is not b, msg or "%r is %r" % (a, b)
153
def startswith_(a, fragment, msg=None):
154
"""Assert a.startswith(fragment), with repr messaging on failure."""
155
assert a.startswith(fragment), msg or "%r does not start with %r" % (
159
def assert_raises(except_cls, callable_, *args, **kw):
161
callable_(*args, **kw)
166
# assert outside the block so it works for AssertionError too !
167
assert success, "Callable did not raise an exception"
170
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
172
callable_(*args, **kwargs)
173
assert False, "Callable did not raise an exception"
174
except except_cls, e:
175
assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
176
print unicode(e).encode('utf-8')
179
class AssertsCompiledSQL(object):
180
def assert_compile(self, clause, result, params=None,
181
checkparams=None, dialect=None,
182
checkpositional=None,
183
use_default_dialect=False,
184
allow_dialect_select=False):
185
if use_default_dialect:
186
dialect = default.DefaultDialect()
187
elif dialect == None and not allow_dialect_select:
188
dialect = getattr(self, '__dialect__', None)
189
if dialect == 'default':
190
dialect = default.DefaultDialect()
191
elif dialect is None:
192
dialect = config.db.dialect
193
elif isinstance(dialect, basestring):
194
dialect = create_engine("%s://" % dialect).dialect
197
if params is not None:
198
kw['column_keys'] = params.keys()
200
if isinstance(clause, orm.Query):
201
context = clause._compile_context()
202
context.statement.use_labels = True
203
clause = context.statement
205
c = clause.compile(dialect=dialect, **kw)
207
param_str = repr(getattr(c, 'params', {}))
209
#param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
211
print "\nSQL String:\n" + str(c) + param_str
213
cc = re.sub(r'[\n\t]', '', str(c))
215
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
217
if checkparams is not None:
218
eq_(c.construct_params(params), checkparams)
219
if checkpositional is not None:
220
p = c.construct_params(params)
221
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
224
class ComparesTables(object):
226
def assert_tables_equal(self, table, reflected_table, strict_types=False):
227
assert len(table.c) == len(reflected_table.c)
228
for c, reflected_c in zip(table.c, reflected_table.c):
229
eq_(c.name, reflected_c.name)
230
assert reflected_c is reflected_table.c[c.name]
231
eq_(c.primary_key, reflected_c.primary_key)
232
eq_(c.nullable, reflected_c.nullable)
235
msg = "Type '%s' doesn't correspond to type '%s'"
236
assert type(reflected_c.type) is type(c.type), \
237
msg % (reflected_c.type, c.type)
239
self.assert_types_base(reflected_c, c)
241
if isinstance(c.type, sqltypes.String):
242
eq_(c.type.length, reflected_c.type.length)
245
set([f.column.name for f in c.foreign_keys]),
246
set([f.column.name for f in reflected_c.foreign_keys])
249
assert isinstance(reflected_c.server_default,
252
assert len(table.primary_key) == len(reflected_table.primary_key)
253
for c in table.primary_key:
254
assert reflected_table.primary_key.columns[c.name] is not None
256
def assert_types_base(self, c1, c2):
257
assert c1.type._compare_type_affinity(c2.type),\
258
"On column %r, type '%s' doesn't correspond to type '%s'" % \
259
(c1.name, c1.type, c2.type)
262
class AssertsExecutionResults(object):
263
def assert_result(self, result, class_, *objects):
264
result = list(result)
266
self.assert_list(result, class_, objects)
268
def assert_list(self, result, class_, list):
269
self.assert_(len(result) == len(list),
270
"result list is not the same size as test list, " +
271
"for class " + class_.__name__)
272
for i in range(0, len(list)):
273
self.assert_row(class_, result[i], list[i])
275
def assert_row(self, class_, rowobj, desc):
276
self.assert_(rowobj.__class__ is class_,
277
"item class is not " + repr(class_))
278
for key, value in desc.iteritems():
279
if isinstance(value, tuple):
280
if isinstance(value[1], list):
281
self.assert_list(getattr(rowobj, key), value[0], value[1])
283
self.assert_row(value[0], getattr(rowobj, key), value[1])
285
self.assert_(getattr(rowobj, key) == value,
286
"attribute %s value %s does not match %s" % (
287
key, getattr(rowobj, key), value))
289
def assert_unordered_result(self, result, cls, *expected):
290
"""As assert_result, but the order of objects is not considered.
292
The algorithm is very expensive but not a big deal for the small
293
numbers of rows that the test suite manipulates.
296
class immutabledict(dict):
300
found = util.IdentitySet(result)
301
expected = set([immutabledict(e) for e in expected])
303
for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
304
fail('Unexpected type "%s", expected "%s"' % (
305
type(wrong).__name__, cls.__name__))
307
if len(found) != len(expected):
308
fail('Unexpected object count "%s", expected "%s"' % (
309
len(found), len(expected)))
313
def _compare_item(obj, spec):
314
for key, value in spec.iteritems():
315
if isinstance(value, tuple):
317
self.assert_unordered_result(
318
getattr(obj, key), value[0], *value[1])
319
except AssertionError:
322
if getattr(obj, key, NOVALUE) != value:
326
for expected_item in expected:
327
for found_item in found:
328
if _compare_item(found_item, expected_item):
329
found.remove(found_item)
333
"Expected %s instance with attributes %s not found." % (
334
cls.__name__, repr(expected_item)))
337
def assert_sql_execution(self, db, callable_, *rules):
338
assertsql.asserter.add_rules(rules)
341
assertsql.asserter.statement_complete()
343
assertsql.asserter.clear_rules()
345
def assert_sql(self, db, callable_, list_, with_sequences=None):
346
if with_sequences is not None and config.db.dialect.supports_sequences:
347
rules = with_sequences
353
if isinstance(rule, dict):
354
newrule = assertsql.AllOf(*[
355
assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
358
newrule = assertsql.ExactSQL(*rule)
359
newrules.append(newrule)
361
self.assert_sql_execution(db, callable_, *newrules)
363
def assert_sql_count(self, db, callable_, count):
364
self.assert_sql_execution(
365
db, callable_, assertsql.CountStatements(count))
367
@contextlib.contextmanager
368
def assert_execution(self, *rules):
369
assertsql.asserter.add_rules(rules)
372
assertsql.asserter.statement_complete()
374
assertsql.asserter.clear_rules()
376
def assert_statement_count(self, count):
377
return self.assert_execution(assertsql.CountStatements(count))