~ubuntu-branches/debian/jessie/sqlalchemy/jessie

« back to all changes in this revision

Viewing changes to test/lib/engines.py

  • Committer: Package Import Robot
  • Author(s): Piotr Ożarowski, Jakub Wilk, Piotr Ożarowski
  • Date: 2013-07-06 20:53:52 UTC
  • mfrom: (1.4.23) (16.1.17 experimental)
  • Revision ID: package-import@ubuntu.com-20130706205352-ryppl1eto3illd79
Tags: 0.8.2-1
[ Jakub Wilk ]
* Use canonical URIs for Vcs-* fields.

[ Piotr Ożarowski ]
* New upstream release
* Upload to unstable
* Build depend on python3-all instead of -dev, extensions are not built for
  Python 3.X 

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
import sys, types, weakref
2
 
from collections import deque
3
 
from test.bootstrap import config
4
 
from test.lib.util import decorator, gc_collect
5
 
from sqlalchemy.util import callable
6
 
from sqlalchemy import event, pool
7
 
from sqlalchemy.engine import base as engine_base
8
 
import re
9
 
import warnings
10
 
 
11
 
class ConnectionKiller(object):
12
 
    def __init__(self):
13
 
        self.proxy_refs = weakref.WeakKeyDictionary()
14
 
        self.testing_engines = weakref.WeakKeyDictionary()
15
 
        self.conns = set()
16
 
 
17
 
    def add_engine(self, engine):
18
 
        self.testing_engines[engine] = True
19
 
 
20
 
    def connect(self, dbapi_conn, con_record):
21
 
        self.conns.add(dbapi_conn)
22
 
 
23
 
    def checkout(self, dbapi_con, con_record, con_proxy):
24
 
        self.proxy_refs[con_proxy] = True
25
 
 
26
 
    def _safe(self, fn):
27
 
        try:
28
 
            fn()
29
 
        except (SystemExit, KeyboardInterrupt):
30
 
            raise
31
 
        except Exception, e:
32
 
            warnings.warn(
33
 
                    "testing_reaper couldn't "
34
 
                    "rollback/close connection: %s" % e)
35
 
 
36
 
    def rollback_all(self):
37
 
        for rec in self.proxy_refs.keys():
38
 
            if rec is not None and rec.is_valid:
39
 
                self._safe(rec.rollback)
40
 
 
41
 
    def close_all(self):
42
 
        for rec in self.proxy_refs.keys():
43
 
            if rec is not None:
44
 
                self._safe(rec._close)
45
 
 
46
 
    def _after_test_ctx(self):
47
 
        pass
48
 
        # this can cause a deadlock with pg8000 - pg8000 acquires
49
 
        # prepared statment lock inside of rollback() - if async gc
50
 
        # is collecting in finalize_fairy, deadlock.
51
 
        # not sure if this should be if pypy/jython only
52
 
        #for conn in self.conns:
53
 
        #    self._safe(conn.rollback)
54
 
 
55
 
    def _stop_test_ctx(self):
56
 
        if config.options.low_connections:
57
 
            self._stop_test_ctx_minimal()
58
 
        else:
59
 
            self._stop_test_ctx_aggressive()
60
 
 
61
 
    def _stop_test_ctx_minimal(self):
62
 
        from test.lib import testing
63
 
        self.close_all()
64
 
 
65
 
        self.conns = set()
66
 
 
67
 
        for rec in self.testing_engines.keys():
68
 
            if rec is not testing.db:
69
 
                rec.dispose()
70
 
 
71
 
    def _stop_test_ctx_aggressive(self):
72
 
        self.close_all()
73
 
        for conn in self.conns:
74
 
            self._safe(conn.close)
75
 
        self.conns = set()
76
 
        for rec in self.testing_engines.keys():
77
 
            rec.dispose()
78
 
 
79
 
    def assert_all_closed(self):
80
 
        for rec in self.proxy_refs:
81
 
            if rec.is_valid:
82
 
                assert False
83
 
 
84
 
testing_reaper = ConnectionKiller()
85
 
 
86
 
def drop_all_tables(metadata, bind):
87
 
    testing_reaper.close_all()
88
 
    if hasattr(bind, 'close'):
89
 
        bind.close()
90
 
    metadata.drop_all(bind)
91
 
 
92
 
