~ubuntu-branches/debian/jessie/sqlalchemy/jessie

« back to all changes in this revision

Viewing changes to lib/sqlalchemy/testing/assertions.py

  • Committer: Package Import Robot
  • Author(s): Piotr Ożarowski, Jakub Wilk, Piotr Ożarowski
  • Date: 2013-07-06 20:53:52 UTC
  • mfrom: (1.4.23) (16.1.17 experimental)
  • Revision ID: package-import@ubuntu.com-20130706205352-ryppl1eto3illd79
Tags: 0.8.2-1
[ Jakub Wilk ]
* Use canonical URIs for Vcs-* fields.

[ Piotr Ożarowski ]
* New upstream release
* Upload to unstable
* Build depend on python3-all instead of -dev, extensions are not built for
  Python 3.X 

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
from __future__ import absolute_import
 
2
 
 
3
from . import util as testutil
 
4
from sqlalchemy import pool, orm, util
 
5
from sqlalchemy.engine import default, create_engine
 
6
from sqlalchemy import exc as sa_exc
 
7
from sqlalchemy.util import decorator
 
8
from sqlalchemy import types as sqltypes, schema
 
9
import warnings
 
10
import re
 
11
from .warnings import resetwarnings
 
12
from .exclusions import db_spec, _is_excluded
 
13
from . import assertsql
 
14
from . import config
 
15
import itertools
 
16
from .util import fail
 
17
import contextlib
 
18
 
 
19
 
 
20
def emits_warning(*messages):
 
21
    """Mark a test as emitting a warning.
 
22
 
 
23
    With no arguments, squelches all SAWarning failures.  Or pass one or more
 
24
    strings; these will be matched to the root of the warning description by
 
25
    warnings.filterwarnings().
 
26
    """
 
27
    # TODO: it would be nice to assert that a named warning was
 
28
    # emitted. should work with some monkeypatching of warnings,
 
29
    # and may work on non-CPython if they keep to the spirit of
 
30
    # warnings.showwarning's docstring.
 
31
    # - update: jython looks ok, it uses cpython's module
 
32
 
 
33
    @decorator
 
34
    def decorate(fn, *args, **kw):
 
35
        # todo: should probably be strict about this, too
 
36
        filters = [dict(action='ignore',
 
37
                        category=sa_exc.SAPendingDeprecationWarning)]
 
38
        if not messages:
 
39
            filters.append(dict(action='ignore',
 
40
                                 category=sa_exc.SAWarning))
 
41
        else:
 
42
            filters.extend(dict(action='ignore',
 
43
                                 message=message,
 
44
                                 category=sa_exc.SAWarning)
 
45
                            for message in messages)
 
46
        for f in filters:
 
47
            warnings.filterwarnings(**f)
 
48
        try:
 
49
            return fn(*args, **kw)
 
50
        finally:
 
51
            resetwarnings()
 
52
    return decorate
 
53
 
 
54
 
 
55
def emits_warning_on(db, *warnings):
 
56
    """Mark a test as emitting a warning on a specific dialect.
 
57
 
 
58
    With no arguments, squelches all SAWarning failures.  Or pass one or more
 
59
    strings; these will be matched to the root of the warning description by
 
60
    warnings.filterwarnings().
 
61
    """
 
62
    spec = db_spec(db)
 
63
 
 
64
    @decorator
 
65
    def decorate(fn, *args, **kw):
 
66
        if isinstance(db, basestring):
 
67
            if not spec(config.db):
 
68
                return fn(*args, **kw)
 
69
            else:
 
70
                wrapped = emits_warning(*warnings)(fn)
 
71
                return wrapped(*args, **kw)
 
72
        else:
 
73
            if not _is_excluded(*db):
 
74
                return fn(*args, **kw)
 
75
            else:
 
76
                wrapped = emits_warning(*warnings)(fn)
 
77
                return wrapped(*args, **kw)
 
78
    return decorate
 
79
 
 
80
 
 
81
def uses_deprecated(*messages):
 
