~bzr/ubuntu/natty/python-testtools/bzr-ppa

« back to all changes in this revision

Viewing changes to testtools/tests/test_spinner.py

  • Committer: Robert Collins
  • Date: 2010-11-14 15:49:58 UTC
  • mfrom: (16.11.4 upstream)
  • Revision ID: robertc@robertcollins.net-20101114154958-lwb16rdhehq6q020
New snapshot for testing.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
 
2
 
 
3
"""Tests for the evil Twisted reactor-spinning we do."""
 
4
 
 
5
import os
 
6
import signal
 
7
 
 
8
from testtools import (
 
9
    skipIf,
 
10
    TestCase,
 
11
    )
 
12
from testtools.matchers import (
 
13
    Equals,
 
14
    Is,
 
15
    MatchesException,
 
16
    Raises,
 
17
    )
 
18
from testtools._spinner import (
 
19
    DeferredNotFired,
 
20
    extract_result,
 
21
    NoResultError,
 
22
    not_reentrant,
 
23
    ReentryError,
 
24
    Spinner,
 
25
    StaleJunkError,
 
26
    TimeoutError,
 
27
    trap_unhandled_errors,
 
28
    )
 
29
 
 
30
from twisted.internet import defer
 
31
from twisted.python.failure import Failure
 
32
 
 
33
 
 
34
class TestNotReentrant(TestCase):
 
35
 
 
36
    def test_not_reentrant(self):
 
37
        # A function decorated as not being re-entrant will raise a
 
38
        # ReentryError if it is called while it is running.
 
39
        calls = []
 
40
        @not_reentrant
 
41
        def log_something():
 
42
            calls.append(None)
 
43
            if len(calls) < 5:
 
44
                log_something()
 
45
        self.assertThat(log_something, Raises(MatchesException(ReentryError)))
 
46
        self.assertEqual(1, len(calls))
 
47
 
 
48
    def test_deeper_stack(self):
 
49
        calls = []
 
50
        @not_reentrant
 
51
        def g():
 
52
            calls.append(None)
 
53
            if len(calls) < 5:
 
54
                f()
 
55
        @not_reentrant
 
56
        def f():
 
57
            calls.append(None)
 
58
            if len(calls) < 5:
 
59
                g()
 
60
        self.assertThat(f, Raises(MatchesException(ReentryError)))
 
61
        self.assertEqual(2, len(calls))
 
62
 
 
63
 
 
64
class TestExtractResult(TestCase):
 
65
 
 
66
    def test_not_fired(self):
 
67
        # extract_result raises DeferredNotFired if it's given a Deferred that
 
68
        # has not fired.
 
69
        self.assertThat(lambda:extract_result(defer.Deferred()),
 
70
            Raises(MatchesException(DeferredNotFired)))
 
71
 
 
72
    def test_success(self):
 
73
        # extract_result returns the value of the Deferred if it has fired
 
74
        # successfully.
 
75
        marker = object()
 
76
        d = defer.succeed(marker)
 
77
        self.assertThat(extract_result(d), Equals(marker))
 
78
 
 
79
    def test_failure(self):
 
80
        # extract_result raises the failure's exception if it's given a
 
81
        # Deferred that is failing.
 
82
        try:
 
83
            1/0
 
84
        except ZeroDivisionError:
 
85
            f = Failure()
 
86
        d = defer.fail(f)
 
87
        self.assertThat(lambda:extract_result(d),
 
88
            Raises(MatchesException(ZeroDivisionError)))
 
89
 
 
90
 
 
91
class TestTrapUnhandledErrors(TestCase):
 
92
 
 
93
    def test_no_deferreds(self):
 
94
        marker = object()
 
95
        result, errors = trap_unhandled_errors(lambda: marker)
 
96
        self.assertEqual([], errors)
 
97
        self.assertIs(marker, result)
 
98
 
 
99
    def test_unhandled_error(self):
 
100
        failures = []
 
101
        def make_deferred_but_dont_handle():
 
102
            try:
 
103
                1/0
 
104
            except ZeroDivisionError:
 
105
                f = Failure()
 
106
                failures.append(f)
 
107
                defer.fail(f)
 
108
        result, errors = trap_unhandled_errors(make_deferred_but_dont_handle)
 
109
        self.assertIs(None, result)
 
110
        self.assertEqual(failures, [error.failResult for error in errors])
 
111
 
 
112
 
 
113
class TestRunInReactor(TestCase):
 
114
 
 
115
    def make_reactor(self):
 
116
        from twisted.internet import reactor
 
117
        return reactor
 
118
 
 
119
    def make_spinner(self, reactor=None):
 
120
        if reactor is None:
 
121
            reactor = self.make_reactor()
 
122
        return Spinner(reactor)
 
123
 
 
124
    def make_timeout(self):
 
125
        return 0.01
 
126
 
 
127
    def test_function_called(self):
 