@decorator
93
 
def assert_conns_closed(fn, *args, **kw):
94
 
    try:
95
 
        fn(*args, **kw)
96
 
    finally:
97
 
        testing_reaper.assert_all_closed()
98
 
 
99
 
@decorator
100
 
def rollback_open_connections(fn, *args, **kw):
101
 
    """Decorator that rolls back all open connections after fn execution."""
102
 
 
103
 
    try:
104
 
        fn(*args, **kw)
105
 
    finally:
106
 
        testing_reaper.rollback_all()
107
 
 
108
 
@decorator
109
 
def close_first(fn, *args, **kw):
110
 
    """Decorator that closes all connections before fn execution."""
111
 
 
112
 
    testing_reaper.close_all()
113
 
    fn(*args, **kw)
114
 
 
115
 
 
116
 
@decorator
117
 
def close_open_connections(fn, *args, **kw):
118
 
    """Decorator that closes all connections after fn execution."""
119
 
    try:
120
 
        fn(*args, **kw)
121
 
    finally:
122
 
        testing_reaper.close_all()
123
 
 
124
 
def all_dialects(exclude=None):
125
 
    import sqlalchemy.databases as d
126
 
    for name in d.__all__:
127
 
        # TEMPORARY
128
 
        if exclude and name in exclude:
129
 
            continue
130
 
        mod = getattr(d, name, None)
131
 
        if not mod:
132
 
            mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
133
 
        yield mod.dialect()
134
 
 
135
 
class ReconnectFixture(object):
136
 
    def __init__(self, dbapi):
137
 
        self.dbapi = dbapi
138
 
        self.connections = []
139
 
 
140
 
    def __getattr__(self, key):
141
 
        return getattr(self.dbapi, key)
142
 
 
143
 
    def connect(self, *args, **kwargs):
144
 
        conn = self.dbapi.connect(*args, **kwargs)
145
 
        self.connections.append(conn)
146
 
        return conn
147
 
 
148
 
    def _safe(self, fn):
149
 
        try:
150
 
            fn()
151
 
        except (SystemExit, KeyboardInterrupt):
152
 
            raise
153
 
        except Exception, e:
154
 
            warnings.warn(
155
 
                    "ReconnectFixture couldn't "
156
 
                    "close connection: %s" % e)
157
 
 
158
 
    def shutdown(self):
159
 
        # TODO: this doesn't cover all cases
160
 
        # as nicely as we'd like, namely MySQLdb.
161
 
        # would need to implement R. Brewer's
162
 
        # proxy server idea to get better
163
 
        # coverage.
164
 
        for c in list(self.connections):
165
 
            self._safe(c.close)
166
 
        self.connections = []
167
 
 
168
 
def reconnecting_engine(url=None, options=None):
169
 
    url = url or config.db_url
170
 
    dbapi = config.db.dialect.dbapi
171
 
    if not options:
172
 
        options = {}
173
 
    options['module'] = ReconnectFixture(dbapi)
174
 
    engine = testing_engine(url, options)
175
 
    _dispose = engine.dispose
176
 
    def dispose():
177
 
        engine.dialect.dbapi.shutdown()
178
 
        _dispose()
179
 
    engine.test_shutdown = engine.dialect.dbapi.shutdown
180
 
    engine.dispose = dispose
181
 
    return engine
182
 
 
183
 
def testing_engine(url=None, options=None):
184
 
    """Produce an engine configured by --options with optional overrides."""
185
 
 
186
 
    from sqlalchemy import create_engine
187
 
    from test.lib.assertsql import asserter
188
 
 
189
 
    if not options:
190
 
        use_reaper = True
191
 
    else:
192
 
        use_reaper = options.pop('use_reaper', True)
193
 
 
194
 
    url = url or config.db_url
195
 
    options = options or config.db_opts
196
 
 
197
 
    engine = create_engine(url, **options)
198
 
    if isinstance(engine.pool, pool.QueuePool):
199
 
        engine.pool._timeout = 0
200
 
        engine.pool._max_overflow = 0
201
 
    event.listen(engine, 'after_execute', asserter.execute)
202
 
    event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
203
 
    if use_reaper:
204
 
        event.listen(engine.pool, 'connect', testing_reaper.connect)
205
 
        event.listen(engine.pool, 'checkout', testing_reaper.checkout)