82
    """Mark a test as immune from fatal deprecation warnings.
 
83
 
 
84
    With no arguments, squelches all SADeprecationWarning failures.
 
85
    Or pass one or more strings; these will be matched to the root
 
86
    of the warning description by warnings.filterwarnings().
 
87
 
 
88
    As a special case, you may pass a function name prefixed with //
 
89
    and it will be re-written as needed to match the standard warning
 
90
    verbiage emitted by the sqlalchemy.util.deprecated decorator.
 
91
    """
 
92
 
 
93
    @decorator
 
94
    def decorate(fn, *args, **kw):
 
95
        # todo: should probably be strict about this, too
 
96
        filters = [dict(action='ignore',
 
97
                        category=sa_exc.SAPendingDeprecationWarning)]
 
98
        if not messages:
 
99
            filters.append(dict(action='ignore',
 
100
                                category=sa_exc.SADeprecationWarning))
 
101
        else:
 
102
            filters.extend(
 
103
                [dict(action='ignore',
 
104
                      message=message,
 
105
                      category=sa_exc.SADeprecationWarning)
 
106
                 for message in
 
107
                 [(m.startswith('//') and
 
108
                    ('Call to deprecated function ' + m[2:]) or m)
 
109
                   for m in messages]])
 
110
 
 
111
        for f in filters:
 
112
            warnings.filterwarnings(**f)
 
113
        try:
 
114
            return fn(*args, **kw)
 
115
        finally:
 
116
            resetwarnings()
 
117
    return decorate
 
118
 
 
119
 
 
120
def global_cleanup_assertions():
 
121
    """Check things that have to be finalized at the end of a test suite.
 
122
 
 
123
    Hardcoded at the moment, a modular system can be built here
 
124
    to support things like PG prepared transactions, tables all
 
125
    dropped, etc.
 
126
 
 
127
    """
 
128
 
 
129
    testutil.lazy_gc()
 
130
    assert not pool._refs, str(pool._refs)
 
131
 
 
132
 
 
133
def eq_(a, b, msg=None):
 
134
    """Assert a == b, with repr messaging on failure."""
 
135
    assert a == b, msg or "%r != %r" % (a, b)
 
136
 
 
137
 
 
138
def ne_(a, b, msg=None):
 
139
    """Assert a != b, with repr messaging on failure."""
 
140
    assert a != b, msg or "%r == %r" % (a, b)
 
141
 
 
142
 
 
143
def is_(a, b, msg=None):
 
144
    """Assert a is b, with repr messaging on failure."""
 
145
    assert a is b, msg or "%r is not %r" % (a, b)
 
146
 
 
147
 
 
148
def is_not_(a, b, msg=None):
 
149
    """Assert a is not b, with repr messaging on failure."""
 
150
    assert a is not b, msg or "%r is %r" % (a, b)
 
151
 
 
152
 
 
153
def startswith_(a, fragment, msg=None):
 
154
    """Assert a.startswith(fragment), with repr messaging on failure."""
 
155
    assert a.startswith(fragment), msg or "%r does not start with %r" % (
 
156
        a, fragment)
 
157
 
 
158
 
 
159
def assert_raises(except_cls, callable_, *args, **kw):
 
160
    try:
 
161
        callable_(*args, **kw)
 
162
        success = False
 
163
    except except_cls:
 
164
        success = True
 
165
 
 
166
    # assert outside the block so it works for AssertionError too !
 
167
    assert success, "Callable did not raise an exception"
 
168
 
 
169
 
 
170
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
 
171
    try:
 
172
        callable_(*args, **kwargs)
 
173
        assert False, "Callable did not raise an exception"
 
174
    except except_cls, e:
 
175
        assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
 
176
        print unicode(e).encode('utf-8')
 
177
 
 
178
 
 
179
class AssertsCompiledSQL(object):
 
180
    def assert_compile(self, clause, result, params=None,
 
181
                        checkparams=None, dialect=None,
 
182
                        checkpositional=None,
 
183
                        use_default_dialect=False,
 
184
                        allow_dialect_select=False):
 
