~ntt-pf-lab/nova/monkey_patch_notification

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/test/test_adbapi.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2001-2008 Twisted Matrix Laboratories.
 
2
# See LICENSE for details.
 
3
 
 
4
 
 
5
"""
 
6
Tests for twisted.enterprise.adbapi.
 
7
"""
 
8
 
 
9
from twisted.trial import unittest
 
10
 
 
11
import os, stat, new
 
12
 
 
13
from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost, safe
 
14
from twisted.enterprise.adbapi import Connection, Transaction
 
15
from twisted.enterprise.adbapi import _unreleasedVersion
 
16
from twisted.internet import reactor, defer, interfaces
 
17
from twisted.python.failure import Failure
 
18
 
 
19
 
 
20
simple_table_schema = """
 
21
CREATE TABLE simple (
 
22
  x integer
 
23
)
 
24
"""
 
25
 
 
26
 
 
27
class ADBAPITestBase:
 
28
    """Test the asynchronous DB-API code."""
 
29
 
 
30
    openfun_called = {}
 
31
 
 
32
    if interfaces.IReactorThreads(reactor, None) is None:
 
33
        skip = "ADB-API requires threads, no way to test without them"
 
34
 
 
35
    def extraSetUp(self):
 
36
        """
 
37
        Set up the database and create a connection pool pointing at it.
 
38
        """
 
39
        self.startDB()
 
40
        self.dbpool = self.makePool(cp_openfun=self.openfun)
 
41
        self.dbpool.start()
 
42
 
 
43
 
 
44
    def tearDown(self):
 
45
        d =  self.dbpool.runOperation('DROP TABLE simple')
 
46
        d.addCallback(lambda res: self.dbpool.close())
 
47
        d.addCallback(lambda res: self.stopDB())
 
48
        return d
 
49
 
 
50
    def openfun(self, conn):
 
51
        self.openfun_called[conn] = True
 
52
 
 
53
    def checkOpenfunCalled(self, conn=None):
 
54
        if not conn:
 
55
            self.failUnless(self.openfun_called)
 
56
        else:
 
57
            self.failUnless(self.openfun_called.has_key(conn))
 
58
 
 
59
    def testPool(self):
 
60
        d = self.dbpool.runOperation(simple_table_schema)
 
61
        if self.test_failures:
 
62
            d.addCallback(self._testPool_1_1)
 
63
            d.addCallback(self._testPool_1_2)
 
64
            d.addCallback(self._testPool_1_3)
 
65
            d.addCallback(self._testPool_1_4)
 
66
            d.addCallback(lambda res: self.flushLoggedErrors())
 
67
        d.addCallback(self._testPool_2)
 
68
        d.addCallback(self._testPool_3)
 
69
        d.addCallback(self._testPool_4)
 
70
        d.addCallback(self._testPool_5)
 
71
        d.addCallback(self._testPool_6)
 
72
        d.addCallback(self._testPool_7)
 
73
        d.addCallback(self._testPool_8)
 
74
        d.addCallback(self._testPool_9)
 
75
        return d
 
76
 
 
77
    def _testPool_1_1(self, res):
 
78
        d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
 
79
        d.addCallbacks(lambda res: self.fail('no exception'),
 
80
                       lambda f: None)
 
81
        return d
 
82
 
 
83
    def _testPool_1_2(self, res):
 
84
        d = defer.maybeDeferred(self.dbpool.runOperation,
 
85
                                "deletexxx from NOTABLE")
 
86
        d.addCallbacks(lambda res: self.fail('no exception'),
 
87
                       lambda f: None)
 
88
        return d
 
89
 
 
90
    def _testPool_1_3(self, res):
 
91
        d = defer.maybeDeferred(self.dbpool.runInteraction,
 
92
                                self.bad_interaction)
 
93
        d.addCallbacks(lambda res: self.fail('no exception'),
 
94
                       lambda f: None)
 
95
        return d
 
96
 
 
97
    def _testPool_1_4(self, res):
 
98
        d = defer.maybeDeferred(self.dbpool.runWithConnection,
 
99
                                self.bad_withConnection)
 
100
        d.addCallbacks(lambda res: self.fail('no exception'),
 
101
                       lambda f: None)
 
102
        return d
 
103
 
 
104
    def _testPool_2(self, res):
 
105
        # verify simple table is empty
 
106
        sql = "select count(1) from simple"
 