206
 
        testing_reaper.add_engine(engine)
207
 
 
208
 
    return engine
209
 
 
210
 
def utf8_engine(url=None, options=None):
211
 
    """Hook for dialects or drivers that don't handle utf8 by default."""
212
 
 
213
 
    from sqlalchemy.engine import url as engine_url
214
 
 
215
 
    if config.db.dialect.name == 'mysql' and \
216
 
        config.db.driver in ['mysqldb', 'pymysql']:
217
 
        # note 1.2.1.gamma.6 or greater of MySQLdb
218
 
        # needed here
219
 
        url = url or config.db_url
220
 
        url = engine_url.make_url(url)
221
 
        url.query['charset'] = 'utf8'
222
 
        url.query['use_unicode'] = '0'
223
 
        url = str(url)
224
 
 
225
 
    return testing_engine(url, options)
226
 
 
227
 
def mock_engine(dialect_name=None):
228
 
    """Provides a mocking engine based on the current testing.db.
229
 
 
230
 
    This is normally used to test DDL generation flow as emitted
231
 
    by an Engine.
232
 
 
233
 
    It should not be used in other cases, as assert_compile() and
234
 
    assert_sql_execution() are much better choices with fewer
235
 
    moving parts.
236
 
 
237
 
    """
238
 
 
239
 
    from sqlalchemy import create_engine
240
 
 
241
 
    if not dialect_name:
242
 
        dialect_name = config.db.name
243
 
 
244
 
    buffer = []
245
 
    def executor(sql, *a, **kw):
246
 
        buffer.append(sql)
247
 
    def assert_sql(stmts):
248
 
        recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
249
 
        assert  recv == stmts, recv
250
 
    def print_sql():
251
 
        d = engine.dialect
252
 
        return "\n".join(
253
 
            str(s.compile(dialect=d))
254
 
            for s in engine.mock
255
 
        )
256
 
    engine = create_engine(dialect_name + '://',
257
 
                           strategy='mock', executor=executor)
258
 
    assert not hasattr(engine, 'mock')
259
 
    engine.mock = buffer
260
 
    engine.assert_sql = assert_sql
261
 
    engine.print_sql = print_sql
262
 
    return engine
263
 
 
264
 
class DBAPIProxyCursor(object):
265
 
    """Proxy a DBAPI cursor.
266
 
 
267
 
    Tests can provide subclasses of this to intercept
268
 
    DBAPI-level cursor operations.
269
 
 
270
 
    """
271
 
    def __init__(self, engine, conn):
272
 
        self.engine = engine
273
 
        self.connection = conn
274
 
        self.cursor = conn.cursor()
275
 
 
276
 
    def execute(self, stmt, parameters=None, **kw):
277
 
        if parameters:
278
 
            return self.cursor.execute(stmt, parameters, **kw)
279
 
        else:
280
 
            return self.cursor.execute(stmt, **kw)
281
 
 
282
 
    def executemany(self, stmt, params, **kw):
283
 
        return self.cursor.executemany(stmt, params, **kw)
284
 
 
285
 
    def __getattr__(self, key):
286
 
        return getattr(self.cursor, key)
287
 
 
288
 
class DBAPIProxyConnection(object):
289
 
    """Proxy a DBAPI connection.
290
 
 
291
 
    Tests can provide subclasses of this to intercept
292
 
    DBAPI-level connection operations.
293
 
 
294
 
    """
295
 
    def __init__(self, engine, cursor_cls):
296
 
        self.conn = self._sqla_unwrap = engine.pool._creator()
297
 
        self.engine = engine
298
 
        self.cursor_cls = cursor_cls
299
 
 
300
 
    def cursor(self):
301
 
        return self.cursor_cls(self.engine, self.conn)
302
 
 
303
 
    def close(self):
304
 
        self.conn.close()
305
 
 
306
 
    def __getattr__(self, key):
307
 
        return getattr(self.conn, key)
308
 
 
309
 
def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
310
 
    """Produce an engine that provides proxy hooks for
311
 
    common methods.
312
 
 
313
 
    """
314
 
    def mock_conn():
315
 
        return conn_cls(config.db, cursor_cls)
316
 
    return testing_engine(options={'creator':mock_conn})
317
 
 
318
 