185
        if use_default_dialect:
 
186
            dialect = default.DefaultDialect()
 
187
        elif dialect == None and not allow_dialect_select:
 
188
            dialect = getattr(self, '__dialect__', None)
 
189
            if dialect == 'default':
 
190
                dialect = default.DefaultDialect()
 
191
            elif dialect is None:
 
192
                dialect = config.db.dialect
 
193
            elif isinstance(dialect, basestring):
 
194
                dialect = create_engine("%s://" % dialect).dialect
 
195
 
 
196
        kw = {}
 
197
        if params is not None:
 
198
            kw['column_keys'] = params.keys()
 
199
 
 
200
        if isinstance(clause, orm.Query):
 
201
            context = clause._compile_context()
 
202
            context.statement.use_labels = True
 
203
            clause = context.statement
 
204
 
 
205
        c = clause.compile(dialect=dialect, **kw)
 
206
 
 
207
        param_str = repr(getattr(c, 'params', {}))
 
208
        # Py3K
 
209
        #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
 
210
 
 
211
        print "\nSQL String:\n" + str(c) + param_str
 
212
 
 
213
        cc = re.sub(r'[\n\t]', '', str(c))
 
214
 
 
215
        eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
 
216
 
 
217
        if checkparams is not None:
 
218
            eq_(c.construct_params(params), checkparams)
 
219
        if checkpositional is not None:
 
220
            p = c.construct_params(params)
 
221
            eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
 
222
 
 
223
 
 
224
class ComparesTables(object):
 
225
 
 
226
    def assert_tables_equal(self, table, reflected_table, strict_types=False):
 
227
        assert len(table.c) == len(reflected_table.c)
 
228
        for c, reflected_c in zip(table.c, reflected_table.c):
 
229
            eq_(c.name, reflected_c.name)
 
230
            assert reflected_c is reflected_table.c[c.name]
 
231
            eq_(c.primary_key, reflected_c.primary_key)
 
232
            eq_(c.nullable, reflected_c.nullable)
 
233
 
 
234
            if strict_types:
 
235
                msg = "Type '%s' doesn't correspond to type '%s'"
 
236
                assert type(reflected_c.type) is type(c.type), \
 
237
                    msg % (reflected_c.type, c.type)
 
238
            else:
 
239
                self.assert_types_base(reflected_c, c)
 
240
 
 
241
            if isinstance(c.type, sqltypes.String):
 
242
                eq_(c.type.length, reflected_c.type.length)
 
243
 
 
244
            eq_(
 
245
                set([f.column.name for f in c.foreign_keys]),
 
246
                set([f.column.name for f in reflected_c.foreign_keys])
 
247
            )
 
248
            if c.server_default:
 
249
                assert isinstance(reflected_c.server_default,
 
250
                                  schema.FetchedValue)
 
251
 
 
252
        assert len(table.primary_key) == len(reflected_table.primary_key)
 
253
        for c in table.primary_key:
 
254
            assert reflected_table.primary_key.columns[c.name] is not None
 
255
 
 
256
    def assert_types_base(self, c1, c2):
 
257
        assert c1.type._compare_type_affinity(c2.type),\
 
258
                "On column %r, type '%s' doesn't correspond to type '%s'" % \
 
259
                (c1.name, c1.type, c2.type)
 
260
 
 
261
 
 
262
class AssertsExecutionResults(object):
 
263
    def assert_result(self, result, class_, *objects):
 
264
        result = list(result)
 
265
        print repr(result)
 
266
        self.assert_list(result, class_, objects)
 
267
 
 
268
    def assert_list(self, result, class_, list):
 
269
        self.assert_(len(result) == len(list),
 
270
                     "result list is not the same size as test list, " +
 
271
                     "for class " + class_.__name__)
 
272
        for i in range(0, len(list)):
 
273
            self.assert_row(class_, result[i], list[i])
 
274
 
 
275
    def assert_row(self, class_, rowobj, desc):
 
276
        self.assert_(rowobj.__class__ is class_,
 
277
                     "item class is not " + repr(class_))
 
