~ubuntu-branches/debian/jessie/sqlalchemy/jessie

« back to all changes in this revision

Viewing changes to lib/sqlalchemy/testing/engines.py

  • Committer: Package Import Robot
  • Author(s): Piotr Ożarowski, Jakub Wilk, Piotr Ożarowski
  • Date: 2013-07-06 20:53:52 UTC
  • mfrom: (1.4.23) (16.1.17 experimental)
  • Revision ID: package-import@ubuntu.com-20130706205352-ryppl1eto3illd79
Tags: 0.8.2-1
[ Jakub Wilk ]
* Use canonical URIs for Vcs-* fields.

[ Piotr Ożarowski ]
* New upstream release
* Upload to unstable
* Build depend on python3-all instead of -dev, extensions are not built for
  Python 3.X 

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
from __future__ import absolute_import
 
2
 
 
3
import types
 
4
import weakref
 
5
from collections import deque
 
6
from . import config
 
7
from .util import decorator
 
8
from .. import event, pool
 
9
import re
 
10
import warnings
 
11
 
 
12
 
 
13
class ConnectionKiller(object):
 
14
 
 
15
    def __init__(self):
 
16
        self.proxy_refs = weakref.WeakKeyDictionary()
 
17
        self.testing_engines = weakref.WeakKeyDictionary()
 
18
        self.conns = set()
 
19
 
 
20
    def add_engine(self, engine):
 
21
        self.testing_engines[engine] = True
 
22
 
 
23
    def connect(self, dbapi_conn, con_record):
 
24
        self.conns.add((dbapi_conn, con_record))
 
25
 
 
26
    def checkout(self, dbapi_con, con_record, con_proxy):
 
27
        self.proxy_refs[con_proxy] = True
 
28
 
 
29
    def _safe(self, fn):
 
30
        try:
 
31
            fn()
 
32
        except (SystemExit, KeyboardInterrupt):
 
33
            raise
 
34
        except Exception, e:
 
35
            warnings.warn(
 
36
                    "testing_reaper couldn't "
 
37
                    "rollback/close connection: %s" % e)
 
38
 
 
39
    def rollback_all(self):
 
40
        for rec in self.proxy_refs.keys():
 
41
            if rec is not None and rec.is_valid:
 
42
                self._safe(rec.rollback)
 
43
 
 
44
    def close_all(self):
 
45
        for rec in self.proxy_refs.keys():
 
46
            if rec is not None:
 
47
                self._safe(rec._close)
 
48
 
 
49
    def _after_test_ctx(self):
 
50
        pass
 
51
        # this can cause a deadlock with pg8000 - pg8000 acquires
 
52
        # prepared statment lock inside of rollback() - if async gc
 
53
        # is collecting in finalize_fairy, deadlock.
 
54
        # not sure if this should be if pypy/jython only.
 
55
        # note that firebird/fdb definitely needs this though
 
56
        for conn, rec in self.conns:
 
57
            self._safe(conn.rollback)
 
58
 
 
59
    def _stop_test_ctx(self):
 
60
        if config.options.low_connections:
 
61
            self._stop_test_ctx_minimal()
 
62
        else:
 
63
            self._stop_test_ctx_aggressive()
 
64
 
 
65
    def _stop_test_ctx_minimal(self):
 
66
        self.close_all()
 
67
 
 
68
        self.conns = set()
 
69
 
 
70
        for rec in self.testing_engines.keys():
 
71
            if rec is not config.db:
 
72
                rec.dispose()
 
73
 
 
74
    def _stop_test_ctx_aggressive(self):
 
75
        self.close_all()
 
76
        for conn, rec in self.conns:
 
77
            self._safe(conn.close)
 
78
            rec.connection = None
 
79
            
 
80
        self.conns = set()
 
81
        for rec in self.testing_engines.keys():
 
82
            rec.dispose()
 
83
 
 
84
    def assert_all_closed(self):
 
85
        for rec in self.proxy_refs:
 
86
            if rec.is_valid:
 
87
                assert False
 
88
 
 
89
testing_reaper = ConnectionKiller()
 
90
 
 
91
 
 
92
def drop_all_tables(metadata, bind):
 
93
    testing_reaper.close_all()
 
94
    if hasattr(bind, 'close'):
 
95
        bind.close()
 
96
    metadata.drop_all(bind)
 
97
 
 
98
 
 
99
@decorator
 
100
def assert_conns_closed(fn, *args, **kw):
 
