1
from __future__ import absolute_import
5
from collections import deque
7
from .util import decorator
8
from .. import event, pool
13
class ConnectionKiller(object):
16
self.proxy_refs = weakref.WeakKeyDictionary()
17
self.testing_engines = weakref.WeakKeyDictionary()
20
def add_engine(self, engine):
21
self.testing_engines[engine] = True
23
def connect(self, dbapi_conn, con_record):
24
self.conns.add((dbapi_conn, con_record))
26
def checkout(self, dbapi_con, con_record, con_proxy):
27
self.proxy_refs[con_proxy] = True
32
except (SystemExit, KeyboardInterrupt):
36
"testing_reaper couldn't "
37
"rollback/close connection: %s" % e)
39
def rollback_all(self):
40
for rec in self.proxy_refs.keys():
41
if rec is not None and rec.is_valid:
42
self._safe(rec.rollback)
45
for rec in self.proxy_refs.keys():
47
self._safe(rec._close)
49
def _after_test_ctx(self):
51
# this can cause a deadlock with pg8000 - pg8000 acquires
52
# prepared statment lock inside of rollback() - if async gc
53
# is collecting in finalize_fairy, deadlock.
54
# not sure if this should be if pypy/jython only.
55
# note that firebird/fdb definitely needs this though
56
for conn, rec in self.conns:
57
self._safe(conn.rollback)
59
def _stop_test_ctx(self):
60
if config.options.low_connections:
61
self._stop_test_ctx_minimal()
63
self._stop_test_ctx_aggressive()
65
def _stop_test_ctx_minimal(self):
70
for rec in self.testing_engines.keys():
71
if rec is not config.db:
74
def _stop_test_ctx_aggressive(self):
76
for conn, rec in self.conns:
77
self._safe(conn.close)
81
for rec in self.testing_engines.keys():
84
def assert_all_closed(self):
85
for rec in self.proxy_refs:
89
testing_reaper = ConnectionKiller()
92
def drop_all_tables(metadata, bind):
93
testing_reaper.close_all()
94
if hasattr(bind, 'close'):
96
metadata.drop_all(bind)
100
def assert_conns_closed(fn, *args, **kw):
104
testing_reaper.assert_all_closed()
108
def rollback_open_connections(fn, *args, **kw):
109
"""Decorator that rolls back all open connections after fn execution."""
114
testing_reaper.rollback_all()
118
def close_first(fn, *args, **kw):
119
"""Decorator that closes all connections before fn execution."""
121
testing_reaper.close_all()
126
def close_open_connections(fn, *args, **kw):
127
"""Decorator that closes all connections after fn execution."""
131
testing_reaper.close_all()
134
def all_dialects(exclude=None):
135
import sqlalchemy.databases as d
136
for name in d.__all__:
138
if exclude and name in exclude:
140
mod = getattr(d, name, None)
142
mod = getattr(__import__(
143
'sqlalchemy.databases.%s' % name).databases, name)
147
class ReconnectFixture(object):
149
def __init__(self, dbapi):
151
self.connections = []
153
def __getattr__(self, key):
154
return getattr(self.dbapi, key)
156
def connect(self, *args, **kwargs):
157
conn = self.dbapi.connect(*args, **kwargs)
158
self.connections.append(conn)
164
except (SystemExit, KeyboardInterrupt):
168
"ReconnectFixture couldn't "
169
"close connection: %s" % e)
172
# TODO: this doesn't cover all cases
173
# as nicely as we'd like, namely MySQLdb.
174
# would need to implement R. Brewer's
175
# proxy server idea to get better
177
for c in list(self.connections):
179
self.connections = []
182
def reconnecting_engine(url=None, options=None):
183
url = url or config.db_url
184
dbapi = config.db.dialect.dbapi
187
options['module'] = ReconnectFixture(dbapi)
188
engine = testing_engine(url, options)
189
_dispose = engine.dispose
192
engine.dialect.dbapi.shutdown()
195
engine.test_shutdown = engine.dialect.dbapi.shutdown
196
engine.dispose = dispose
200
def testing_engine(url=None, options=None):
201
"""Produce an engine configured by --options with optional overrides."""
203
from sqlalchemy import create_engine
204
from .assertsql import asserter
209
use_reaper = options.pop('use_reaper', True)
211
url = url or config.db_url
213
options = config.db_opts
215
engine = create_engine(url, **options)
216
if isinstance(engine.pool, pool.QueuePool):
217
engine.pool._timeout = 0
218
engine.pool._max_overflow = 0
219
event.listen(engine, 'after_execute', asserter.execute)
220
event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
222
event.listen(engine.pool, 'connect', testing_reaper.connect)
223
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
224
testing_reaper.add_engine(engine)
229
def utf8_engine(url=None, options=None):
230
"""Hook for dialects or drivers that don't handle utf8 by default."""
232
from sqlalchemy.engine import url as engine_url
234
if config.db.dialect.name == 'mysql' and \
235
config.db.driver in ['mysqldb', 'pymysql', 'cymysql']:
236
# note 1.2.1.gamma.6 or greater of MySQLdb
238
url = url or config.db_url
239
url = engine_url.make_url(url)
240
url.query['charset'] = 'utf8'
241
url.query['use_unicode'] = '0'
244
return testing_engine(url, options)
247
def mock_engine(dialect_name=None):
248
"""Provides a mocking engine based on the current testing.db.
250
This is normally used to test DDL generation flow as emitted
253
It should not be used in other cases, as assert_compile() and
254
assert_sql_execution() are much better choices with fewer
259
from sqlalchemy import create_engine
262
dialect_name = config.db.name
266
def executor(sql, *a, **kw):
269
def assert_sql(stmts):
270
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
271
assert recv == stmts, recv
276
str(s.compile(dialect=d))
280
engine = create_engine(dialect_name + '://',
281
strategy='mock', executor=executor)
282
assert not hasattr(engine, 'mock')
284
engine.assert_sql = assert_sql
285
engine.print_sql = print_sql
289
class DBAPIProxyCursor(object):
290
"""Proxy a DBAPI cursor.
292
Tests can provide subclasses of this to intercept
293
DBAPI-level cursor operations.
296
def __init__(self, engine, conn):
298
self.connection = conn
299
self.cursor = conn.cursor()
301
def execute(self, stmt, parameters=None, **kw):
303
return self.cursor.execute(stmt, parameters, **kw)
305
return self.cursor.execute(stmt, **kw)
307
def executemany(self, stmt, params, **kw):
308
return self.cursor.executemany(stmt, params, **kw)
310
def __getattr__(self, key):
311
return getattr(self.cursor, key)
314
class DBAPIProxyConnection(object):
315
"""Proxy a DBAPI connection.
317
Tests can provide subclasses of this to intercept
318
DBAPI-level connection operations.
321
def __init__(self, engine, cursor_cls):
322
self.conn = self._sqla_unwrap = engine.pool._creator()
324
self.cursor_cls = cursor_cls
327
return self.cursor_cls(self.engine, self.conn)
332
def __getattr__(self, key):
333
return getattr(self.conn, key)
336
def proxying_engine(conn_cls=DBAPIProxyConnection,
337
cursor_cls=DBAPIProxyCursor):
338
"""Produce an engine that provides proxy hooks for
343
return conn_cls(config.db, cursor_cls)
344
return testing_engine(options={'creator': mock_conn})
347
class ReplayableSession(object):
348
"""A simple record/playback tool.
350
This is *not* a mock testing class. It only records a session for later
351
playback and makes no assertions on call consistency whatsoever. It's
352
unlikely to be suitable for anything other than DB-API recording.
357
NoAttribute = object()
360
#Natives = set([getattr(types, t)
361
# for t in dir(types) if not t.startswith('_')]). \
362
# union([type(t) if not isinstance(t, type)
363
# else t for t in __builtins__.values()]).\
364
# difference([getattr(types, t)
365
# for t in ('FunctionType', 'BuiltinFunctionType',
366
# 'MethodType', 'BuiltinMethodType',
369
Natives = set([getattr(types, t)
370
for t in dir(types) if not t.startswith('_')]). \
371
difference([getattr(types, t)
372
for t in ('FunctionType', 'BuiltinFunctionType',
373
'MethodType', 'BuiltinMethodType',
374
'LambdaType', 'UnboundMethodType',)])
378
self.buffer = deque()
380
def recorder(self, base):
381
return self.Recorder(self.buffer, base)
384
return self.Player(self.buffer)
386
class Recorder(object):
387
def __init__(self, buffer, subject):
388
self._buffer = buffer
389
self._subject = subject
391
def __call__(self, *args, **kw):
392
subject, buffer = [object.__getattribute__(self, x)
393
for x in ('_subject', '_buffer')]
395
result = subject(*args, **kw)
396
if type(result) not in ReplayableSession.Natives:
397
buffer.append(ReplayableSession.Callable)
398
return type(self)(buffer, result)
400
buffer.append(result)
404
def _sqla_unwrap(self):
407
def __getattribute__(self, key):
409
return object.__getattribute__(self, key)
410
except AttributeError:
413
subject, buffer = [object.__getattribute__(self, x)
414
for x in ('_subject', '_buffer')]
416
result = type(subject).__getattribute__(subject, key)
417
except AttributeError:
418
buffer.append(ReplayableSession.NoAttribute)
421
if type(result) not in ReplayableSession.Natives:
422
buffer.append(ReplayableSession.Callable)
423
return type(self)(buffer, result)
425
buffer.append(result)
428
class Player(object):
429
def __init__(self, buffer):
430
self._buffer = buffer
432
def __call__(self, *args, **kw):
433
buffer = object.__getattribute__(self, '_buffer')
434
result = buffer.popleft()
435
if result is ReplayableSession.Callable:
441
def _sqla_unwrap(self):
444
def __getattribute__(self, key):
446
return object.__getattribute__(self, key)
447
except AttributeError:
449
buffer = object.__getattribute__(self, '_buffer')
450
result = buffer.popleft()
451
if result is ReplayableSession.Callable:
453
elif result is ReplayableSession.NoAttribute:
454
raise AttributeError(key)