278
        for key, value in desc.iteritems():
 
279
            if isinstance(value, tuple):
 
280
                if isinstance(value[1], list):
 
281
                    self.assert_list(getattr(rowobj, key), value[0], value[1])
 
282
                else:
 
283
                    self.assert_row(value[0], getattr(rowobj, key), value[1])
 
284
            else:
 
285
                self.assert_(getattr(rowobj, key) == value,
 
286
                             "attribute %s value %s does not match %s" % (
 
287
                             key, getattr(rowobj, key), value))
 
288
 
 
289
    def assert_unordered_result(self, result, cls, *expected):
 
290
        """As assert_result, but the order of objects is not considered.
 
291
 
 
292
        The algorithm is very expensive but not a big deal for the small
 
293
        numbers of rows that the test suite manipulates.
 
294
        """
 
295
 
 
296
        class immutabledict(dict):
 
297
            def __hash__(self):
 
298
                return id(self)
 
299
 
 
300
        found = util.IdentitySet(result)
 
301
        expected = set([immutabledict(e) for e in expected])
 
302
 
 
303
        for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
 
304
            fail('Unexpected type "%s", expected "%s"' % (
 
305
                type(wrong).__name__, cls.__name__))
 
306
 
 
307
        if len(found) != len(expected):
 
308
            fail('Unexpected object count "%s", expected "%s"' % (
 
309
                len(found), len(expected)))
 
310
 
 
311
        NOVALUE = object()
 
312
 
 
313
        def _compare_item(obj, spec):
 
314
            for key, value in spec.iteritems():
 
315
                if isinstance(value, tuple):
 
316
                    try:
 
317
                        self.assert_unordered_result(
 
318
                            getattr(obj, key), value[0], *value[1])
 
319
                    except AssertionError:
 
320
                        return False
 
321
                else:
 
322
                    if getattr(obj, key, NOVALUE) != value:
 
323
                        return False
 
324
            return True
 
325
 
 
326
        for expected_item in expected:
 
327
            for found_item in found:
 
328
                if _compare_item(found_item, expected_item):
 
329
                    found.remove(found_item)
 
330
                    break
 
331
            else:
 
332
                fail(
 
333
                    "Expected %s instance with attributes %s not found." % (
 
334
                    cls.__name__, repr(expected_item)))
 
335
        return True
 
336
 
 
337
    def assert_sql_execution(self, db, callable_, *rules):
 
338
        assertsql.asserter.add_rules(rules)
 
339
        try:
 
340
            callable_()
 
341
            assertsql.asserter.statement_complete()
 
342
        finally:
 
343
            assertsql.asserter.clear_rules()
 
344
 
 
345
    def assert_sql(self, db, callable_, list_, with_sequences=None):
 
346
        if with_sequences is not None and config.db.dialect.supports_sequences:
 
347
            rules = with_sequences
 
348
        else:
 
349
            rules = list_
 
350
 
 
351
        newrules = []
 
352
        for rule in rules:
 
353
            if isinstance(rule, dict):
 
354
                newrule = assertsql.AllOf(*[
 
355
                    assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
 
356
                ])
 
357
            else:
 
358
                newrule = assertsql.ExactSQL(*rule)
 
359
            newrules.append(newrule)
 
360
 
 
361
        self.assert_sql_execution(db, callable_, *newrules)
 
362
 
 
363
    def assert_sql_count(self, db, callable_, count):
 
364
        self.assert_sql_execution(
 
365
            db, callable_, assertsql.CountStatements(count))
 
366
 
 
367
    @contextlib.contextmanager
 
368
    def assert_execution(self, *rules):
 
369
        assertsql.asserter.add_rules(rules)
 
370
        try:
 
371
            yield
 
372
            assertsql.asserter.statement_complete()
 
373
        finally:
 
374
            assertsql.asserter.clear_rules()
 
375
 
 
376
    def assert_statement_count(self, count):
 
377
        return self.assert_execution(assertsql.CountStatements(count))