107
        d = self.dbpool.runQuery(sql)
 
108
        def _check(row):
 
109
            self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
 
110
            self.checkOpenfunCalled()
 
111
        d.addCallback(_check)
 
112
        return d
 
113
 
 
114
    def _testPool_3(self, res):
 
115
        sql = "select count(1) from simple"
 
116
        inserts = []
 
117
        # add some rows to simple table (runOperation)
 
118
        for i in range(self.num_iterations):
 
119
            sql = "insert into simple(x) values(%d)" % i
 
120
            inserts.append(self.dbpool.runOperation(sql))
 
121
        d = defer.gatherResults(inserts)
 
122
 
 
123
        def _select(res):
 
124
            # make sure they were added (runQuery)
 
125
            sql = "select x from simple order by x";
 
126
            d = self.dbpool.runQuery(sql)
 
127
            return d
 
128
        d.addCallback(_select)
 
129
 
 
130
        def _check(rows):
 
131
            self.failUnless(len(rows) == self.num_iterations,
 
132
                            "Wrong number of rows")
 
133
            for i in range(self.num_iterations):
 
134
                self.failUnless(len(rows[i]) == 1, "Wrong size row")
 
135
                self.failUnless(rows[i][0] == i, "Values not returned.")
 
136
        d.addCallback(_check)
 
137
 
 
138
        return d
 
139
 
 
140
    def _testPool_4(self, res):
 
141
        # runInteraction
 
142
        d = self.dbpool.runInteraction(self.interaction)
 
143
        d.addCallback(lambda res: self.assertEquals(res, "done"))
 
144
        return d
 
145
 
 
146
    def _testPool_5(self, res):
 
147
        # withConnection
 
148
        d = self.dbpool.runWithConnection(self.withConnection)
 
149
        d.addCallback(lambda res: self.assertEquals(res, "done"))
 
150
        return d
 
151
 
 
152
    def _testPool_6(self, res):
 
153
        # Test a withConnection cannot be closed
 
154
        d = self.dbpool.runWithConnection(self.close_withConnection)
 
155
        return d
 
156
 
 
157
    def _testPool_7(self, res):
 
158
        # give the pool a workout
 
159
        ds = []
 
160
        for i in range(self.num_iterations):
 
161
            sql = "select x from simple where x = %d" % i
 
162
            ds.append(self.dbpool.runQuery(sql))
 
163
        dlist = defer.DeferredList(ds, fireOnOneErrback=True)
 
164
        def _check(result):
 
165
            for i in range(self.num_iterations):
 
166
                self.failUnless(result[i][1][0][0] == i, "Value not returned")
 
167
        dlist.addCallback(_check)
 
168
        return dlist
 
169
 
 
170
    def _testPool_8(self, res):
 
171
        # now delete everything
 
172
        ds = []
 
173
        for i in range(self.num_iterations):
 
174
            sql = "delete from simple where x = %d" % i
 
175
            ds.append(self.dbpool.runOperation(sql))
 
176
        dlist = defer.DeferredList(ds, fireOnOneErrback=True)
 
177
        return dlist
 
178
 
 
179
    def _testPool_9(self, res):
 
180
        # verify simple table is empty
 
181
        sql = "select count(1) from simple"
 
182
        d = self.dbpool.runQuery(sql)
 
183
        def _check(row):
 
184
            self.failUnless(int(row[0][0]) == 0,
 
185
                            "Didn't successfully delete table contents")
 
186
            self.checkConnect()
 
187
        d.addCallback(_check)
 
188
        return d
 
189
 
 
190
    def checkConnect(self):
 
191
        """Check the connect/disconnect synchronous calls."""
 
192
        conn = self.dbpool.connect()
 
193
        self.checkOpenfunCalled(conn)
 
194
        curs = conn.cursor()
 
195
        curs.execute("insert into simple(x) values(1)")
 
196
        curs.execute("select x from simple")
 
197
        res = curs.fetchall()
 
198
        self.failUnlessEqual(len(res), 1)
 
199
        self.failUnlessEqual(len(res[0]), 1)
 
200
        self.failUnlessEqual(res[0][0], 1)
 
201
        curs.execute("delete from simple")
 
202
        curs.execute("select x from simple")
 
203
        self.failUnlessEqual(len(curs.fetchall()), 0)
 
204
        curs.close()
 
205
        self.dbpool.disconnect(conn)
 
