1
import sys, types, weakref
2
from collections import deque
3
from test.bootstrap import config
4
from test.lib.util import decorator, gc_collect
5
from sqlalchemy.util import callable
6
from sqlalchemy import event, pool
7
from sqlalchemy.engine import base as engine_base
11
class ConnectionKiller(object):
13
self.proxy_refs = weakref.WeakKeyDictionary()
14
self.testing_engines = weakref.WeakKeyDictionary()
17
def add_engine(self, engine):
18
self.testing_engines[engine] = True
20
def connect(self, dbapi_conn, con_record):
21
self.conns.add(dbapi_conn)
23
def checkout(self, dbapi_con, con_record, con_proxy):
24
self.proxy_refs[con_proxy] = True
29
except (SystemExit, KeyboardInterrupt):
33
"testing_reaper couldn't "
34
"rollback/close connection: %s" % e)
36
def rollback_all(self):
37
for rec in self.proxy_refs.keys():
38
if rec is not None and rec.is_valid:
39
self._safe(rec.rollback)
42
for rec in self.proxy_refs.keys():
44
self._safe(rec._close)
46
def _after_test_ctx(self):
48
# this can cause a deadlock with pg8000 - pg8000 acquires
49
# prepared statment lock inside of rollback() - if async gc
50
# is collecting in finalize_fairy, deadlock.
51
# not sure if this should be if pypy/jython only
52
#for conn in self.conns:
53
# self._safe(conn.rollback)
55
def _stop_test_ctx(self):
56
if config.options.low_connections:
57
self._stop_test_ctx_minimal()
59
self._stop_test_ctx_aggressive()
61
def _stop_test_ctx_minimal(self):
62
from test.lib import testing
67
for rec in self.testing_engines.keys():
68
if rec is not testing.db:
71
def _stop_test_ctx_aggressive(self):
73
for conn in self.conns:
74
self._safe(conn.close)
76
for rec in self.testing_engines.keys():
79
def assert_all_closed(self):
80
for rec in self.proxy_refs:
84
testing_reaper = ConnectionKiller()
86
def drop_all_tables(metadata, bind):
87
testing_reaper.close_all()
88
if hasattr(bind, 'close'):
90
metadata.drop_all(bind)
93
def assert_conns_closed(fn, *args, **kw):
97
testing_reaper.assert_all_closed()
100
def rollback_open_connections(fn, *args, **kw):
101
"""Decorator that rolls back all open connections after fn execution."""
106
testing_reaper.rollback_all()
109
def close_first(fn, *args, **kw):
110
"""Decorator that closes all connections before fn execution."""
112
testing_reaper.close_all()
117
def close_open_connections(fn, *args, **kw):
118
"""Decorator that closes all connections after fn execution."""
122
testing_reaper.close_all()
124
def all_dialects(exclude=None):
125
import sqlalchemy.databases as d
126
for name in d.__all__:
128
if exclude and name in exclude:
130
mod = getattr(d, name, None)
132
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
135
class ReconnectFixture(object):
136
def __init__(self, dbapi):
138
self.connections = []
140
def __getattr__(self, key):
141
return getattr(self.dbapi, key)
143
def connect(self, *args, **kwargs):
144
conn = self.dbapi.connect(*args, **kwargs)
145
self.connections.append(conn)
151
except (SystemExit, KeyboardInterrupt):
155
"ReconnectFixture couldn't "
156
"close connection: %s" % e)
159
# TODO: this doesn't cover all cases
160
# as nicely as we'd like, namely MySQLdb.
161
# would need to implement R. Brewer's
162
# proxy server idea to get better
164
for c in list(self.connections):
166
self.connections = []
168
def reconnecting_engine(url=None, options=None):
169
url = url or config.db_url
170
dbapi = config.db.dialect.dbapi
173
options['module'] = ReconnectFixture(dbapi)
174
engine = testing_engine(url, options)
175
_dispose = engine.dispose
177
engine.dialect.dbapi.shutdown()
179
engine.test_shutdown = engine.dialect.dbapi.shutdown
180
engine.dispose = dispose
183
def testing_engine(url=None, options=None):
184
"""Produce an engine configured by --options with optional overrides."""
186
from sqlalchemy import create_engine
187
from test.lib.assertsql import asserter
192
use_reaper = options.pop('use_reaper', True)
194
url = url or config.db_url
195
options = options or config.db_opts
197
engine = create_engine(url, **options)
198
if isinstance(engine.pool, pool.QueuePool):
199
engine.pool._timeout = 0
200
engine.pool._max_overflow = 0
201
event.listen(engine, 'after_execute', asserter.execute)
202
event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
204
event.listen(engine.pool, 'connect', testing_reaper.connect)
205
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
206
testing_reaper.add_engine(engine)
210
def utf8_engine(url=None, options=None):
211
"""Hook for dialects or drivers that don't handle utf8 by default."""
213
from sqlalchemy.engine import url as engine_url
215
if config.db.dialect.name == 'mysql' and \
216
config.db.driver in ['mysqldb', 'pymysql']:
217
# note 1.2.1.gamma.6 or greater of MySQLdb
219
url = url or config.db_url
220
url = engine_url.make_url(url)
221
url.query['charset'] = 'utf8'
222
url.query['use_unicode'] = '0'
225
return testing_engine(url, options)
227
def mock_engine(dialect_name=None):
228
"""Provides a mocking engine based on the current testing.db.
230
This is normally used to test DDL generation flow as emitted
233
It should not be used in other cases, as assert_compile() and
234
assert_sql_execution() are much better choices with fewer
239
from sqlalchemy import create_engine
242
dialect_name = config.db.name
245
def executor(sql, *a, **kw):
247
def assert_sql(stmts):
248
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
249
assert recv == stmts, recv
253
str(s.compile(dialect=d))
256
engine = create_engine(dialect_name + '://',
257
strategy='mock', executor=executor)
258
assert not hasattr(engine, 'mock')
260
engine.assert_sql = assert_sql
261
engine.print_sql = print_sql
264
class DBAPIProxyCursor(object):
265
"""Proxy a DBAPI cursor.
267
Tests can provide subclasses of this to intercept
268
DBAPI-level cursor operations.
271
def __init__(self, engine, conn):
273
self.connection = conn
274
self.cursor = conn.cursor()
276
def execute(self, stmt, parameters=None, **kw):
278
return self.cursor.execute(stmt, parameters, **kw)
280
return self.cursor.execute(stmt, **kw)
282
def executemany(self, stmt, params, **kw):
283
return self.cursor.executemany(stmt, params, **kw)
285
def __getattr__(self, key):
286
return getattr(self.cursor, key)
288
class DBAPIProxyConnection(object):
289
"""Proxy a DBAPI connection.
291
Tests can provide subclasses of this to intercept
292
DBAPI-level connection operations.
295
def __init__(self, engine, cursor_cls):
296
self.conn = self._sqla_unwrap = engine.pool._creator()
298
self.cursor_cls = cursor_cls
301
return self.cursor_cls(self.engine, self.conn)
306
def __getattr__(self, key):
307
return getattr(self.conn, key)
309
def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
310
"""Produce an engine that provides proxy hooks for
315
return conn_cls(config.db, cursor_cls)
316
return testing_engine(options={'creator':mock_conn})
318
class ReplayableSession(object):
319
"""A simple record/playback tool.
321
This is *not* a mock testing class. It only records a session for later
322
playback and makes no assertions on call consistency whatsoever. It's
323
unlikely to be suitable for anything other than DB-API recording.
328
NoAttribute = object()
331
#Natives = set([getattr(types, t)
332
# for t in dir(types) if not t.startswith('_')]). \
333
# union([type(t) if not isinstance(t, type)
334
# else t for t in __builtins__.values()]).\
335
# difference([getattr(types, t)
336
# for t in ('FunctionType', 'BuiltinFunctionType',
337
# 'MethodType', 'BuiltinMethodType',
340
Natives = set([getattr(types, t)
341
for t in dir(types) if not t.startswith('_')]). \
342
difference([getattr(types, t)
343
for t in ('FunctionType', 'BuiltinFunctionType',
344
'MethodType', 'BuiltinMethodType',
345
'LambdaType', 'UnboundMethodType',)])
349
self.buffer = deque()
351
def recorder(self, base):
352
return self.Recorder(self.buffer, base)
355
return self.Player(self.buffer)
357
class Recorder(object):
358
def __init__(self, buffer, subject):
359
self._buffer = buffer
360
self._subject = subject
362
def __call__(self, *args, **kw):
363
subject, buffer = [object.__getattribute__(self, x)
364
for x in ('_subject', '_buffer')]
366
result = subject(*args, **kw)
367
if type(result) not in ReplayableSession.Natives:
368
buffer.append(ReplayableSession.Callable)
369
return type(self)(buffer, result)
371
buffer.append(result)
375
def _sqla_unwrap(self):
378
def __getattribute__(self, key):
380
return object.__getattribute__(self, key)
381
except AttributeError:
384
subject, buffer = [object.__getattribute__(self, x)
385
for x in ('_subject', '_buffer')]
387
result = type(subject).__getattribute__(subject, key)
388
except AttributeError:
389
buffer.append(ReplayableSession.NoAttribute)
392
if type(result) not in ReplayableSession.Natives:
393
buffer.append(ReplayableSession.Callable)
394
return type(self)(buffer, result)
396
buffer.append(result)
399
class Player(object):
400
def __init__(self, buffer):
401
self._buffer = buffer
403
def __call__(self, *args, **kw):
404
buffer = object.__getattribute__(self, '_buffer')
405
result = buffer.popleft()
406
if result is ReplayableSession.Callable:
412
def _sqla_unwrap(self):
415
def __getattribute__(self, key):
417
return object.__getattribute__(self, key)
418
except AttributeError:
420
buffer = object.__getattribute__(self, '_buffer')
421
result = buffer.popleft()
422
if result is ReplayableSession.Callable:
424
elif result is ReplayableSession.NoAttribute:
425
raise AttributeError(key)