101
    try:
 
102
        fn(*args, **kw)
 
103
    finally:
 
104
        testing_reaper.assert_all_closed()
 
105
 
 
106
 
 
107
@decorator
 
108
def rollback_open_connections(fn, *args, **kw):
 
109
    """Decorator that rolls back all open connections after fn execution."""
 
110
 
 
111
    try:
 
112
        fn(*args, **kw)
 
113
    finally:
 
114
        testing_reaper.rollback_all()
 
115
 
 
116
 
 
117
@decorator
 
118
def close_first(fn, *args, **kw):
 
119
    """Decorator that closes all connections before fn execution."""
 
120
 
 
121
    testing_reaper.close_all()
 
122
    fn(*args, **kw)
 
123
 
 
124
 
 
125
@decorator
 
126
def close_open_connections(fn, *args, **kw):
 
127
    """Decorator that closes all connections after fn execution."""
 
128
    try:
 
129
        fn(*args, **kw)
 
130
    finally:
 
131
        testing_reaper.close_all()
 
132
 
 
133
 
 
134
def all_dialects(exclude=None):
 
135
    import sqlalchemy.databases as d
 
136
    for name in d.__all__:
 
137
        # TEMPORARY
 
138
        if exclude and name in exclude:
 
139
            continue
 
140
        mod = getattr(d, name, None)
 
141
        if not mod:
 
142
            mod = getattr(__import__(
 
143
                'sqlalchemy.databases.%s' % name).databases, name)
 
144
        yield mod.dialect()
 
145
 
 
146
 
 
147
class ReconnectFixture(object):
 
148
 
 
149
    def __init__(self, dbapi):
 
150
        self.dbapi = dbapi
 
151
        self.connections = []
 
152
 
 
153
    def __getattr__(self, key):
 
154
        return getattr(self.dbapi, key)
 
155
 
 
156
    def connect(self, *args, **kwargs):
 
157
        conn = self.dbapi.connect(*args, **kwargs)
 
158
        self.connections.append(conn)
 
159
        return conn
 
160
 
 
161
    def _safe(self, fn):
 
162
        try:
 
163
            fn()
 
164
        except (SystemExit, KeyboardInterrupt):
 
165
            raise
 
166
        except Exception, e:
 
167
            warnings.warn(
 
168
                    "ReconnectFixture couldn't "
 
169
                    "close connection: %s" % e)
 
170
 
 
171
    def shutdown(self):
 
172
        # TODO: this doesn't cover all cases
 
173
        # as nicely as we'd like, namely MySQLdb.
 
174
        # would need to implement R. Brewer's
 
175
        # proxy server idea to get better
 
176
        # coverage.
 
177
        for c in list(self.connections):
 
178
            self._safe(c.close)
 
179
        self.connections = []
 
180
 
 
181
 
 
182
def reconnecting_engine(url=None, options=None):
 
183
    url = url or config.db_url
 
184
    dbapi = config.db.dialect.dbapi
 
185
    if not options:
 
186
        options = {}
 
187
    options['module'] = ReconnectFixture(dbapi)
 
188
    engine = testing_engine(url, options)
 
189
    _dispose = engine.dispose
 
190
 
 
191
    def dispose():
 
192
        engine.dialect.dbapi.shutdown()
 
193
        _dispose()
 
194
 
 
195
    engine.test_shutdown = engine.dialect.dbapi.shutdown
 
196
    engine.dispose = dispose
 
197
    return engine
 
198
 
 
199
 
 
200
def testing_engine(url=None, options=None):
 
201
    """Produce an engine configured by --options with optional overrides."""
 
202
 
 
203
    from sqlalchemy import create_engine
 
204
    from .assertsql import asserter
 
205
 
 
206
    if not options:
 
207
        use_reaper = True
 
208
    else:
 
209
        use_reaper = options.pop('use_reaper', True)
 
210
 
 
211
    url = url or config.db_url
 
212
    if options is None:
 
213
        options = config.db_opts
 
214
 
 
215
    engine = create_engine(url, **options)
 
216
    if isinstance(engine.pool, pool.QueuePool):
 
217
        engine.pool._timeout = 0
 
218
        engine.pool._max_overflow = 0
 
219
    event.listen(engine, 'after_execute', asserter.execute)
 
220
    event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
 
221
    if use_reaper:
 
222
        event.listen(engine.pool, 'connect', testing_reaper.connect)
 