206
 
 
207
    def interaction(self, transaction):
 
208
        transaction.execute("select x from simple order by x")
 
209
        for i in range(self.num_iterations):
 
210
            row = transaction.fetchone()
 
211
            self.failUnless(len(row) == 1, "Wrong size row")
 
212
            self.failUnless(row[0] == i, "Value not returned.")
 
213
        # should test this, but gadfly throws an exception instead
 
214
        #self.failUnless(transaction.fetchone() is None, "Too many rows")
 
215
        return "done"
 
216
 
 
217
    def bad_interaction(self, transaction):
 
218
        if self.can_rollback:
 
219
            transaction.execute("insert into simple(x) values(0)")
 
220
 
 
221
        transaction.execute("select * from NOTABLE")
 
222
 
 
223
    def withConnection(self, conn):
 
224
        curs = conn.cursor()
 
225
        try:
 
226
            curs.execute("select x from simple order by x")
 
227
            for i in range(self.num_iterations):
 
228
                row = curs.fetchone()
 
229
                self.failUnless(len(row) == 1, "Wrong size row")
 
230
                self.failUnless(row[0] == i, "Value not returned.")
 
231
            # should test this, but gadfly throws an exception instead
 
232
            #self.failUnless(transaction.fetchone() is None, "Too many rows")
 
233
        finally:
 
234
            curs.close()
 
235
        return "done"
 
236
 
 
237
    def close_withConnection(self, conn):
 
238
        conn.close()
 
239
 
 
240
    def bad_withConnection(self, conn):
 
241
        curs = conn.cursor()
 
242
        try:
 
243
            curs.execute("select * from NOTABLE")
 
244
        finally:
 
245
            curs.close()
 
246
 
 
247
 
 
248
class ReconnectTestBase:
 
249
    """Test the asynchronous DB-API code with reconnect."""
 
250
 
 
251
    if interfaces.IReactorThreads(reactor, None) is None:
 
252
        skip = "ADB-API requires threads, no way to test without them"
 
253
 
 
254
    def extraSetUp(self):
 
255
        """
 
256
        Skip the test if C{good_sql} is unavailable.  Otherwise, set up the
 
257
        database, create a connection pool pointed at it, and set up a simple
 
258
        schema in it.
 
259
        """
 
260
        if self.good_sql is None:
 
261
            raise unittest.SkipTest('no good sql for reconnect test')
 
262
        self.startDB()
 
263
        self.dbpool = self.makePool(cp_max=1, cp_reconnect=True,
 
264
                                    cp_good_sql=self.good_sql)
 
265
        self.dbpool.start()
 
266
        return self.dbpool.runOperation(simple_table_schema)
 
267
 
 
268
 
 
269
    def tearDown(self):
 
270
        d = self.dbpool.runOperation('DROP TABLE simple')
 
271
        d.addCallback(lambda res: self.dbpool.close())
 
272
        d.addCallback(lambda res: self.stopDB())
 
273
        return d
 
274
 
 
275
    def testPool(self):
 
276
        d = defer.succeed(None)
 
277
        d.addCallback(self._testPool_1)
 
278
        d.addCallback(self._testPool_2)
 
279
        if not self.early_reconnect:
 
280
            d.addCallback(self._testPool_3)
 
281
        d.addCallback(self._testPool_4)
 
282
        d.addCallback(self._testPool_5)
 
283
        return d
 
284
 
 
285
    def _testPool_1(self, res):
 
286
        sql = "select count(1) from simple"
 
287
        d = self.dbpool.runQuery(sql)
 
288
        def _check(row):
 
289
            self.failUnless(int(row[0][0]) == 0, "Table not empty")
 
290
        d.addCallback(_check)
 
291
        return d
 
292
 
 
293
    def _testPool_2(self, res):
 
294
        # reach in and close the connection manually
 
295
        self.dbpool.connections.values()[0].close()
 
296
 
 
297
    def _testPool_3(self, res):
 
298
        sql = "select count(1) from simple"
 
299
        d = defer.maybeDeferred(self.dbpool.runQuery, sql)
 
300
        d.addCallbacks(lambda res: self.fail('no exception'),
 
301
                       lambda f: None)
 
302
        return d
 
303
 
 
304
    def _testPool_4(self, res):
 
305
        sql = "select count(1) from simple"
 
306
        d = self.dbpool.runQuery(sql)
 