class ReplayableSession(object):
319
 
    """A simple record/playback tool.
320
 
 
321
 
    This is *not* a mock testing class.  It only records a session for later
322
 
    playback and makes no assertions on call consistency whatsoever.  It's
323
 
    unlikely to be suitable for anything other than DB-API recording.
324
 
 
325
 
    """
326
 
 
327
 
    Callable = object()
328
 
    NoAttribute = object()
329
 
 
330
 
    # Py3K
331
 
    #Natives = set([getattr(types, t)
332
 
    #               for t in dir(types) if not t.startswith('_')]). \
333
 
    #               union([type(t) if not isinstance(t, type)
334
 
    #                        else t for t in __builtins__.values()]).\
335
 
    #               difference([getattr(types, t)
336
 
    #                        for t in ('FunctionType', 'BuiltinFunctionType',
337
 
    #                                  'MethodType', 'BuiltinMethodType',
338
 
    #                                  'LambdaType', )])
339
 
    # Py2K
340
 
    Natives = set([getattr(types, t)
341
 
                   for t in dir(types) if not t.startswith('_')]). \
342
 
                   difference([getattr(types, t)
343
 
                           for t in ('FunctionType', 'BuiltinFunctionType',
344
 
                                     'MethodType', 'BuiltinMethodType',
345
 
                                     'LambdaType', 'UnboundMethodType',)])
346
 
    # end Py2K
347
 
 
348
 
    def __init__(self):
349
 
        self.buffer = deque()
350
 
 
351
 
    def recorder(self, base):
352
 
        return self.Recorder(self.buffer, base)
353
 
 
354
 
    def player(self):
355
 
        return self.Player(self.buffer)
356
 
 
357
 
    class Recorder(object):
358
 
        def __init__(self, buffer, subject):
359
 
            self._buffer = buffer
360
 
            self._subject = subject
361
 
 
362
 
        def __call__(self, *args, **kw):
363
 
            subject, buffer = [object.__getattribute__(self, x)
364
 
                               for x in ('_subject', '_buffer')]
365
 
 
366
 
            result = subject(*args, **kw)
367
 
            if type(result) not in ReplayableSession.Natives:
368
 
                buffer.append(ReplayableSession.Callable)
369
 
                return type(self)(buffer, result)
370
 
            else:
371
 
                buffer.append(result)
372
 
                return result
373
 
 
374
 
        @property
375
 
        def _sqla_unwrap(self):
376
 
            return self._subject
377
 
 
378
 
        def __getattribute__(self, key):
379
 
            try:
380
 
                return object.__getattribute__(self, key)
381
 
            except AttributeError:
382
 
                pass
383
 
 
384
 
            subject, buffer = [object.__getattribute__(self, x)
385
 
                               for x in ('_subject', '_buffer')]
386
 
            try:
387
 
                result = type(subject).__getattribute__(subject, key)
388
 
            except AttributeError:
389
 
                buffer.append(ReplayableSession.NoAttribute)
390
 
                raise
391
 
            else:
392
 
                if type(result) not in ReplayableSession.Natives:
393
 
                    buffer.append(ReplayableSession.Callable)
394
 
                    return type(self)(buffer, result)
395
 
                else:
396
 
                    buffer.append(result)
397
 
                    return result
398
 
 
399
 
    class Player(object):
400
 
        def __init__(self, buffer):
401
 
            self._buffer = buffer
402
 
 
403
 
        def __call__(self, *args, **kw):
404
 
            buffer = object.__getattribute__(self, '_buffer')
405
 
            result = buffer.popleft()
406
 
            if result is ReplayableSession.Callable:
407
 
                return self
408
 
            else:
409
 
                return result
410
 
 
411
 
        @property
412
 
        def _sqla_unwrap(self):
413
 
            return None
414
 
 
415
 
        def __getattribute__(self, key):
416
 
            try:
417
 
                return object.__getattribute__(self, key)
418
 
            except AttributeError:
419
 
                pass
420
 
            buffer = object.__getattribute__(self, '_buffer')
421
 
            result = buffer.popleft()
422
 
            if result is ReplayableSession.Callable:
423
 
                return self
424
 
            elif result is ReplayableSession.NoAttribute:
425
 
                raise AttributeError(key)
426
 
            else:
427
 
                return result
428