~crunch.io/ubuntu/precise/pymongo/unstable

« back to all changes in this revision

Viewing changes to test/test_threads.py

  • Committer: Joseph Tate
  • Date: 2013-01-31 08:00:57 UTC
  • mfrom: (1.1.12)
  • Revision ID: jtate@dragonstrider.com-20130131080057-y7lv17xi6x8c1j5x
New upstream release.

Show diffs side-by-side

added added

removed removed

Lines of Context:
20
20
 
21
21
from nose.plugins.skip import SkipTest
22
22
 
23
 
from test.utils import server_started_with_auth, joinall
 
23
from test.utils import server_started_with_auth, joinall, RendezvousThread
24
24
from test.test_connection import get_connection
25
25
from pymongo.connection import Connection
26
26
from pymongo.replica_set_connection import ReplicaSetConnection
27
27
from pymongo.pool import SocketInfo, _closed
28
 
from pymongo.errors import (AutoReconnect,
29
 
                            OperationFailure,
30
 
                            DuplicateKeyError)
 
28
from pymongo.errors import AutoReconnect, OperationFailure
31
29
 
32
30
 
33
31
def get_pool(connection):
34
32
    if isinstance(connection, Connection):
35
 
        return connection._Connection__pool
 
33
        return connection._MongoClient__pool
36
34
    elif isinstance(connection, ReplicaSetConnection):
37
 
        writer = connection._ReplicaSetConnection__writer
38
 
        pools = connection._ReplicaSetConnection__pools
39
 
        return pools[writer]['pool']
 
35
        writer = connection._MongoReplicaSetClient__writer
 
36
        pools = connection._MongoReplicaSetClient__members
 
37
        return pools[writer].pool
40
38
    else:
41
39
        raise TypeError(str(connection))
42
40
 
139
137
                pass
140
138
 
141
139
 
142
 
class FindPauseFind(threading.Thread):
 
140
class FindPauseFind(RendezvousThread):
143
141
    """See test_server_disconnect() for details"""
144
 
    @classmethod
145
 
    def shared_state(cls, nthreads):
146
 
        class SharedState(object):
147
 
            pass
148
 
 
149
 
        state = SharedState()
150
 
 
151
 
        # Number of threads total
152
 
        state.nthreads = nthreads
153
 
 
154
 
        # Number of threads that have arrived at rendezvous point
155
 
        state.arrived_threads = 0
156
 
        state.arrived_threads_lock = threading.Lock()
157
 
 
158
 
        # set when all threads reach rendezvous
159
 
        state.ev_arrived = threading.Event()
160
 
 
161
 
        # set from outside FindPauseFind to let threads resume after
162
 
        # rendezvous
163
 
        state.ev_resume = threading.Event()
164
 
        return state
165
 
 
166
142
    def __init__(self, collection, state):
167
 
        """Params: A collection, an event to signal when all threads have
168
 
        done the first find(), an event to signal when threads should resume,
169
 
        and the total number of threads
 
143
        """Params:
 
144
          `collection`: A collection for testing
 
145
          `state`: A shared state object from RendezvousThread.shared_state()
170
146
        """
171
 
        super(FindPauseFind, self).__init__()
 
147
        super(FindPauseFind, self).__init__(state)
172
148
        self.collection = collection
173
 
        self.state = state
174
 
        self.passed = False
175
 
 
176
 
        # If this thread fails to terminate, don't hang the whole program
177
 
        self.setDaemon(True)
178
 
 
179
 
    def rendezvous(self):
180
 
        # pause until all threads arrive here
181
 
        s = self.state
182
 
        s.arrived_threads_lock.acquire()
183
 
        s.arrived_threads += 1
184
 
        if s.arrived_threads == s.nthreads:
185
 
            s.arrived_threads_lock.release()
186
 
            s.ev_arrived.set()
187
 
        else:
188
 
            s.arrived_threads_lock.release()
189
 
            s.ev_arrived.wait()
190
 
 
191
 
    def run(self):
192
 
        try:
193
 
            # acquire a socket
194
 
            list(self.collection.find())
195
 
 
196
 
            pool = get_pool(self.collection.database.connection)
197
 
            socket_info = pool._get_request_state()
198
 
            assert isinstance(socket_info, SocketInfo)
199
 
            self.request_sock = socket_info.sock
200
 
            assert not _closed(self.request_sock)
201
 
 
202
 
            # Dereference socket_info so it can potentially return to the pool
203
 
            del socket_info
204
 
        finally:
205
 
            self.rendezvous()
206
 
 
207
 
        # all threads have passed the rendezvous, wait for
208
 
        # test_server_disconnect() to disconnect the connection
209
 
        self.state.ev_resume.wait()
210
 
 
 
149
 
 
150
    def before_rendezvous(self):
 
151
        # acquire a socket
 
152
        list(self.collection.find())
 
153
 
 
154
        self.pool = get_pool(self.collection.database.connection)
 
155
        socket_info = self.pool._get_request_state()
 
156
        assert isinstance(socket_info, SocketInfo)
 
157
        self.request_sock = socket_info.sock
 
158
        assert not _closed(self.request_sock)
 
159
 
 
160
    def after_rendezvous(self):
211
161
        # test_server_disconnect() has closed this socket, but that's ok
212
162
        # because it's not our request socket anymore
213
163
        assert _closed(self.request_sock)
214
164
 
215
165
        # if disconnect() properly closed all threads' request sockets, then
216
166
        # this won't raise AutoReconnect because it will acquire a new socket
217
 
        assert self.request_sock == pool._get_request_state().sock
 
167
        assert self.request_sock == self.pool._get_request_state().sock
218
168
        list(self.collection.find())
219
169
        assert self.collection.database.connection.in_request()
220
 
        assert self.request_sock != pool._get_request_state().sock
221
 
        self.passed = True
 
170
        assert self.request_sock != self.pool._get_request_state().sock
222
171
 
223
172
 
224
173
class BaseTestThreads(object):
232
181
        self.db = self._get_connection().pymongo_test
233
182
 
234
183
    def tearDown(self):
235
 
        pass
 
184
        # Clear connection reference so that RSC's monitor thread
 
185
        # dies.
 
186
        self.db = None
236
187
 
237
188
    def _get_connection(self):
238
189
        """