128
        # run_in_reactor actually calls the function given to it.
 
129
        calls = []
 
130
        marker = object()
 
131
        self.make_spinner().run(self.make_timeout(), calls.append, marker)
 
132
        self.assertThat(calls, Equals([marker]))
 
133
 
 
134
    def test_return_value_returned(self):
 
135
        # run_in_reactor returns the value returned by the function given to
 
136
        # it.
 
137
        marker = object()
 
138
        result = self.make_spinner().run(self.make_timeout(), lambda: marker)
 
139
        self.assertThat(result, Is(marker))
 
140
 
 
141
    def test_exception_reraised(self):
 
142
        # If the given function raises an error, run_in_reactor re-raises that
 
143
        # error.
 
144
        self.assertThat(
 
145
            lambda:self.make_spinner().run(self.make_timeout(), lambda: 1/0),
 
146
            Raises(MatchesException(ZeroDivisionError)))
 
147
 
 
148
    def test_keyword_arguments(self):
 
149
        # run_in_reactor passes keyword arguments on.
 
150
        calls = []
 
151
        function = lambda *a, **kw: calls.extend([a, kw])
 
152
        self.make_spinner().run(self.make_timeout(), function, foo=42)
 
153
        self.assertThat(calls, Equals([(), {'foo': 42}]))
 
154
 
 
155
    def test_not_reentrant(self):
 
156
        # run_in_reactor raises an error if it is called inside another call
 
157
        # to run_in_reactor.
 
158
        spinner = self.make_spinner()
 
159
        self.assertThat(lambda: spinner.run(
 
160
            self.make_timeout(), spinner.run, self.make_timeout(), lambda: None),
 
161
            Raises(MatchesException(ReentryError)))
 
162
 
 
163
    def test_deferred_value_returned(self):
 
164
        # If the given function returns a Deferred, run_in_reactor returns the
 
165
        # value in the Deferred at the end of the callback chain.
 
166
        marker = object()
 
167
        result = self.make_spinner().run(
 
168
            self.make_timeout(), lambda: defer.succeed(marker))
 
169
        self.assertThat(result, Is(marker))
 
170
 
 
171
    def test_preserve_signal_handler(self):
 
172
        signals = ['SIGINT', 'SIGTERM', 'SIGCHLD']
 
173
        signals = filter(
 
174
            None, (getattr(signal, name, None) for name in signals))
 
175
        for sig in signals:
 
176
            self.addCleanup(signal.signal, sig, signal.getsignal(sig))
 
177
        new_hdlrs = list(lambda *a: None for _ in signals)
 
178
        for sig, hdlr in zip(signals, new_hdlrs):
 
179
            signal.signal(sig, hdlr)
 
180
        spinner = self.make_spinner()
 
181
        spinner.run(self.make_timeout(), lambda: None)
 
182
        self.assertEqual(new_hdlrs, map(signal.getsignal, signals))
 
183
 
 
184
    def test_timeout(self):
 
185
        # If the function takes too long to run, we raise a TimeoutError.
 
186
        timeout = self.make_timeout()
 
187
        self.assertThat(
 
188
            lambda:self.make_spinner().run(timeout, lambda: defer.Deferred()),
 
189
            Raises(MatchesException(TimeoutError)))
 
190
 
 
191
    def test_no_junk_by_default(self):
 
192
        # If the reactor hasn't spun yet, then there cannot be any junk.
 
193
        spinner = self.make_spinner()
 
194
        self.assertThat(spinner.get_junk(), Equals([]))
 
195
 
 
196
    def test_clean_do_nothing(self):
 
197
        # If there's nothing going on in the reactor, then clean does nothing
 
198
        # and returns an empty list.
 
199
        spinner = self.make_spinner()
 
200
        result = spinner._clean()
 
201
        self.assertThat(result, Equals([]))
 
202
 
 
203
    def test_clean_delayed_call(self):
 
204
        # If there's a delayed call in the reactor, then clean cancels it and
 
205
        # returns an empty list.
 
206
        reactor = self.make_reactor()
 
207
        spinner = self.make_spinner(reactor)
 
208
        call = reactor.callLater(10, lambda: None)
 
209
        results = spinner._clean()
 
210
        self.assertThat(results, Equals([call]))
 
211
        self.assertThat(call.active(), Equals(False))
 
212
 
 
213
    def test_clean_delayed_call_cancelled(self):
 
214
        # If there's a delayed call that's just been cancelled, then it's no
 
215
        # longer there.
 
216
        reactor = self.make_reactor()
 
217
        spinner = self.make_spinner(reactor)
 
218
        call = reactor.callLater(10, lambda: None)
 
219
        call.cancel()
 
220
        results = spinner._clean()
 
221
        self.assertThat(results, Equals([]))
 
222
 
 
223
    def test_clean_selectables(self):
 
224
        # If there's still a selectable (e.g. a listening socket), then
 
225
        # clean() removes it from the reactor's registry.
 