307
        def _check(row):
 
308
            self.failUnless(int(row[0][0]) == 0, "Table not empty")
 
309
        d.addCallback(_check)
 
310
        return d
 
311
 
 
312
    def _testPool_5(self, res):
 
313
        self.flushLoggedErrors()
 
314
        sql = "select * from NOTABLE" # bad sql
 
315
        d = defer.maybeDeferred(self.dbpool.runQuery, sql)
 
316
        d.addCallbacks(lambda res: self.fail('no exception'),
 
317
                       lambda f: self.failIf(f.check(ConnectionLost)))
 
318
        return d
 
319
 
 
320
 
 
321
class DBTestConnector:
 
322
    """A class which knows how to test for the presence of
 
323
    and establish a connection to a relational database.
 
324
 
 
325
    To enable test cases  which use a central, system database,
 
326
    you must create a database named DB_NAME with a user DB_USER
 
327
    and password DB_PASS with full access rights to database DB_NAME.
 
328
    """
 
329
 
 
330
    TEST_PREFIX = None # used for creating new test cases
 
331
 
 
332
    DB_NAME = "twisted_test"
 
333
    DB_USER = 'twisted_test'
 
334
    DB_PASS = 'twisted_test'
 
335
 
 
336
    DB_DIR = None # directory for database storage
 
337
 
 
338
    nulls_ok = True # nulls supported
 
339
    trailing_spaces_ok = True # trailing spaces in strings preserved
 
340
    can_rollback = True # rollback supported
 
341
    test_failures = True # test bad sql?
 
342
    escape_slashes = True # escape \ in sql?
 
343
    good_sql = ConnectionPool.good_sql
 
344
    early_reconnect = True # cursor() will fail on closed connection
 
345
    can_clear = True # can try to clear out tables when starting
 
346
 
 
347
    num_iterations = 50 # number of iterations for test loops
 
348
                        # (lower this for slow db's)
 
349
 
 
350
    def setUp(self):
 
351
        self.DB_DIR = self.mktemp()
 
352
        os.mkdir(self.DB_DIR)
 
353
        if not self.can_connect():
 
354
            raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX)
 
355
        return self.extraSetUp()
 
356
 
 
357
    def can_connect(self):
 
358
        """Return true if this database is present on the system
 
359
        and can be used in a test."""
 
360
        raise NotImplementedError()
 
361
 
 
362
    def startDB(self):
 
363
        """Take any steps needed to bring database up."""
 
364
        pass
 
365
 
 
366
    def stopDB(self):
 
367
        """Bring database down, if needed."""
 
368
        pass
 
369
 
 
370
    def makePool(self, **newkw):
 
371
        """Create a connection pool with additional keyword arguments."""
 
372
        args, kw = self.getPoolArgs()
 
373
        kw = kw.copy()
 
374
        kw.update(newkw)
 
375
        return ConnectionPool(*args, **kw)
 
376
 
 
377
    def getPoolArgs(self):
 
378
        """Return a tuple (args, kw) of list and keyword arguments
 
379
        that need to be passed to ConnectionPool to create a connection
 
380
        to this database."""
 
381
        raise NotImplementedError()
 
382
 
 
383
class GadflyConnector(DBTestConnector):
 
384
    TEST_PREFIX = 'Gadfly'
 
385
 
 
386
    nulls_ok = False
 
387
    can_rollback = False
 
388
    escape_slashes = False
 
389
    good_sql = 'select * from simple where 1=0'
 
390
 
 
391
    num_iterations = 1 # slow
 
392
 
 
393
    def can_connect(self):
 
394
        try: import gadfly
 
395
        except: return False
 
396
        if not getattr(gadfly, 'connect', None):
 
397
            gadfly.connect = gadfly.gadfly
 
398
        return True
 
399
 
 
400
    def startDB(self):
 
401
        import gadfly
 
402
        conn = gadfly.gadfly()
 
403
        conn.startup(self.DB_NAME, self.DB_DIR)
 
404
 
 
405
        # gadfly seems to want us to create something to get the db going
 
406
        cursor = conn.cursor()
 
407
        cursor.execute("create table x (x integer)")
 
408
        conn.commit()
 
409
        conn.close()
 
410
 
 
411
    def getPoolArgs(self):
 
412
        args = ('gadfly', self.DB_NAME, self.DB_DIR)
 
413
        kw = {'cp_max': 1}
 