294
245
        error.join()
295
246
        okay.join()
296
247
 
297
 
    def test_low_network_timeout(self):
298
 
        db = None
299
 
        i = 0
300
 
        n = 10
301
 
        while db is None and i < n:
302
 
            try:
303
 
                db = get_connection(network_timeout=0.0001).pymongo_test
304
 
            except AutoReconnect:
305
 
                i += 1
306
 
        if i == n:
307
 
            raise SkipTest()
308
 
 
309
 
        threads = []
310
 
        for _ in range(4):
311
 
            t = IgnoreAutoReconnect(db.test, 100)
312
 
            t.start()
313
 
            threads.append(t)
314
 
 
315
 
        joinall(threads)
316
 
 
317
248
    def test_server_disconnect(self):
318
249
        # PYTHON-345, we need to make sure that threads' request sockets are
319
250
        # closed by disconnect().
329
260
        #
330
261
        # If we've fixed PYTHON-345, then only one AutoReconnect is raised,
331
262
        # and all the threads get new request sockets.
332
 
 
333
263
        cx = self.db.connection
334
264
        self.assertTrue(cx.auto_start_request)
335
265
        collection = self.db.pymongo_test
341
271
        assert isinstance(socket_info, SocketInfo)
342
272
        request_sock = socket_info.sock
343
273
 
344
 
        state = FindPauseFind.shared_state(nthreads=40)
 
274
        state = FindPauseFind.create_shared_state(nthreads=40)
345
275
 
346
276
        threads = [
347
277
            FindPauseFind(collection, state)
353
283
            t.start()
354
284
 
355
285
        # Wait for the threads to reach the rendezvous
356
 
        state.ev_arrived.wait(10)
357
 
        self.assertTrue(state.ev_arrived.isSet(), "Thread timeout")
 
286
        FindPauseFind.wait_for_rendezvous(state)
358
287
 
359
288
        try:
360
 
            self.assertEqual(state.nthreads, state.arrived_threads)
361
 
 
362
289
            # Simulate an event that closes all sockets, e.g. primary stepdown
363
290
            for t in threads:
364
291
                t.request_sock.close()
376
303
 
377
304
        finally:
378
305
            # Let threads do a second find()
379
 
            state.ev_resume.set()
 
306
            FindPauseFind.resume_after_rendezvous(state)
380
307
 
381
308
        joinall(threads)
382
309
 
401
328
        return get_connection()
402
329
 
403
330
    def setUp(self):
404
 
        self.conn = self._get_connection()
405
 
        if not server_started_with_auth(self.conn):
 
331
        conn = self._get_connection()
 
332
        if not server_started_with_auth(conn):
406
333
            raise SkipTest("Authentication is not enabled on server")
 
334
        self.conn = conn
407
335
        self.conn.admin.system.users.remove({})
408
336
        self.conn.admin.add_user('admin-user', 'password')
409
337
        self.conn.admin.authenticate("admin-user", "password")
415
343
        self.conn.admin.authenticate("admin-user", "password")
416
344
        self.conn.admin.system.users.remove({})
417
345
        self.conn.auth_test.system.users.remove({})
 
346
        self.conn.drop_database('auth_test')
 
347
        # Clear connection reference so that RSC's monitor thread
 
348
        # dies.
 
349
        self.conn = None
418
350
 
419
351
    def test_auto_auth_login(self):
420
352
        conn = self._get_connection()
457
389
class TestThreadsAuth(BaseTestThreadsAuth, unittest.TestCase):
458
390
    pass
459
391
 
 
392
 
460
393
if __name__ == "__main__":
461
394
    unittest.main()