223
        event.listen(engine.pool, 'checkout', testing_reaper.checkout)
 
224
        testing_reaper.add_engine(engine)
 
225
 
 
226
    return engine
 
227
 
 
228
 
 
229
def utf8_engine(url=None, options=None):
 
230
    """Hook for dialects or drivers that don't handle utf8 by default."""
 
231
 
 
232
    from sqlalchemy.engine import url as engine_url
 
233
 
 
234
    if config.db.dialect.name == 'mysql' and \
 
235
        config.db.driver in ['mysqldb', 'pymysql', 'cymysql']:
 
236
        # note 1.2.1.gamma.6 or greater of MySQLdb
 
237
        # needed here
 
238
        url = url or config.db_url
 
239
        url = engine_url.make_url(url)
 
240
        url.query['charset'] = 'utf8'
 
241
        url.query['use_unicode'] = '0'
 
242
        url = str(url)
 
243
 
 
244
    return testing_engine(url, options)
 
245
 
 
246
 
 
247
def mock_engine(dialect_name=None):
 
248
    """Provides a mocking engine based on the current testing.db.
 
249
 
 
250
    This is normally used to test DDL generation flow as emitted
 
251
    by an Engine.
 
252
 
 
253
    It should not be used in other cases, as assert_compile() and
 
254
    assert_sql_execution() are much better choices with fewer
 
255
    moving parts.
 
256
 
 
257
    """
 
258
 
 
259
    from sqlalchemy import create_engine
 
260
 
 
261
    if not dialect_name:
 
262
        dialect_name = config.db.name
 
263
 
 
264
    buffer = []
 
265
 
 
266
    def executor(sql, *a, **kw):
 
267
        buffer.append(sql)
 
268
 
 
269
    def assert_sql(stmts):
 
270
        recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
 
271
        assert  recv == stmts, recv
 
272
 
 
273
    def print_sql():
 
274
        d = engine.dialect
 
275
        return "\n".join(
 
276
            str(s.compile(dialect=d))
 
277
            for s in engine.mock
 
278
        )
 
279
 
 
280
    engine = create_engine(dialect_name + '://',
 
281
                           strategy='mock', executor=executor)
 
282
    assert not hasattr(engine, 'mock')
 
283
    engine.mock = buffer
 
284
    engine.assert_sql = assert_sql
 
285
    engine.print_sql = print_sql
 
286
    return engine
 
287
 
 
288
 
 
289
class DBAPIProxyCursor(object):
 
290
    """Proxy a DBAPI cursor.
 
291
 
 
292
    Tests can provide subclasses of this to intercept
 
293
    DBAPI-level cursor operations.
 
294
 
 
295
    """
 
296
    def __init__(self, engine, conn):
 
297
        self.engine = engine
 
298
        self.connection = conn
 
299
        self.cursor = conn.cursor()
 
300
 
 
301
    def execute(self, stmt, parameters=None, **kw):
 
302
        if parameters:
 
303
            return self.cursor.execute(stmt, parameters, **kw)
 
304
        else:
 
305
            return self.cursor.execute(stmt, **kw)
 
306
 
 
307
    def executemany(self, stmt, params, **kw):
 
308
        return self.cursor.executemany(stmt, params, **kw)
 
309
 
 
310
    def __getattr__(self, key):
 
311
        return getattr(self.cursor, key)
 
312
 
 
313
 
 
314
class DBAPIProxyConnection(object):
 
315
    """Proxy a DBAPI connection.
 
316
 
 
317
    Tests can provide subclasses of this to intercept
 
318
    DBAPI-level connection operations.
 
319
 
 
320
    """
 
321
    def __init__(self, engine, cursor_cls):
 
322
        self.conn = self._sqla_unwrap = engine.pool._creator()
 
323
        self.engine = engine
 
324
        self.cursor_cls = cursor_cls
 
325
 
 
326
    def cursor(self):
 
327
        return self.cursor_cls(self.engine, self.conn)
 
328
 
 
329
    def close(self):
 
330
        self.conn.close()
 
331
 
 
332
    def __getattr__(self, key):
 
333
        return getattr(self.conn, key)
 
334
 
 
335
 
 
336
def proxying_engine(conn_cls=DBAPIProxyConnection,
 
337
                    cursor_cls=DBAPIProxyCursor):
 
338
    """Produce an engine that provides proxy hooks for
 
339
    common methods.
 
340
 
 
341
    """
 
342
    def mock_conn():
 
