1
# Copyright (c) 2001-2008 Twisted Matrix Laboratories.
2
# See LICENSE for details.
6
Tests for twisted.enterprise.adbapi.
9
from twisted.trial import unittest
13
from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost, safe
14
from twisted.enterprise.adbapi import Connection, Transaction
15
from twisted.enterprise.adbapi import _unreleasedVersion
16
from twisted.internet import reactor, defer, interfaces
17
from twisted.python.failure import Failure
20
simple_table_schema = """
28
"""Test the asynchronous DB-API code."""
32
if interfaces.IReactorThreads(reactor, None) is None:
33
skip = "ADB-API requires threads, no way to test without them"
37
Set up the database and create a connection pool pointing at it.
40
self.dbpool = self.makePool(cp_openfun=self.openfun)
45
d = self.dbpool.runOperation('DROP TABLE simple')
46
d.addCallback(lambda res: self.dbpool.close())
47
d.addCallback(lambda res: self.stopDB())
50
def openfun(self, conn):
51
self.openfun_called[conn] = True
53
def checkOpenfunCalled(self, conn=None):
55
self.failUnless(self.openfun_called)
57
self.failUnless(self.openfun_called.has_key(conn))
60
d = self.dbpool.runOperation(simple_table_schema)
61
if self.test_failures:
62
d.addCallback(self._testPool_1_1)
63
d.addCallback(self._testPool_1_2)
64
d.addCallback(self._testPool_1_3)
65
d.addCallback(self._testPool_1_4)
66
d.addCallback(lambda res: self.flushLoggedErrors())
67
d.addCallback(self._testPool_2)
68
d.addCallback(self._testPool_3)
69
d.addCallback(self._testPool_4)
70
d.addCallback(self._testPool_5)
71
d.addCallback(self._testPool_6)
72
d.addCallback(self._testPool_7)
73
d.addCallback(self._testPool_8)
74
d.addCallback(self._testPool_9)
77
def _testPool_1_1(self, res):
78
d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
79
d.addCallbacks(lambda res: self.fail('no exception'),
83
def _testPool_1_2(self, res):
84
d = defer.maybeDeferred(self.dbpool.runOperation,
85
"deletexxx from NOTABLE")
86
d.addCallbacks(lambda res: self.fail('no exception'),
90
def _testPool_1_3(self, res):
91
d = defer.maybeDeferred(self.dbpool.runInteraction,
93
d.addCallbacks(lambda res: self.fail('no exception'),
97
def _testPool_1_4(self, res):
98
d = defer.maybeDeferred(self.dbpool.runWithConnection,
99
self.bad_withConnection)
100
d.addCallbacks(lambda res: self.fail('no exception'),
104
def _testPool_2(self, res):
105
# verify simple table is empty
106
sql = "select count(1) from simple"
107
d = self.dbpool.runQuery(sql)
109
self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
110
self.checkOpenfunCalled()
111
d.addCallback(_check)
114
def _testPool_3(self, res):
115
sql = "select count(1) from simple"
117
# add some rows to simple table (runOperation)
118
for i in range(self.num_iterations):
119
sql = "insert into simple(x) values(%d)" % i
120
inserts.append(self.dbpool.runOperation(sql))
121
d = defer.gatherResults(inserts)
124
# make sure they were added (runQuery)
125
sql = "select x from simple order by x";
126
d = self.dbpool.runQuery(sql)
128
d.addCallback(_select)
131
self.failUnless(len(rows) == self.num_iterations,
132
"Wrong number of rows")
133
for i in range(self.num_iterations):
134
self.failUnless(len(rows[i]) == 1, "Wrong size row")
135
self.failUnless(rows[i][0] == i, "Values not returned.")
136
d.addCallback(_check)
140
def _testPool_4(self, res):
142
d = self.dbpool.runInteraction(self.interaction)
143
d.addCallback(lambda res: self.assertEquals(res, "done"))
146
def _testPool_5(self, res):
148
d = self.dbpool.runWithConnection(self.withConnection)
149
d.addCallback(lambda res: self.assertEquals(res, "done"))
152
def _testPool_6(self, res):
153
# Test a withConnection cannot be closed
154
d = self.dbpool.runWithConnection(self.close_withConnection)
157
def _testPool_7(self, res):
158
# give the pool a workout
160
for i in range(self.num_iterations):
161
sql = "select x from simple where x = %d" % i
162
ds.append(self.dbpool.runQuery(sql))
163
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
165
for i in range(self.num_iterations):
166
self.failUnless(result[i][1][0][0] == i, "Value not returned")
167
dlist.addCallback(_check)
170
def _testPool_8(self, res):
171
# now delete everything
173
for i in range(self.num_iterations):
174
sql = "delete from simple where x = %d" % i
175
ds.append(self.dbpool.runOperation(sql))
176
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
179
def _testPool_9(self, res):
180
# verify simple table is empty
181
sql = "select count(1) from simple"
182
d = self.dbpool.runQuery(sql)
184
self.failUnless(int(row[0][0]) == 0,
185
"Didn't successfully delete table contents")
187
d.addCallback(_check)
190
def checkConnect(self):
191
"""Check the connect/disconnect synchronous calls."""
192
conn = self.dbpool.connect()
193
self.checkOpenfunCalled(conn)
195
curs.execute("insert into simple(x) values(1)")
196
curs.execute("select x from simple")
197
res = curs.fetchall()
198
self.failUnlessEqual(len(res), 1)
199
self.failUnlessEqual(len(res[0]), 1)
200
self.failUnlessEqual(res[0][0], 1)
201
curs.execute("delete from simple")
202
curs.execute("select x from simple")
203
self.failUnlessEqual(len(curs.fetchall()), 0)
205
self.dbpool.disconnect(conn)
207
def interaction(self, transaction):
208
transaction.execute("select x from simple order by x")
209
for i in range(self.num_iterations):
210
row = transaction.fetchone()
211
self.failUnless(len(row) == 1, "Wrong size row")
212
self.failUnless(row[0] == i, "Value not returned.")
213
# should test this, but gadfly throws an exception instead
214
#self.failUnless(transaction.fetchone() is None, "Too many rows")
217
def bad_interaction(self, transaction):
218
if self.can_rollback:
219
transaction.execute("insert into simple(x) values(0)")
221
transaction.execute("select * from NOTABLE")
223
def withConnection(self, conn):
226
curs.execute("select x from simple order by x")
227
for i in range(self.num_iterations):
228
row = curs.fetchone()
229
self.failUnless(len(row) == 1, "Wrong size row")
230
self.failUnless(row[0] == i, "Value not returned.")
231
# should test this, but gadfly throws an exception instead
232
#self.failUnless(transaction.fetchone() is None, "Too many rows")
237
def close_withConnection(self, conn):
240
def bad_withConnection(self, conn):
243
curs.execute("select * from NOTABLE")
248
class ReconnectTestBase:
249
"""Test the asynchronous DB-API code with reconnect."""
251
if interfaces.IReactorThreads(reactor, None) is None:
252
skip = "ADB-API requires threads, no way to test without them"
254
def extraSetUp(self):
256
Skip the test if C{good_sql} is unavailable. Otherwise, set up the
257
database, create a connection pool pointed at it, and set up a simple
260
if self.good_sql is None:
261
raise unittest.SkipTest('no good sql for reconnect test')
263
self.dbpool = self.makePool(cp_max=1, cp_reconnect=True,
264
cp_good_sql=self.good_sql)
266
return self.dbpool.runOperation(simple_table_schema)
270
d = self.dbpool.runOperation('DROP TABLE simple')
271
d.addCallback(lambda res: self.dbpool.close())
272
d.addCallback(lambda res: self.stopDB())
276
d = defer.succeed(None)
277
d.addCallback(self._testPool_1)
278
d.addCallback(self._testPool_2)
279
if not self.early_reconnect:
280
d.addCallback(self._testPool_3)
281
d.addCallback(self._testPool_4)
282
d.addCallback(self._testPool_5)
285
def _testPool_1(self, res):
286
sql = "select count(1) from simple"
287
d = self.dbpool.runQuery(sql)
289
self.failUnless(int(row[0][0]) == 0, "Table not empty")
290
d.addCallback(_check)
293
def _testPool_2(self, res):
294
# reach in and close the connection manually
295
self.dbpool.connections.values()[0].close()
297
def _testPool_3(self, res):
298
sql = "select count(1) from simple"
299
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
300
d.addCallbacks(lambda res: self.fail('no exception'),
304
def _testPool_4(self, res):
305
sql = "select count(1) from simple"
306
d = self.dbpool.runQuery(sql)
308
self.failUnless(int(row[0][0]) == 0, "Table not empty")
309
d.addCallback(_check)
312
def _testPool_5(self, res):
313
self.flushLoggedErrors()
314
sql = "select * from NOTABLE" # bad sql
315
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
316
d.addCallbacks(lambda res: self.fail('no exception'),
317
lambda f: self.failIf(f.check(ConnectionLost)))
321
class DBTestConnector:
322
"""A class which knows how to test for the presence of
323
and establish a connection to a relational database.
325
To enable test cases which use a central, system database,
326
you must create a database named DB_NAME with a user DB_USER
327
and password DB_PASS with full access rights to database DB_NAME.
330
TEST_PREFIX = None # used for creating new test cases
332
DB_NAME = "twisted_test"
333
DB_USER = 'twisted_test'
334
DB_PASS = 'twisted_test'
336
DB_DIR = None # directory for database storage
338
nulls_ok = True # nulls supported
339
trailing_spaces_ok = True # trailing spaces in strings preserved
340
can_rollback = True # rollback supported
341
test_failures = True # test bad sql?
342
escape_slashes = True # escape \ in sql?
343
good_sql = ConnectionPool.good_sql
344
early_reconnect = True # cursor() will fail on closed connection
345
can_clear = True # can try to clear out tables when starting
347
num_iterations = 50 # number of iterations for test loops
348
# (lower this for slow db's)
351
self.DB_DIR = self.mktemp()
352
os.mkdir(self.DB_DIR)
353
if not self.can_connect():
354
raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX)
355
return self.extraSetUp()
357
def can_connect(self):
358
"""Return true if this database is present on the system
359
and can be used in a test."""
360
raise NotImplementedError()
363
"""Take any steps needed to bring database up."""
367
"""Bring database down, if needed."""
370
def makePool(self, **newkw):
371
"""Create a connection pool with additional keyword arguments."""
372
args, kw = self.getPoolArgs()
375
return ConnectionPool(*args, **kw)
377
def getPoolArgs(self):
378
"""Return a tuple (args, kw) of list and keyword arguments
379
that need to be passed to ConnectionPool to create a connection
381
raise NotImplementedError()
383
class GadflyConnector(DBTestConnector):
384
TEST_PREFIX = 'Gadfly'
388
escape_slashes = False
389
good_sql = 'select * from simple where 1=0'
391
num_iterations = 1 # slow
393
def can_connect(self):
396
if not getattr(gadfly, 'connect', None):
397
gadfly.connect = gadfly.gadfly
402
conn = gadfly.gadfly()
403
conn.startup(self.DB_NAME, self.DB_DIR)
405
# gadfly seems to want us to create something to get the db going
406
cursor = conn.cursor()
407
cursor.execute("create table x (x integer)")
411
def getPoolArgs(self):
412
args = ('gadfly', self.DB_NAME, self.DB_DIR)
416
class SQLiteConnector(DBTestConnector):
417
TEST_PREFIX = 'SQLite'
419
escape_slashes = False
421
num_iterations = 1 # slow
423
def can_connect(self):
429
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
430
if os.path.exists(self.database):
431
os.unlink(self.database)
433
def getPoolArgs(self):
435
kw = {'database': self.database, 'cp_max': 1}
438
class PyPgSQLConnector(DBTestConnector):
439
TEST_PREFIX = "PyPgSQL"
441
def can_connect(self):
442
try: from pyPgSQL import PgSQL
445
conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER,
446
password=self.DB_PASS)
452
def getPoolArgs(self):
453
args = ('pyPgSQL.PgSQL',)
454
kw = {'database': self.DB_NAME, 'user': self.DB_USER,
455
'password': self.DB_PASS, 'cp_min': 0}
458
class PsycopgConnector(DBTestConnector):
459
TEST_PREFIX = 'Psycopg'
461
def can_connect(self):
465
conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER,
466
password=self.DB_PASS)
472
def getPoolArgs(self):
474
kw = {'database': self.DB_NAME, 'user': self.DB_USER,
475
'password': self.DB_PASS, 'cp_min': 0}
478
class MySQLConnector(DBTestConnector):
479
TEST_PREFIX = 'MySQL'
481
trailing_spaces_ok = False
483
early_reconnect = False
485
def can_connect(self):
489
conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER,
496
def getPoolArgs(self):
498
kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS}
501
class FirebirdConnector(DBTestConnector):
502
TEST_PREFIX = 'Firebird'
504
test_failures = False # failure testing causes problems
505
escape_slashes = False
506
good_sql = None # firebird doesn't handle failed sql well
507
can_clear = False # firebird is not so good
509
num_iterations = 5 # slow
511
def can_connect(self):
512
try: import kinterbasdb
524
self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
525
os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
526
sql = 'create database "%s" user "%s" password "%s"'
527
sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
528
conn = kinterbasdb.create_database(sql)
532
def getPoolArgs(self):
533
args = ('kinterbasdb',)
534
kw = {'database': self.DB_NAME, 'host': '127.0.0.1',
535
'user': self.DB_USER, 'password': self.DB_PASS}
540
conn = kinterbasdb.connect(database=self.DB_NAME,
541
host='127.0.0.1', user=self.DB_USER,
542
password=self.DB_PASS)
545
def makeSQLTests(base, suffix, globals):
547
Make a test case for every db connector which can connect.
549
@param base: Base class for test case. Additional base classes
550
will be a DBConnector subclass and unittest.TestCase
551
@param suffix: A suffix used to create test case names. Prefixes
552
are defined in the DBConnector subclasses.
554
connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector,
555
PsycopgConnector, MySQLConnector, FirebirdConnector]
556
for connclass in connectors:
557
name = connclass.TEST_PREFIX + suffix
558
klass = new.classobj(name, (connclass, base, unittest.TestCase), base.__dict__)
559
globals[name] = klass
561
# GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase
562
# PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase
563
makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals())
565
# GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase
566
# PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase
567
makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals())
571
class DeprecationTestCase(unittest.TestCase):
573
Test deprecations in twisted.enterprise.adbapi
578
Test deprecation of twisted.enterprise.adbapi.safe()
580
result = self.callDeprecated(_unreleasedVersion,
583
# make sure safe still behaves like the original
584
self.assertEqual(result, "test''")
588
class FakePool(object):
590
A fake L{ConnectionPool} for tests.
592
@ivar connectionFactory: factory for making connections returned by the
594
@type connectionFactory: any callable
599
def __init__(self, connectionFactory):
600
self.connectionFactory = connectionFactory
605
Return an instance of C{self.connectionFactory}.
607
return self.connectionFactory()
610
def disconnect(self, connection):
617
class ConnectionTestCase(unittest.TestCase):
619
Tests for the L{Connection} class.
622
def test_rollbackErrorLogged(self):
624
If an error happens during rollback, L{ConnectionLost} is raised but
625
the original error is logged.
627
class ConnectionRollbackRaise(object):
629
raise RuntimeError("problem!")
631
pool = FakePool(ConnectionRollbackRaise)
632
connection = Connection(pool)
633
self.assertRaises(ConnectionLost, connection.rollback)
634
errors = self.flushLoggedErrors(RuntimeError)
635
self.assertEquals(len(errors), 1)
636
self.assertEquals(errors[0].value.args[0], "problem!")
640
class TransactionTestCase(unittest.TestCase):
642
Tests for the L{Transaction} class.
645
def test_reopenLogErrorIfReconnect(self):
647
If the cursor creation raises an error in L{Transaction.reopen}, it
648
reconnects but log the error occurred.
650
class ConnectionCursorRaise(object):
659
raise RuntimeError("problem!")
661
pool = FakePool(None)
662
transaction = Transaction(pool, ConnectionCursorRaise())
664
errors = self.flushLoggedErrors(RuntimeError)
665
self.assertEquals(len(errors), 1)
666
self.assertEquals(errors[0].value.args[0], "problem!")
670
class NonThreadPool(object):
671
def callInThreadWithCallback(self, onResult, f, *a, **kw):
678
onResult(success, result)
682
class DummyConnectionPool(ConnectionPool):
684
A testable L{ConnectionPool};
686
threadpool = NonThreadPool()
690
Don't forward init call.
695
class ConnectionPoolTestCase(unittest.TestCase):
697
Unit tests for L{ConnectionPool}.
700
def test_runWithConnectionRaiseOriginalError(self):
702
If rollback fails, L{ConnectionPool.runWithConnection} raises the
703
original exception and log the error of the rollback.
705
class ConnectionRollbackRaise(object):
706
def __init__(self, pool):
710
raise RuntimeError("problem!")
712
def raisingFunction(connection):
713
raise ValueError("foo")
715
pool = DummyConnectionPool()
716
pool.connectionFactory = ConnectionRollbackRaise
717
d = pool.runWithConnection(raisingFunction)
718
d = self.assertFailure(d, ValueError)
719
def cbFailed(ignored):
720
errors = self.flushLoggedErrors(RuntimeError)
721
self.assertEquals(len(errors), 1)
722
self.assertEquals(errors[0].value.args[0], "problem!")
723
d.addCallback(cbFailed)
727
def test_closeLogError(self):
729
L{ConnectionPool._close} logs exceptions.
731
class ConnectionCloseRaise(object):
733
raise RuntimeError("problem!")
735
pool = DummyConnectionPool()
736
pool._close(ConnectionCloseRaise())
738
errors = self.flushLoggedErrors(RuntimeError)
739
self.assertEquals(len(errors), 1)
740
self.assertEquals(errors[0].value.args[0], "problem!")
743
def test_runWithInteractionRaiseOriginalError(self):
745
If rollback fails, L{ConnectionPool.runInteraction} raises the
746
original exception and log the error of the rollback.
748
class ConnectionRollbackRaise(object):
749
def __init__(self, pool):
753
raise RuntimeError("problem!")
755
class DummyTransaction(object):
756
def __init__(self, pool, connection):
759
def raisingFunction(transaction):
760
raise ValueError("foo")
762
pool = DummyConnectionPool()
763
pool.connectionFactory = ConnectionRollbackRaise
764
pool.transactionFactory = DummyTransaction
766
d = pool.runInteraction(raisingFunction)
767
d = self.assertFailure(d, ValueError)
768
def cbFailed(ignored):
769
errors = self.flushLoggedErrors(RuntimeError)
770
self.assertEquals(len(errors), 1)
771
self.assertEquals(errors[0].value.args[0], "problem!")
772
d.addCallback(cbFailed)