226
        #
 
227
        # Note that the socket is left open. This emulates a bug in trial.
 
228
        from twisted.internet.protocol import ServerFactory
 
229
        reactor = self.make_reactor()
 
230
        spinner = self.make_spinner(reactor)
 
231
        port = reactor.listenTCP(0, ServerFactory())
 
232
        spinner.run(self.make_timeout(), lambda: None)
 
233
        results = spinner.get_junk()
 
234
        self.assertThat(results, Equals([port]))
 
235
 
 
236
    def test_clean_running_threads(self):
 
237
        import threading
 
238
        import time
 
239
        current_threads = list(threading.enumerate())
 
240
        reactor = self.make_reactor()
 
241
        timeout = self.make_timeout()
 
242
        spinner = self.make_spinner(reactor)
 
243
        spinner.run(timeout, reactor.callInThread, time.sleep, timeout / 2.0)
 
244
        self.assertThat(list(threading.enumerate()), Equals(current_threads))
 
245
 
 
246
    def test_leftover_junk_available(self):
 
247
        # If 'run' is given a function that leaves the reactor dirty in some
 
248
        # way, 'run' will clean up the reactor and then store information
 
249
        # about the junk. This information can be got using get_junk.
 
250
        from twisted.internet.protocol import ServerFactory
 
251
        reactor = self.make_reactor()
 
252
        spinner = self.make_spinner(reactor)
 
253
        port = spinner.run(
 
254
            self.make_timeout(), reactor.listenTCP, 0, ServerFactory())
 
255
        self.assertThat(spinner.get_junk(), Equals([port]))
 
256
 
 
257
    def test_will_not_run_with_previous_junk(self):
 
258
        # If 'run' is called and there's still junk in the spinner's junk
 
259
        # list, then the spinner will refuse to run.
 
260
        from twisted.internet.protocol import ServerFactory
 
261
        reactor = self.make_reactor()
 
262
        spinner = self.make_spinner(reactor)
 
263
        timeout = self.make_timeout()
 
264
        spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
 
265
        self.assertThat(lambda: spinner.run(timeout, lambda: None),
 
266
            Raises(MatchesException(StaleJunkError)))
 
267
 
 
268
    def test_clear_junk_clears_previous_junk(self):
 
269
        # If 'run' is called and there's still junk in the spinner's junk
 
270
        # list, then the spinner will refuse to run.
 
271
        from twisted.internet.protocol import ServerFactory
 
272
        reactor = self.make_reactor()
 
273
        spinner = self.make_spinner(reactor)
 
274
        timeout = self.make_timeout()
 
275
        port = spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
 
276
        junk = spinner.clear_junk()
 
277
        self.assertThat(junk, Equals([port]))
 
278
        self.assertThat(spinner.get_junk(), Equals([]))
 
279
 
 
280
    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
 
281
    def test_sigint_raises_no_result_error(self):
 
282
        # If we get a SIGINT during a run, we raise NoResultError.
 
283
        SIGINT = getattr(signal, 'SIGINT', None)
 
284
        if not SIGINT:
 
285
            self.skipTest("SIGINT not available")
 
286
        reactor = self.make_reactor()
 
287
        spinner = self.make_spinner(reactor)
 
288
        timeout = self.make_timeout()
 
289
        reactor.callLater(timeout, os.kill, os.getpid(), SIGINT)
 
290
        self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
 
291
            Raises(MatchesException(NoResultError)))
 
292
        self.assertEqual([], spinner._clean())
 
293
 
 
294
    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
 
295
    def test_sigint_raises_no_result_error_second_time(self):
 
296
        # If we get a SIGINT during a run, we raise NoResultError.  This test
 
297
        # is exactly the same as test_sigint_raises_no_result_error, and
 
298
        # exists to make sure we haven't futzed with state.
 
299
        self.test_sigint_raises_no_result_error()
 
300
 
 
301
    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
 
302
    def test_fast_sigint_raises_no_result_error(self):
 
303
        # If we get a SIGINT during a run, we raise NoResultError.
 
304
        SIGINT = getattr(signal, 'SIGINT', None)
 
305
        if not SIGINT:
 
306
            self.skipTest("SIGINT not available")
 
307
        reactor = self.make_reactor()
 
308
        spinner = self.make_spinner(reactor)
 
309
        timeout = self.make_timeout()
 
310
        reactor.callWhenRunning(os.kill, os.getpid(), SIGINT)
 
311
        self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
 
312
            Raises(MatchesException(NoResultError)))
 
313
        self.assertEqual([], spinner._clean())
 
314
 
 
315
    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
 
316
    def test_fast_sigint_raises_no_result_error_second_time(self):
 
317
        self.test_fast_sigint_raises_no_result_error()
 
318
 
 
319
 
 
320
def test_suite():
 
321
    from unittest import TestLoader
 
322
    return TestLoader().loadTestsFromName(__name__)