21
21
from nose.plugins.skip import SkipTest
23
from test.utils import server_started_with_auth, joinall
23
from test.utils import server_started_with_auth, joinall, RendezvousThread
24
24
from test.test_connection import get_connection
25
25
from pymongo.connection import Connection
26
26
from pymongo.replica_set_connection import ReplicaSetConnection
27
27
from pymongo.pool import SocketInfo, _closed
28
from pymongo.errors import (AutoReconnect,
28
from pymongo.errors import AutoReconnect, OperationFailure
33
31
def get_pool(connection):
34
32
if isinstance(connection, Connection):
35
return connection._Connection__pool
33
return connection._MongoClient__pool
36
34
elif isinstance(connection, ReplicaSetConnection):
37
writer = connection._ReplicaSetConnection__writer
38
pools = connection._ReplicaSetConnection__pools
39
return pools[writer]['pool']
35
writer = connection._MongoReplicaSetClient__writer
36
pools = connection._MongoReplicaSetClient__members
37
return pools[writer].pool
41
39
raise TypeError(str(connection))
142
class FindPauseFind(threading.Thread):
140
class FindPauseFind(RendezvousThread):
143
141
"""See test_server_disconnect() for details"""
145
def shared_state(cls, nthreads):
146
class SharedState(object):
149
state = SharedState()
151
# Number of threads total
152
state.nthreads = nthreads
154
# Number of threads that have arrived at rendezvous point
155
state.arrived_threads = 0
156
state.arrived_threads_lock = threading.Lock()
158
# set when all threads reach rendezvous
159
state.ev_arrived = threading.Event()
161
# set from outside FindPauseFind to let threads resume after
163
state.ev_resume = threading.Event()
166
142
def __init__(self, collection, state):
167
"""Params: A collection, an event to signal when all threads have
168
done the first find(), an event to signal when threads should resume,
169
and the total number of threads
144
`collection`: A collection for testing
145
`state`: A shared state object from RendezvousThread.shared_state()
171
super(FindPauseFind, self).__init__()
147
super(FindPauseFind, self).__init__(state)
172
148
self.collection = collection
176
# If this thread fails to terminate, don't hang the whole program
179
def rendezvous(self):
180
# pause until all threads arrive here
182
s.arrived_threads_lock.acquire()
183
s.arrived_threads += 1
184
if s.arrived_threads == s.nthreads:
185
s.arrived_threads_lock.release()
188
s.arrived_threads_lock.release()
194
list(self.collection.find())
196
pool = get_pool(self.collection.database.connection)
197
socket_info = pool._get_request_state()
198
assert isinstance(socket_info, SocketInfo)
199
self.request_sock = socket_info.sock
200
assert not _closed(self.request_sock)
202
# Dereference socket_info so it can potentially return to the pool
207
# all threads have passed the rendezvous, wait for
208
# test_server_disconnect() to disconnect the connection
209
self.state.ev_resume.wait()
150
def before_rendezvous(self):
152
list(self.collection.find())
154
self.pool = get_pool(self.collection.database.connection)
155
socket_info = self.pool._get_request_state()
156
assert isinstance(socket_info, SocketInfo)
157
self.request_sock = socket_info.sock
158
assert not _closed(self.request_sock)
160
def after_rendezvous(self):
211
161
# test_server_disconnect() has closed this socket, but that's ok
212
162
# because it's not our request socket anymore
213
163
assert _closed(self.request_sock)
215
165
# if disconnect() properly closed all threads' request sockets, then
216
166
# this won't raise AutoReconnect because it will acquire a new socket
217
assert self.request_sock == pool._get_request_state().sock
167
assert self.request_sock == self.pool._get_request_state().sock
218
168
list(self.collection.find())
219
169
assert self.collection.database.connection.in_request()
220
assert self.request_sock != pool._get_request_state().sock
170
assert self.request_sock != self.pool._get_request_state().sock
224
173
class BaseTestThreads(object):