1
# Copyright 2012-2014 MongoDB, Inc.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
15
"""Test the thread_util module."""
25
from nose.plugins.skip import SkipTest
27
from pymongo import thread_util
28
if thread_util.have_gevent:
29
import greenlet # Plain greenlets.
30
import gevent.greenlet # Gevent's enhanced Greenlets.
33
from test.utils import looplet, my_partial, RendezvousThread
36
class TestIdent(unittest.TestCase):
37
"""Ensure thread_util.Ident works for threads and greenlets. This has
38
gotten intricate from refactoring: we have classes, Watched and Unwatched,
39
that implement the logic for the two child threads / greenlets. For the
40
greenlet case it's easy to ensure the two children are alive at once, so
41
we run the Watched and Unwatched logic directly. For the thread case we
42
mix in the RendezvousThread class so we're sure both children are alive
43
when they call Ident.get().
45
1. Store main thread's / greenlet's id
46
2. Start 2 child threads / greenlets
47
3. Store their values for Ident.get()
48
4. Children reach rendezvous point
49
5. Children call Ident.watch()
50
6. One of the children calls Ident.unwatch()
52
8. Assert that children got different ids from each other and from main,
53
and assert watched child's callback was executed, and that unwatched
54
child's callback was not
56
def _test_ident(self, use_greenlets):
57
if 'java' in sys.platform:
58
raise SkipTest("Can't rely on weakref callbacks in Jython")
60
ident = thread_util.create_ident(use_greenlets)
62
ids = set([ident.get()])
64
done = set([ident.get()]) # Start with main thread's / greenlet's id.
67
class Watched(object):
68
def __init__(self, ident):
69
self._my_ident = ident
71
def before_rendezvous(self):
72
self.my_id = self._my_ident.get()
75
def after_rendezvous(self):
76
assert not self._my_ident.watching()
77
self._my_ident.watch(lambda ref: died.add(self.my_id))
78
assert self._my_ident.watching()
81
class Unwatched(Watched):
82
def before_rendezvous(self):
83
Watched.before_rendezvous(self)
84
unwatched_id.append(self.my_id)
86
def after_rendezvous(self):
87
Watched.after_rendezvous(self)
88
self._my_ident.unwatch(self.my_id)
89
assert not self._my_ident.watching()
92
class WatchedGreenlet(Watched):
94
self.before_rendezvous()
95
self.after_rendezvous()
97
class UnwatchedGreenlet(Unwatched):
99
self.before_rendezvous()
100
self.after_rendezvous()
102
t_watched = greenlet.greenlet(WatchedGreenlet(ident).run)
103
t_unwatched = greenlet.greenlet(UnwatchedGreenlet(ident).run)
104
looplet([t_watched, t_unwatched])
106
class WatchedThread(Watched, RendezvousThread):
107
def __init__(self, ident, state):
108
Watched.__init__(self, ident)
109
RendezvousThread.__init__(self, state)
111
class UnwatchedThread(Unwatched, RendezvousThread):
112
def __init__(self, ident, state):
113
Unwatched.__init__(self, ident)
114
RendezvousThread.__init__(self, state)
116
state = RendezvousThread.create_shared_state(2)
117
t_watched = WatchedThread(ident, state)
120
t_unwatched = UnwatchedThread(ident, state)
123
RendezvousThread.wait_for_rendezvous(state)
124
RendezvousThread.resume_after_rendezvous(state)
129
self.assertTrue(t_watched.passed)
130
self.assertTrue(t_unwatched.passed)
132
# Remove references, let weakref callbacks run
136
# Trigger final cleanup in Python <= 2.7.0.
137
# http://bugs.python.org/issue1868
139
self.assertEqual(3, len(ids))
140
self.assertEqual(3, len(done))
142
# Make sure thread is really gone
144
while not died and slept < 10:
149
self.assertEqual(1, len(died))
150
self.assertFalse(unwatched_id[0] in died)
152
def test_thread_ident(self):
153
self._test_ident(False)
155
def test_greenlet_ident(self):
156
if not thread_util.have_gevent:
157
raise SkipTest('greenlet not installed')
159
self._test_ident(True)
162
class TestGreenletIdent(unittest.TestCase):
164
if not thread_util.have_gevent:
165
raise SkipTest("need Gevent")
167
def test_unwatch_cleans_up(self):
168
# GreenletIdent.unwatch() should remove the on_thread_died callback
169
# from an enhanced Gevent Greenlet's list of links.
170
callback_ran = [False]
172
def on_greenlet_died(_):
173
callback_ran[0] = True
175
ident = thread_util.create_ident(use_greenlets=True)
177
def watch_and_unwatch():
178
ident.watch(on_greenlet_died)
179
ident.unwatch(ident.get())
181
g = gevent.greenlet.Greenlet(run=watch_and_unwatch)
184
the_hub = gevent.hub.get_hub()
185
if hasattr(the_hub, 'join'):
189
# Gevent 0.13 and less
192
self.assertTrue(g.successful())
194
# unwatch() canceled the callback.
195
self.assertFalse(callback_ran[0])
198
class TestCounter(unittest.TestCase):
199
def _test_counter(self, use_greenlets):
200
counter = thread_util.Counter(use_greenlets)
202
self.assertEqual(0, counter.dec())
203
self.assertEqual(0, counter.get())
204
self.assertEqual(0, counter.dec())
205
self.assertEqual(0, counter.get())
211
self.assertEqual(i, counter.get())
212
self.assertEqual(i + 1, counter.inc())
214
for i in xrange(n, 0, -1):
215
self.assertEqual(i, counter.get())
216
self.assertEqual(i - 1, counter.dec())
218
self.assertEqual(0, counter.get())
220
# Extra decrements have no effect
221
self.assertEqual(0, counter.dec())
222
self.assertEqual(0, counter.get())
223
self.assertEqual(0, counter.dec())
224
self.assertEqual(0, counter.get())
230
greenlet.greenlet(my_partial(f, i)) for i in xrange(10)]
234
threading.Thread(target=my_partial(f, i)) for i in xrange(10)]
240
self.assertEqual(10, len(done))
242
def test_thread_counter(self):
243
self._test_counter(False)
245
def test_greenlet_counter(self):
246
if not thread_util.have_gevent:
247
raise SkipTest('greenlet not installed')
249
self._test_counter(True)
251
if __name__ == "__main__":