343
        return conn_cls(config.db, cursor_cls)
 
344
    return testing_engine(options={'creator': mock_conn})
 
345
 
 
346
 
 
347
class ReplayableSession(object):
 
348
    """A simple record/playback tool.
 
349
 
 
350
    This is *not* a mock testing class.  It only records a session for later
 
351
    playback and makes no assertions on call consistency whatsoever.  It's
 
352
    unlikely to be suitable for anything other than DB-API recording.
 
353
 
 
354
    """
 
355
 
 
356
    Callable = object()
 
357
    NoAttribute = object()
 
358
 
 
359
    # Py3K
 
360
    #Natives = set([getattr(types, t)
 
361
    #               for t in dir(types) if not t.startswith('_')]). \
 
362
    #               union([type(t) if not isinstance(t, type)
 
363
    #                        else t for t in __builtins__.values()]).\
 
364
    #               difference([getattr(types, t)
 
365
    #                        for t in ('FunctionType', 'BuiltinFunctionType',
 
366
    #                                  'MethodType', 'BuiltinMethodType',
 
367
    #                                  'LambdaType', )])
 
368
    # Py2K
 
369
    Natives = set([getattr(types, t)
 
370
                   for t in dir(types) if not t.startswith('_')]). \
 
371
                   difference([getattr(types, t)
 
372
                           for t in ('FunctionType', 'BuiltinFunctionType',
 
373
                                     'MethodType', 'BuiltinMethodType',
 
374
                                     'LambdaType', 'UnboundMethodType',)])
 
375
    # end Py2K
 
376
 
 
377
    def __init__(self):
 
378
        self.buffer = deque()
 
379
 
 
380
    def recorder(self, base):
 
381
        return self.Recorder(self.buffer, base)
 
382
 
 
383
    def player(self):
 
384
        return self.Player(self.buffer)
 
385
 
 
386
    class Recorder(object):
 
387
        def __init__(self, buffer, subject):
 
388
            self._buffer = buffer
 
389
            self._subject = subject
 
390
 
 
391
        def __call__(self, *args, **kw):
 
392
            subject, buffer = [object.__getattribute__(self, x)
 
393
                               for x in ('_subject', '_buffer')]
 
394
 
 
395
            result = subject(*args, **kw)
 
396
            if type(result) not in ReplayableSession.Natives:
 
397
                buffer.append(ReplayableSession.Callable)
 
398
                return type(self)(buffer, result)
 
399
            else:
 
400
                buffer.append(result)
 
401
                return result
 
402
 
 
403
        @property
 
404
        def _sqla_unwrap(self):
 
405
            return self._subject
 
406
 
 
407
        def __getattribute__(self, key):
 
408
            try:
 
409
                return object.__getattribute__(self, key)
 
410
            except AttributeError:
 
411
                pass
 
412
 
 
413
            subject, buffer = [object.__getattribute__(self, x)
 
414
                               for x in ('_subject', '_buffer')]
 
415
            try:
 
416
                result = type(subject).__getattribute__(subject, key)
 
417
            except AttributeError:
 
418
                buffer.append(ReplayableSession.NoAttribute)
 
419
                raise
 
420
            else:
 
421
                if type(result) not in ReplayableSession.Natives:
 
422
                    buffer.append(ReplayableSession.Callable)
 
423
                    return type(self)(buffer, result)
 
424
                else:
 
425
                    buffer.append(result)
 
426
                    return result
 
427
 
 
428
    class Player(object):
 
429
        def __init__(self, buffer):
 
430
            self._buffer = buffer
 
431
 
 
432
        def __call__(self, *args, **kw):
 
433
            buffer = object.__getattribute__(self, '_buffer')
 
434
            result = buffer.popleft()
 
435
            if result is ReplayableSession.Callable:
 
436
                return self
 
437
            else:
 
438
                return result
 
439
 
 
440
        @property
 
441
        def _sqla_unwrap(self):
 
442
            return None
 
443
 
 
444
        def __getattribute__(self, key):
 
445
            try:
 
446
                return object.__getattribute__(self, key)
 
447
            except AttributeError:
 
448
                pass
 
449
            buffer = object.__getattribute__(self, '_buffer')
 
450
            result = buffer.popleft()
 
451
            if result is ReplayableSession.Callable:
 
452
                return self
 
453
            elif result is ReplayableSession.NoAttribute:
 
454
                raise AttributeError(key)
 
455
            else:
 
456
                return result