414
        return args, kw
 
415
 
 
416
class SQLiteConnector(DBTestConnector):
 
417
    TEST_PREFIX = 'SQLite'
 
418
 
 
419
    escape_slashes = False
 
420
 
 
421
    num_iterations = 1 # slow
 
422
 
 
423
    def can_connect(self):
 
424
        try: import sqlite
 
425
        except: return False
 
426
        return True
 
427
 
 
428
    def startDB(self):
 
429
        self.database = os.path.join(self.DB_DIR, self.DB_NAME)
 
430
        if os.path.exists(self.database):
 
431
            os.unlink(self.database)
 
432
 
 
433
    def getPoolArgs(self):
 
434
        args = ('sqlite',)
 
435
        kw = {'database': self.database, 'cp_max': 1}
 
436
        return args, kw
 
437
 
 
438
class PyPgSQLConnector(DBTestConnector):
 
439
    TEST_PREFIX = "PyPgSQL"
 
440
 
 
441
    def can_connect(self):
 
442
        try: from pyPgSQL import PgSQL
 
443
        except: return False
 
444
        try:
 
445
            conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER,
 
446
                                 password=self.DB_PASS)
 
447
            conn.close()
 
448
            return True
 
449
        except:
 
450
            return False
 
451
 
 
452
    def getPoolArgs(self):
 
453
        args = ('pyPgSQL.PgSQL',)
 
454
        kw = {'database': self.DB_NAME, 'user': self.DB_USER,
 
455
              'password': self.DB_PASS, 'cp_min': 0}
 
456
        return args, kw
 
457
 
 
458
class PsycopgConnector(DBTestConnector):
 
459
    TEST_PREFIX = 'Psycopg'
 
460
 
 
461
    def can_connect(self):
 
462
        try: import psycopg
 
463
        except: return False
 
464
        try:
 
465
            conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER,
 
466
                                   password=self.DB_PASS)
 
467
            conn.close()
 
468
            return True
 
469
        except:
 
470
            return False
 
471
 
 
472
    def getPoolArgs(self):
 
473
        args = ('psycopg',)
 
474
        kw = {'database': self.DB_NAME, 'user': self.DB_USER,
 
475
              'password': self.DB_PASS, 'cp_min': 0}
 
476
        return args, kw
 
477
 
 
478
class MySQLConnector(DBTestConnector):
 
479
    TEST_PREFIX = 'MySQL'
 
480
 
 
481
    trailing_spaces_ok = False
 
482
    can_rollback = False
 
483
    early_reconnect = False
 
484
 
 
485
    def can_connect(self):
 
486
        try: import MySQLdb
 
487
        except: return False
 
488
        try:
 
489
            conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER,
 
490
                                   passwd=self.DB_PASS)
 
491
            conn.close()
 
492
            return True
 
493
        except:
 
494
            return False
 
495
 
 
496
    def getPoolArgs(self):
 
497
        args = ('MySQLdb',)
 
498
        kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS}
 
499
        return args, kw
 
500
 
 
501
class FirebirdConnector(DBTestConnector):
 
502
    TEST_PREFIX = 'Firebird'
 
503
 
 
504
    test_failures = False # failure testing causes problems
 
505
    escape_slashes = False
 
506
    good_sql = None # firebird doesn't handle failed sql well
 
507
    can_clear = False # firebird is not so good
 
508
 
 
509
    num_iterations = 5 # slow
 
510
 
 
511
    def can_connect(self):
 
512
        try: import kinterbasdb
 
513
        except: return False
 
514
        try:
 
515
            self.startDB()
 
516
            self.stopDB()
 
517
            return True
 
518
        except:
 
519
            return False
 
520
 
 
521
 
 
522
    def startDB(self):
 
523
        import kinterbasdb
 
524
        self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
 
525
        os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
 
526
        sql = 'create database "%s" user "%s" password "%s"'
 
527
        sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
 
528
        conn = kinterbasdb.create_database(sql)
 
529
        conn.close()
 
530
 
 
531
 
 
532
    def getPoolArgs(self):
 
533
        args = ('kinterbasdb',)
 
534
        kw = {'database': self.DB_NAME, 'host': '127.0.0.1',
 
535
              'user': self.DB_USER, 'password': self.DB_PASS}
 
536
        return args, kw
 
537
 
 
538
    def stopDB(self):
 
539
        import kinterbasdb
 
540
        conn = kinterbasdb.connect(database=self.DB_NAME,
 
541
                                   host='127.0.0.1', user=self.DB_USER,
 
542
                                   password=self.DB_PASS)
 
543
        conn.drop_database()
 
544
 
 
545
def makeSQLTests(base, suffix, globals):
 
546
    """
 
547
    Make a test case for every db connector which can connect.
 
548
 
 
549
    @param base: Base class for test case. Additional base classes
 
550
                 will be a DBConnector subclass and unittest.TestCase
 
551
    @param suffix: A suffix used to create test case names. Prefixes
 
552
                   are defined in the DBConnector subclasses.
 
553
    """
 
554
    connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector,
 
555
                  PsycopgConnector, MySQLConnector, FirebirdConnector]
 
556
    for connclass in connectors:
 
557
        name = connclass.TEST_PREFIX + suffix
 
558
        klass = new.classobj(name, (connclass, base, unittest.TestCase), base.__dict__)
 
559
        globals[name] = klass
 
560
 
 
561
# GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase
 
562
# PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase
 
563
makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals())
 
564
 
 
565
# GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase
 
566
# PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase
 
567
makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals())
 
568
 
 
569
 
 
570
 
 
571
class DeprecationTestCase(unittest.TestCase):
 
572
    """
 
573
    Test deprecations in twisted.enterprise.adbapi
 
574
    """
 
575
 
 
576
    def test_safe(self):
 
577
        """
 
578
        Test deprecation of twisted.enterprise.adbapi.safe()
 
579
        """
 
580
        result = self.callDeprecated(_unreleasedVersion,
 
581
                                     safe, "test'")
 
582
 
 
583
        # make sure safe still behaves like the original
 
584
        self.assertEqual(result, "test''")
 
585
 
 
586
 
 
587
 
 
588
class FakePool(object):
 
589
    """
 
590
    A fake L{ConnectionPool} for tests.
 
591
 
 
592
    @ivar connectionFactory: factory for making connections returned by the
 
593
        C{connect} method.
 
594
    @type connectionFactory: any callable
 
595
    """
 
596
    reconnect = True
 
597
    noisy = True
 
598
 
 
599
    def __init__(self, connectionFactory):
 
600
        self.connectionFactory = connectionFactory
 
601
 
 
602
 
 
603
    def connect(self):
 
604
        """
 
605
        Return an instance of C{self.connectionFactory}.
 
606
        """
 
607
        return self.connectionFactory()
 
608
 
 
609
 
 
610
    def disconnect(self, connection):
 
611
        """
 
612
        Do nothing.
 
613
        """
 
614
 
 
615
 
 
616
 
 
617
class ConnectionTestCase(unittest.TestCase):
 
618
    """
 
619
    Tests for the L{Connection} class.
 
620
    """
 
621
 
 
622
    def test_rollbackErrorLogged(self):
 
623
        """
 
624
        If an error happens during rollback, L{ConnectionLost} is raised but
 
625
        the original error is logged.
 
626
        """
 
627
        class ConnectionRollbackRaise(object):
 
628
            def rollback(self):
 
629
                raise RuntimeError("problem!")
 
630
 
 
631
        pool = FakePool(ConnectionRollbackRaise)
 
632
        connection = Connection(pool)
 
633
        self.assertRaises(ConnectionLost, connection.rollback)
 
634
        errors = self.flushLoggedErrors(RuntimeError)
 
635
        self.assertEquals(len(errors), 1)
 
636
        self.assertEquals(errors[0].value.args[0], "problem!")
 
637
 
 
638
 
 
639
 
 
640
class TransactionTestCase(unittest.TestCase):
 
641
    """
 
642
    Tests for the L{Transaction} class.
 
643
    """
 
644
 
 
645
    def test_reopenLogErrorIfReconnect(self):
 
646
        """
 
647
        If the cursor creation raises an error in L{Transaction.reopen}, it
 
648
        reconnects but log the error occurred.
 
649
        """
 
650
        class ConnectionCursorRaise(object):
 
651
            count = 0
 
652
 
 
653
            def reconnect(self):
 
654
                pass
 
655
 
 
656
            def cursor(self):
 
657
                if self.count == 0:
 
658
                    self.count += 1
 
659
                    raise RuntimeError("problem!")
 
660
 
 
661
        pool = FakePool(None)
 
662
        transaction = Transaction(pool, ConnectionCursorRaise())
 
663
        transaction.reopen()
 
664
        errors = self.flushLoggedErrors(RuntimeError)
 
665
        self.assertEquals(len(errors), 1)
 
666
        self.assertEquals(errors[0].value.args[0], "problem!")
 
667
 
 
668
 
 
669
 
 
670
class NonThreadPool(object):
 
671
    def callInThreadWithCallback(self, onResult, f, *a, **kw):
 
672
        success = True
 
673
        try:
 
674
            result = f(*a, **kw)
 
675
        except Exception, e:
 
676
            success = False
 
677
            result = Failure()
 
678
        onResult(success, result)
 
679
 
 
680
 
 
681
 
 
682
class DummyConnectionPool(ConnectionPool):
 
683
    """
 
684
    A testable L{ConnectionPool};
 
685
    """
 
686
    threadpool = NonThreadPool()
 
687
 
 
688
    def __init__(self):
 
689
        """
 
690
        Don't forward init call.
 
691
        """
 
692
 
 
693
 
 
694
 
 
695
class ConnectionPoolTestCase(unittest.TestCase):
 
696
    """
 
697
    Unit tests for L{ConnectionPool}.
 
698
    """
 
699
 
 
700
    def test_runWithConnectionRaiseOriginalError(self):
 
701
        """
 
702
        If rollback fails, L{ConnectionPool.runWithConnection} raises the
 
703
        original exception and log the error of the rollback.
 
704
        """
 
705
        class ConnectionRollbackRaise(object):
 
706
            def __init__(self, pool):
 
707
                pass
 
708
 
 
709
            def rollback(self):
 
710
                raise RuntimeError("problem!")
 
711
 
 
712
        def raisingFunction(connection):
 
713
            raise ValueError("foo")
 
714
 
 
715
        pool = DummyConnectionPool()
 
716
        pool.connectionFactory = ConnectionRollbackRaise
 
717
        d = pool.runWithConnection(raisingFunction)
 
718
        d = self.assertFailure(d, ValueError)
 
719
        def cbFailed(ignored):
 
720
            errors = self.flushLoggedErrors(RuntimeError)
 
721
            self.assertEquals(len(errors), 1)
 
722
            self.assertEquals(errors[0].value.args[0], "problem!")
 
723
        d.addCallback(cbFailed)
 
724
        return d
 
725
 
 
726
 
 
727
    def test_closeLogError(self):
 
728
        """
 
729
        L{ConnectionPool._close} logs exceptions.
 
730
        """
 
731
        class ConnectionCloseRaise(object):
 
732
            def close(self):
 
733
                raise RuntimeError("problem!")
 
734
 
 
735
        pool = DummyConnectionPool()
 
736
        pool._close(ConnectionCloseRaise())
 
737
 
 
738
        errors = self.flushLoggedErrors(RuntimeError)
 
739
        self.assertEquals(len(errors), 1)
 
740
        self.assertEquals(errors[0].value.args[0], "problem!")
 
741
 
 
742
 
 
743
    def test_runWithInteractionRaiseOriginalError(self):
 
744
        """
 
745
        If rollback fails, L{ConnectionPool.runInteraction} raises the
 
746
        original exception and log the error of the rollback.
 
747
        """
 
748
        class ConnectionRollbackRaise(object):
 
749
            def __init__(self, pool):
 
750
                pass
 
751
 
 
752
            def rollback(self):
 
753
                raise RuntimeError("problem!")
 
754
 
 
755
        class DummyTransaction(object):
 
756
            def __init__(self, pool, connection):
 
757
                pass
 
758
 
 
759
        def raisingFunction(transaction):
 
760
            raise ValueError("foo")
 
761
 
 
762
        pool = DummyConnectionPool()
 
763
        pool.connectionFactory = ConnectionRollbackRaise
 
764
        pool.transactionFactory = DummyTransaction
 
765
 
 
766
        d = pool.runInteraction(raisingFunction)
 
767
        d = self.assertFailure(d, ValueError)
 
768
        def cbFailed(ignored):
 
769
            errors = self.flushLoggedErrors(RuntimeError)
 
770
            self.assertEquals(len(errors), 1)
 
771
            self.assertEquals(errors[0].value.args[0], "problem!")
 
772
        d.addCallback(cbFailed)
 
773
        return d
 
774