~0x44/nova/bug838466

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/trial/unittest.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.trial.test.test_tests -*-
 
2
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
"""
 
6
Things likely to be used by writers of unit tests.
 
7
 
 
8
Maintainer: Jonathan Lange
 
9
"""
 
10
 
 
11
 
 
12
import doctest, inspect
 
13
import os, warnings, sys, tempfile, gc, types
 
14
from pprint import pformat
 
15
try:
 
16
    from dis import findlinestarts as _findlinestarts
 
17
except ImportError:
 
18
    # Definition copied from Python's Lib/dis.py - findlinestarts was not
 
19
    # available in Python 2.3.  This function is copyright Python Software
 
20
    # Foundation, released under the Python license:
 
21
    # http://www.python.org/psf/license/
 
22
    def _findlinestarts(code):
 
23
        """Find the offsets in a byte code which are start of lines in the source.
 
24
 
 
25
        Generate pairs (offset, lineno) as described in Python/compile.c.
 
26
 
 
27
        """
 
28
        byte_increments = [ord(c) for c in code.co_lnotab[0::2]]
 
29
        line_increments = [ord(c) for c in code.co_lnotab[1::2]]
 
30
 
 
31
        lastlineno = None
 
32
        lineno = code.co_firstlineno
 
33
        addr = 0
 
34
        for byte_incr, line_incr in zip(byte_increments, line_increments):
 
35
            if byte_incr:
 
36
                if lineno != lastlineno:
 
37
                    yield (addr, lineno)
 
38
                    lastlineno = lineno
 
39
                addr += byte_incr
 
40
            lineno += line_incr
 
41
        if lineno != lastlineno:
 
42
            yield (addr, lineno)
 
43
 
 
44
from twisted.internet import defer, utils
 
45
from twisted.python import components, failure, log, monkey
 
46
from twisted.python.deprecate import getDeprecationWarningString
 
47
 
 
48
from twisted.trial import itrial, reporter, util
 
49
 
 
50
pyunit = __import__('unittest')
 
51
 
 
52
from zope.interface import implements
 
53
 
 
54
 
 
55
 
 
56
class SkipTest(Exception):
 
57
    """
 
58
    Raise this (with a reason) to skip the current test. You may also set
 
59
    method.skip to a reason string to skip it, or set class.skip to skip the
 
60
    entire TestCase.
 
61
    """
 
62
 
 
63
 
 
64
class FailTest(AssertionError):
 
65
    """Raised to indicate the current test has failed to pass."""
 
66
 
 
67
 
 
68
class Todo(object):
 
69
    """
 
70
    Internal object used to mark a L{TestCase} as 'todo'. Tests marked 'todo'
 
71
    are reported differently in Trial L{TestResult}s. If todo'd tests fail,
 
72
    they do not fail the suite and the errors are reported in a separate
 
73
    category. If todo'd tests succeed, Trial L{TestResult}s will report an
 
74
    unexpected success.
 
75
    """
 
76
 
 
77
    def __init__(self, reason, errors=None):
 
78
        """
 
79
        @param reason: A string explaining why the test is marked 'todo'
 
80
 
 
81
        @param errors: An iterable of exception types that the test is
 
82
        expected to raise. If one of these errors is raised by the test, it
 
83
        will be trapped. Raising any other kind of error will fail the test.
 
84
        If C{None} is passed, then all errors will be trapped.
 
85
        """
 
86
        self.reason = reason
 
87
        self.errors = errors
 
88
 
 
89
    def __repr__(self):
 
90
        return "<Todo reason=%r errors=%r>" % (self.reason, self.errors)
 
91
 
 
92
    def expected(self, failure):
 
93
        """
 
94
        @param failure: A L{twisted.python.failure.Failure}.
 
95
 
 
96
        @return: C{True} if C{failure} is expected, C{False} otherwise.
 
97
        """
 
98
        if self.errors is None:
 
99
            return True
 
100
        for error in self.errors:
 
101
            if failure.check(error):
 
102
                return True
 
103
        return False
 
104
 
 
105
 
 
106
def makeTodo(value):
 
107
    """
 
108
    Return a L{Todo} object built from C{value}.
 
109
 
 
110
    If C{value} is a string, return a Todo that expects any exception with
 
111
    C{value} as a reason. If C{value} is a tuple, the second element is used
 
112
    as the reason and the first element as the excepted error(s).
 
113
 
 
114
    @param value: A string or a tuple of C{(errors, reason)}, where C{errors}
 
115
    is either a single exception class or an iterable of exception classes.
 
116
 
 
117
    @return: A L{Todo} object.
 
118
    """
 
119
    if isinstance(value, str):
 
120
        return Todo(reason=value)
 
121
    if isinstance(value, tuple):
 
122
        errors, reason = value
 
123
        try:
 
124
            errors = list(errors)
 
125
        except TypeError:
 
126
            errors = [errors]
 
127
        return Todo(reason=reason, errors=errors)
 
128
 
 
129
 
 
130
 
 
131
class _Warning(object):
 
132
    """
 
133
    A L{_Warning} instance represents one warning emitted through the Python
 
134
    warning system (L{warnings}).  This is used to insulate callers of
 
135
    L{_collectWarnings} from changes to the Python warnings system which might
 
136
    otherwise require changes to the warning objects that function passes to
 
137
    the observer object it accepts.
 
138
 
 
139
    @ivar message: The string which was passed as the message parameter to
 
140
        L{warnings.warn}.
 
141
 
 
142
    @ivar category: The L{Warning} subclass which was passed as the category
 
143
        parameter to L{warnings.warn}.
 
144
 
 
145
    @ivar filename: The name of the file containing the definition of the code
 
146
        object which was C{stacklevel} frames above the call to
 
147
        L{warnings.warn}, where C{stacklevel} is the value of the C{stacklevel}
 
148
        parameter passed to L{warnings.warn}.
 
149
 
 
150
    @ivar lineno: The source line associated with the active instruction of the
 
151
        code object object which was C{stacklevel} frames above the call to
 
152
        L{warnings.warn}, where C{stacklevel} is the value of the C{stacklevel}
 
153
        parameter passed to L{warnings.warn}.
 
154
    """
 
155
    def __init__(self, message, category, filename, lineno):
 
156
        self.message = message
 
157
        self.category = category
 
158
        self.filename = filename
 
159
        self.lineno = lineno
 
160
 
 
161
 
 
162
 
 
163
def _collectWarnings(observeWarning, f, *args, **kwargs):
 
164
    """
 
165
    Call C{f} with C{args} positional arguments and C{kwargs} keyword arguments
 
166
    and collect all warnings which are emitted as a result in a list.
 
167
 
 
168
    @param observeWarning: A callable which will be invoked with a L{_Warning}
 
169
        instance each time a warning is emitted.
 
170
 
 
171
    @return: The return value of C{f(*args, **kwargs)}.
 
172
    """
 
173
    def showWarning(message, category, filename, lineno, file=None, line=None):
 
174
        assert isinstance(message, Warning)
 
175
        observeWarning(_Warning(
 
176
                message.args[0], category, filename, lineno))
 
177
 
 
178
    # Disable the per-module cache for every module otherwise if the warning
 
179
    # which the caller is expecting us to collect was already emitted it won't
 
180
    # be re-emitted by the call to f which happens below.
 
181
    for v in sys.modules.itervalues():
 
182
        if v is not None:
 
183
            try:
 
184
                v.__warningregistry__ = None
 
185
            except:
 
186
                # Don't specify a particular exception type to handle in case
 
187
                # some wacky object raises some wacky exception in response to
 
188
                # the setattr attempt.
 
189
                pass
 
190
 
 
191
    origFilters = warnings.filters[:]
 
192
    origShow = warnings.showwarning
 
193
    warnings.simplefilter('always')
 
194
    try:
 
195
        warnings.showwarning = showWarning
 
196
        result = f(*args, **kwargs)
 
197
    finally:
 
198
        warnings.filters[:] = origFilters
 
199
        warnings.showwarning = origShow
 
200
    return result
 
201
 
 
202
 
 
203
 
 
204
class _Assertions(pyunit.TestCase, object):
 
205
    """
 
206
    Replaces many of the built-in TestCase assertions. In general, these
 
207
    assertions provide better error messages and are easier to use in
 
208
    callbacks. Also provides new assertions such as L{failUnlessFailure}.
 
209
 
 
210
    Although the tests are defined as 'failIf*' and 'failUnless*', they can
 
211
    also be called as 'assertNot*' and 'assert*'.
 
212
    """
 
213
 
 
214
    def fail(self, msg=None):
 
215
        """
 
216
        Absolutely fail the test.  Do not pass go, do not collect $200.
 
217
 
 
218
        @param msg: the message that will be displayed as the reason for the
 
219
        failure
 
220
        """
 
221
        raise self.failureException(msg)
 
222
 
 
223
    def failIf(self, condition, msg=None):
 
224
        """
 
225
        Fail the test if C{condition} evaluates to True.
 
226
 
 
227
        @param condition: any object that defines __nonzero__
 
228
        """
 
229
        if condition:
 
230
            raise self.failureException(msg)
 
231
        return condition
 
232
    assertNot = assertFalse = failUnlessFalse = failIf
 
233
 
 
234
    def failUnless(self, condition, msg=None):
 
235
        """
 
236
        Fail the test if C{condition} evaluates to False.
 
237
 
 
238
        @param condition: any object that defines __nonzero__
 
239
        """
 
240
        if not condition:
 
241
            raise self.failureException(msg)
 
242
        return condition
 
243
    assert_ = assertTrue = failUnlessTrue = failUnless
 
244
 
 
245
    def failUnlessRaises(self, exception, f, *args, **kwargs):
 
246
        """
 
247
        Fail the test unless calling the function C{f} with the given
 
248
        C{args} and C{kwargs} raises C{exception}. The failure will report
 
249
        the traceback and call stack of the unexpected exception.
 
250
 
 
251
        @param exception: exception type that is to be expected
 
252
        @param f: the function to call
 
253
 
 
254
        @return: The raised exception instance, if it is of the given type.
 
255
        @raise self.failureException: Raised if the function call does
 
256
            not raise an exception or if it raises an exception of a
 
257
            different type.
 
258
        """
 
259
        try:
 
260
            result = f(*args, **kwargs)
 
261
        except exception, inst:
 
262
            return inst
 
263
        except:
 
264
            raise self.failureException('%s raised instead of %s:\n %s'
 
265
                                        % (sys.exc_info()[0],
 
266
                                           exception.__name__,
 
267
                                           failure.Failure().getTraceback()))
 
268
        else:
 
269
            raise self.failureException('%s not raised (%r returned)'
 
270
                                        % (exception.__name__, result))
 
271
    assertRaises = failUnlessRaises
 
272
 
 
273
    def failUnlessEqual(self, first, second, msg=''):
 
274
        """
 
275
        Fail the test if C{first} and C{second} are not equal.
 
276
 
 
277
        @param msg: A string describing the failure that's included in the
 
278
            exception.
 
279
        """
 
280
        if not first == second:
 
281
            if msg is None:
 
282
                msg = ''
 
283
            if len(msg) > 0:
 
284
                msg += '\n'
 
285
            raise self.failureException(
 
286
                '%snot equal:\na = %s\nb = %s\n'
 
287
                % (msg, pformat(first), pformat(second)))
 
288
        return first
 
289
    assertEqual = assertEquals = failUnlessEquals = failUnlessEqual
 
290
 
 
291
    def failUnlessIdentical(self, first, second, msg=None):
 
292
        """
 
293
        Fail the test if C{first} is not C{second}.  This is an
 
294
        obect-identity-equality test, not an object equality
 
295
        (i.e. C{__eq__}) test.
 
296
 
 
297
        @param msg: if msg is None, then the failure message will be
 
298
        '%r is not %r' % (first, second)
 
299
        """
 
300
        if first is not second:
 
301
            raise self.failureException(msg or '%r is not %r' % (first, second))
 
302
        return first
 
303
    assertIdentical = failUnlessIdentical
 
304
 
 
305
    def failIfIdentical(self, first, second, msg=None):
 
306
        """
 
307
        Fail the test if C{first} is C{second}.  This is an
 
308
        obect-identity-equality test, not an object equality
 
309
        (i.e. C{__eq__}) test.
 
310
 
 
311
        @param msg: if msg is None, then the failure message will be
 
312
        '%r is %r' % (first, second)
 
313
        """
 
314
        if first is second:
 
315
            raise self.failureException(msg or '%r is %r' % (first, second))
 
316
        return first
 
317
    assertNotIdentical = failIfIdentical
 
318
 
 
319
    def failIfEqual(self, first, second, msg=None):
 
320
        """
 
321
        Fail the test if C{first} == C{second}.
 
322
 
 
323
        @param msg: if msg is None, then the failure message will be
 
324
        '%r == %r' % (first, second)
 
325
        """
 
326
        if not first != second:
 
327
            raise self.failureException(msg or '%r == %r' % (first, second))
 
328
        return first
 
329
    assertNotEqual = assertNotEquals = failIfEquals = failIfEqual
 
330
 
 
331
    def failUnlessIn(self, containee, container, msg=None):
 
332
        """
 
333
        Fail the test if C{containee} is not found in C{container}.
 
334
 
 
335
        @param containee: the value that should be in C{container}
 
336
        @param container: a sequence type, or in the case of a mapping type,
 
337
                          will follow semantics of 'if key in dict.keys()'
 
338
        @param msg: if msg is None, then the failure message will be
 
339
                    '%r not in %r' % (first, second)
 
340
        """
 
341
        if containee not in container:
 
342
            raise self.failureException(msg or "%r not in %r"
 
343
                                        % (containee, container))
 
344
        return containee
 
345
    assertIn = failUnlessIn
 
346
 
 
347
    def failIfIn(self, containee, container, msg=None):
 
348
        """
 
349
        Fail the test if C{containee} is found in C{container}.
 
350
 
 
351
        @param containee: the value that should not be in C{container}
 
352
        @param container: a sequence type, or in the case of a mapping type,
 
353
                          will follow semantics of 'if key in dict.keys()'
 
354
        @param msg: if msg is None, then the failure message will be
 
355
                    '%r in %r' % (first, second)
 
356
        """
 
357
        if containee in container:
 
358
            raise self.failureException(msg or "%r in %r"
 
359
                                        % (containee, container))
 
360
        return containee
 
361
    assertNotIn = failIfIn
 
362
 
 
363
    def failIfAlmostEqual(self, first, second, places=7, msg=None):
 
364
        """
 
365
        Fail if the two objects are equal as determined by their
 
366
        difference rounded to the given number of decimal places
 
367
        (default 7) and comparing to zero.
 
368
 
 
369
        @note: decimal places (from zero) is usually not the same
 
370
               as significant digits (measured from the most
 
371
               signficant digit).
 
372
 
 
373
        @note: included for compatiblity with PyUnit test cases
 
374
        """
 
375
        if round(second-first, places) == 0:
 
376
            raise self.failureException(msg or '%r == %r within %r places'
 
377
                                        % (first, second, places))
 
378
        return first
 
379
    assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
 
380
    failIfAlmostEquals = failIfAlmostEqual
 
381
 
 
382
    def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
 
383
        """
 
384
        Fail if the two objects are unequal as determined by their
 
385
        difference rounded to the given number of decimal places
 
386
        (default 7) and comparing to zero.
 
387
 
 
388
        @note: decimal places (from zero) is usually not the same
 
389
               as significant digits (measured from the most
 
390
               signficant digit).
 
391
 
 
392
        @note: included for compatiblity with PyUnit test cases
 
393
        """
 
394
        if round(second-first, places) != 0:
 
395
            raise self.failureException(msg or '%r != %r within %r places'
 
396
                                        % (first, second, places))
 
397
        return first
 
398
    assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
 
399
    failUnlessAlmostEquals = failUnlessAlmostEqual
 
400
 
 
401
    def failUnlessApproximates(self, first, second, tolerance, msg=None):
 
402
        """
 
403
        Fail if C{first} - C{second} > C{tolerance}
 
404
 
 
405
        @param msg: if msg is None, then the failure message will be
 
406
                    '%r ~== %r' % (first, second)
 
407
        """
 
408
        if abs(first - second) > tolerance:
 
409
            raise self.failureException(msg or "%s ~== %s" % (first, second))
 
410
        return first
 
411
    assertApproximates = failUnlessApproximates
 
412
 
 
413
    def failUnlessFailure(self, deferred, *expectedFailures):
 
414
        """
 
415
        Fail if C{deferred} does not errback with one of C{expectedFailures}.
 
416
        Returns the original Deferred with callbacks added. You will need
 
417
        to return this Deferred from your test case.
 
418
        """
 
419
        def _cb(ignore):
 
420
            raise self.failureException(
 
421
                "did not catch an error, instead got %r" % (ignore,))
 
422
 
 
423
        def _eb(failure):
 
424
            if failure.check(*expectedFailures):
 
425
                return failure.value
 
426
            else:
 
427
                output = ('\nExpected: %r\nGot:\n%s'
 
428
                          % (expectedFailures, str(failure)))
 
429
                raise self.failureException(output)
 
430
        return deferred.addCallbacks(_cb, _eb)
 
431
    assertFailure = failUnlessFailure
 
432
 
 
433
    def failUnlessSubstring(self, substring, astring, msg=None):
 
434
        """
 
435
        Fail if C{substring} does not exist within C{astring}.
 
436
        """
 
437
        return self.failUnlessIn(substring, astring, msg)
 
438
    assertSubstring = failUnlessSubstring
 
439
 
 
440
    def failIfSubstring(self, substring, astring, msg=None):
 
441
        """
 
442
        Fail if C{astring} contains C{substring}.
 
443
        """
 
444
        return self.failIfIn(substring, astring, msg)
 
445
    assertNotSubstring = failIfSubstring
 
446
 
 
447
    def failUnlessWarns(self, category, message, filename, f,
 
448
                       *args, **kwargs):
 
449
        """
 
450
        Fail if the given function doesn't generate the specified warning when
 
451
        called. It calls the function, checks the warning, and forwards the
 
452
        result of the function if everything is fine.
 
453
 
 
454
        @param category: the category of the warning to check.
 
455
        @param message: the output message of the warning to check.
 
456
        @param filename: the filename where the warning should come from.
 
457
        @param f: the function which is supposed to generate the warning.
 
458
        @type f: any callable.
 
459
        @param args: the arguments to C{f}.
 
460
        @param kwargs: the keywords arguments to C{f}.
 
461
 
 
462
        @return: the result of the original function C{f}.
 
463
        """
 
464
        warningsShown = []
 
465
        result = _collectWarnings(warningsShown.append, f, *args, **kwargs)
 
466
 
 
467
        if not warningsShown:
 
468
            self.fail("No warnings emitted")
 
469
        first = warningsShown[0]
 
470
        for other in warningsShown[1:]:
 
471
            if ((other.message, other.category)
 
472
                != (first.message, first.category)):
 
473
                self.fail("Can't handle different warnings")
 
474
        self.assertEqual(first.message, message)
 
475
        self.assertIdentical(first.category, category)
 
476
 
 
477
        # Use starts with because of .pyc/.pyo issues.
 
478
        self.failUnless(
 
479
            filename.startswith(first.filename),
 
480
            'Warning in %r, expected %r' % (first.filename, filename))
 
481
 
 
482
        # It would be nice to be able to check the line number as well, but
 
483
        # different configurations actually end up reporting different line
 
484
        # numbers (generally the variation is only 1 line, but that's enough
 
485
        # to fail the test erroneously...).
 
486
        # self.assertEqual(lineno, xxx)
 
487
 
 
488
        return result
 
489
    assertWarns = failUnlessWarns
 
490
 
 
491
    def failUnlessIsInstance(self, instance, classOrTuple):
 
492
        """
 
493
        Fail if C{instance} is not an instance of the given class or of
 
494
        one of the given classes.
 
495
 
 
496
        @param instance: the object to test the type (first argument of the
 
497
            C{isinstance} call).
 
498
        @type instance: any.
 
499
        @param classOrTuple: the class or classes to test against (second
 
500
            argument of the C{isinstance} call).
 
501
        @type classOrTuple: class, type, or tuple.
 
502
        """
 
503
        if not isinstance(instance, classOrTuple):
 
504
            self.fail("%r is not an instance of %s" % (instance, classOrTuple))
 
505
    assertIsInstance = failUnlessIsInstance
 
506
 
 
507
    def failIfIsInstance(self, instance, classOrTuple):
 
508
        """
 
509
        Fail if C{instance} is not an instance of the given class or of
 
510
        one of the given classes.
 
511
 
 
512
        @param instance: the object to test the type (first argument of the
 
513
            C{isinstance} call).
 
514
        @type instance: any.
 
515
        @param classOrTuple: the class or classes to test against (second
 
516
            argument of the C{isinstance} call).
 
517
        @type classOrTuple: class, type, or tuple.
 
518
        """
 
519
        if isinstance(instance, classOrTuple):
 
520
            self.fail("%r is an instance of %s" % (instance, classOrTuple))
 
521
    assertNotIsInstance = failIfIsInstance
 
522
 
 
523
 
 
524
class _LogObserver(object):
 
525
    """
 
526
    Observes the Twisted logs and catches any errors.
 
527
 
 
528
    @ivar _errors: A C{list} of L{Failure} instances which were received as
 
529
        error events from the Twisted logging system.
 
530
 
 
531
    @ivar _added: A C{int} giving the number of times C{_add} has been called
 
532
        less the number of times C{_remove} has been called; used to only add
 
533
        this observer to the Twisted logging since once, regardless of the
 
534
        number of calls to the add method.
 
535
 
 
536
    @ivar _ignored: A C{list} of exception types which will not be recorded.
 
537
    """
 
538
 
 
539
    def __init__(self):
 
540
        self._errors = []
 
541
        self._added = 0
 
542
        self._ignored = []
 
543
 
 
544
 
 
545
    def _add(self):
 
546
        if self._added == 0:
 
547
            log.addObserver(self.gotEvent)
 
548
            self._oldFE, log._flushErrors = (log._flushErrors, self.flushErrors)
 
549
            self._oldIE, log._ignore = (log._ignore, self._ignoreErrors)
 
550
            self._oldCI, log._clearIgnores = (log._clearIgnores,
 
551
                                              self._clearIgnores)
 
552
        self._added += 1
 
553
 
 
554
    def _remove(self):
 
555
        self._added -= 1
 
556
        if self._added == 0:
 
557
            log.removeObserver(self.gotEvent)
 
558
            log._flushErrors = self._oldFE
 
559
            log._ignore = self._oldIE
 
560
            log._clearIgnores = self._oldCI
 
561
 
 
562
 
 
563
    def _ignoreErrors(self, *errorTypes):
 
564
        """
 
565
        Do not store any errors with any of the given types.
 
566
        """
 
567
        self._ignored.extend(errorTypes)
 
568
 
 
569
 
 
570
    def _clearIgnores(self):
 
571
        """
 
572
        Stop ignoring any errors we might currently be ignoring.
 
573
        """
 
574
        self._ignored = []
 
575
 
 
576
 
 
577
    def flushErrors(self, *errorTypes):
 
578
        """
 
579
        Flush errors from the list of caught errors. If no arguments are
 
580
        specified, remove all errors. If arguments are specified, only remove
 
581
        errors of those types from the stored list.
 
582
        """
 
583
        if errorTypes:
 
584
            flushed = []
 
585
            remainder = []
 
586
            for f in self._errors:
 
587
                if f.check(*errorTypes):
 
588
                    flushed.append(f)
 
589
                else:
 
590
                    remainder.append(f)
 
591
            self._errors = remainder
 
592
        else:
 
593
            flushed = self._errors
 
594
            self._errors = []
 
595
        return flushed
 
596
 
 
597
 
 
598
    def getErrors(self):
 
599
        """
 
600
        Return a list of errors caught by this observer.
 
601
        """
 
602
        return self._errors
 
603
 
 
604
 
 
605
    def gotEvent(self, event):
 
606
        """
 
607
        The actual observer method. Called whenever a message is logged.
 
608
 
 
609
        @param event: A dictionary containing the log message. Actual
 
610
        structure undocumented (see source for L{twisted.python.log}).
 
611
        """
 
612
        if event.get('isError', False) and 'failure' in event:
 
613
            f = event['failure']
 
614
            if len(self._ignored) == 0 or not f.check(*self._ignored):
 
615
                self._errors.append(f)
 
616
 
 
617
 
 
618
 
 
619
_logObserver = _LogObserver()
 
620
 
 
621
_wait_is_running = []
 
622
 
 
623
class TestCase(_Assertions):
 
624
    """
 
625
    A unit test. The atom of the unit testing universe.
 
626
 
 
627
    This class extends C{unittest.TestCase} from the standard library. The
 
628
    main feature is the ability to return C{Deferred}s from tests and fixture
 
629
    methods and to have the suite wait for those C{Deferred}s to fire.
 
630
 
 
631
    To write a unit test, subclass C{TestCase} and define a method (say,
 
632
    'test_foo') on the subclass. To run the test, instantiate your subclass
 
633
    with the name of the method, and call L{run} on the instance, passing a
 
634
    L{TestResult} object.
 
635
 
 
636
    The C{trial} script will automatically find any C{TestCase} subclasses
 
637
    defined in modules beginning with 'test_' and construct test cases for all
 
638
    methods beginning with 'test'.
 
639
 
 
640
    If an error is logged during the test run, the test will fail with an
 
641
    error. See L{log.err}.
 
642
 
 
643
    @ivar failureException: An exception class, defaulting to C{FailTest}. If
 
644
    the test method raises this exception, it will be reported as a failure,
 
645
    rather than an exception. All of the assertion methods raise this if the
 
646
    assertion fails.
 
647
 
 
648
    @ivar skip: C{None} or a string explaining why this test is to be
 
649
    skipped. If defined, the test will not be run. Instead, it will be
 
650
    reported to the result object as 'skipped' (if the C{TestResult} supports
 
651
    skipping).
 
652
 
 
653
    @ivar suppress: C{None} or a list of tuples of C{(args, kwargs)} to be
 
654
    passed to C{warnings.filterwarnings}. Use these to suppress warnings
 
655
    raised in a test. Useful for testing deprecated code. See also
 
656
    L{util.suppress}.
 
657
 
 
658
    @ivar timeout: A real number of seconds. If set, the test will
 
659
    raise an error if it takes longer than C{timeout} seconds.
 
660
    If not set, util.DEFAULT_TIMEOUT_DURATION is used.
 
661
 
 
662
    @ivar todo: C{None}, a string or a tuple of C{(errors, reason)} where
 
663
    C{errors} is either an exception class or an iterable of exception
 
664
    classes, and C{reason} is a string. See L{Todo} or L{makeTodo} for more
 
665
    information.
 
666
    """
 
667
 
 
668
    implements(itrial.ITestCase)
 
669
    failureException = FailTest
 
670
 
 
671
    def __init__(self, methodName='runTest'):
 
672
        """
 
673
        Construct an asynchronous test case for C{methodName}.
 
674
 
 
675
        @param methodName: The name of a method on C{self}. This method should
 
676
        be a unit test. That is, it should be a short method that calls some of
 
677
        the assert* methods. If C{methodName} is unspecified, L{runTest} will
 
678
        be used as the test method. This is mostly useful for testing Trial.
 
679
        """
 
680
        super(TestCase, self).__init__(methodName)
 
681
        self._testMethodName = methodName
 
682
        testMethod = getattr(self, methodName)
 
683
        self._parents = [testMethod, self]
 
684
        self._parents.extend(util.getPythonContainers(testMethod))
 
685
        self._passed = False
 
686
        self._cleanups = []
 
687
 
 
688
    if sys.version_info >= (2, 6):
 
689
        # Override the comparison defined by the base TestCase which considers
 
690
        # instances of the same class with the same _testMethodName to be
 
691
        # equal.  Since trial puts TestCase instances into a set, that
 
692
        # definition of comparison makes it impossible to run the same test
 
693
        # method twice.  Most likely, trial should stop using a set to hold
 
694
        # tests, but until it does, this is necessary on Python 2.6.  Only
 
695
        # __eq__ and __ne__ are required here, not __hash__, since the
 
696
        # inherited __hash__ is compatible with these equality semantics.  A
 
697
        # different __hash__ might be slightly more efficient (by reducing
 
698
        # collisions), but who cares? -exarkun
 
699
        def __eq__(self, other):
 
700
            return self is other
 
701
 
 
702
        def __ne__(self, other):
 
703
            return self is not other
 
704
 
 
705
 
 
706
    def _run(self, methodName, result):
 
707
        from twisted.internet import reactor
 
708
        timeout = self.getTimeout()
 
709
        def onTimeout(d):
 
710
            e = defer.TimeoutError("%r (%s) still running at %s secs"
 
711
                % (self, methodName, timeout))
 
712
            f = failure.Failure(e)
 
713
            # try to errback the deferred that the test returns (for no gorram
 
714
            # reason) (see issue1005 and test_errorPropagation in
 
715
            # test_deferred)
 
716
            try:
 
717
                d.errback(f)
 
718
            except defer.AlreadyCalledError:
 
719
                # if the deferred has been called already but the *back chain
 
720
                # is still unfinished, crash the reactor and report timeout
 
721
                # error ourself.
 
722
                reactor.crash()
 
723
                self._timedOut = True # see self._wait
 
724
                todo = self.getTodo()
 
725
                if todo is not None and todo.expected(f):
 
726
                    result.addExpectedFailure(self, f, todo)
 
727
                else:
 
728
                    result.addError(self, f)
 
729
        onTimeout = utils.suppressWarnings(
 
730
            onTimeout, util.suppress(category=DeprecationWarning))
 
731
        method = getattr(self, methodName)
 
732
        d = defer.maybeDeferred(utils.runWithWarningsSuppressed,
 
733
                                self.getSuppress(), method)
 
734
        call = reactor.callLater(timeout, onTimeout, d)
 
735
        d.addBoth(lambda x : call.active() and call.cancel() or x)
 
736
        return d
 
737
 
 
738
    def shortDescription(self):
 
739
        desc = super(TestCase, self).shortDescription()
 
740
        if desc is None:
 
741
            return self._testMethodName
 
742
        return desc
 
743
 
 
744
    def __call__(self, *args, **kwargs):
 
745
        return self.run(*args, **kwargs)
 
746
 
 
747
    def deferSetUp(self, ignored, result):
 
748
        d = self._run('setUp', result)
 
749
        d.addCallbacks(self.deferTestMethod, self._ebDeferSetUp,
 
750
                       callbackArgs=(result,),
 
751
                       errbackArgs=(result,))
 
752
        return d
 
753
 
 
754
    def _ebDeferSetUp(self, failure, result):
 
755
        if failure.check(SkipTest):
 
756
            result.addSkip(self, self._getReason(failure))
 
757
        else:
 
758
            result.addError(self, failure)
 
759
            if failure.check(KeyboardInterrupt):
 
760
                result.stop()
 
761
        return self.deferRunCleanups(None, result)
 
762
 
 
763
    def deferTestMethod(self, ignored, result):
 
764
        d = self._run(self._testMethodName, result)
 
765
        d.addCallbacks(self._cbDeferTestMethod, self._ebDeferTestMethod,
 
766
                       callbackArgs=(result,),
 
767
                       errbackArgs=(result,))
 
768
        d.addBoth(self.deferRunCleanups, result)
 
769
        d.addBoth(self.deferTearDown, result)
 
770
        return d
 
771
 
 
772
    def _cbDeferTestMethod(self, ignored, result):
 
773
        if self.getTodo() is not None:
 
774
            result.addUnexpectedSuccess(self, self.getTodo())
 
775
        else:
 
776
            self._passed = True
 
777
        return ignored
 
778
 
 
779
    def _ebDeferTestMethod(self, f, result):
 
780
        todo = self.getTodo()
 
781
        if todo is not None and todo.expected(f):
 
782
            result.addExpectedFailure(self, f, todo)
 
783
        elif f.check(self.failureException, FailTest):
 
784
            result.addFailure(self, f)
 
785
        elif f.check(KeyboardInterrupt):
 
786
            result.addError(self, f)
 
787
            result.stop()
 
788
        elif f.check(SkipTest):
 
789
            result.addSkip(self, self._getReason(f))
 
790
        else:
 
791
            result.addError(self, f)
 
792
 
 
793
    def deferTearDown(self, ignored, result):
 
794
        d = self._run('tearDown', result)
 
795
        d.addErrback(self._ebDeferTearDown, result)
 
796
        return d
 
797
 
 
798
    def _ebDeferTearDown(self, failure, result):
 
799
        result.addError(self, failure)
 
800
        if failure.check(KeyboardInterrupt):
 
801
            result.stop()
 
802
        self._passed = False
 
803
 
 
804
    def deferRunCleanups(self, ignored, result):
 
805
        """
 
806
        Run any scheduled cleanups and report errors (if any to the result
 
807
        object.
 
808
        """
 
809
        d = self._runCleanups()
 
810
        d.addCallback(self._cbDeferRunCleanups, result)
 
811
        return d
 
812
 
 
813
    def _cbDeferRunCleanups(self, cleanupResults, result):
 
814
        for flag, failure in cleanupResults:
 
815
            if flag == defer.FAILURE:
 
816
                result.addError(self, failure)
 
817
                if failure.check(KeyboardInterrupt):
 
818
                    result.stop()
 
819
                self._passed = False
 
820
 
 
821
    def _cleanUp(self, result):
 
822
        try:
 
823
            clean = util._Janitor(self, result).postCaseCleanup()
 
824
            if not clean:
 
825
                self._passed = False
 
826
        except:
 
827
            result.addError(self, failure.Failure())
 
828
            self._passed = False
 
829
        for error in self._observer.getErrors():
 
830
            result.addError(self, error)
 
831
            self._passed = False
 
832
        self.flushLoggedErrors()
 
833
        self._removeObserver()
 
834
        if self._passed:
 
835
            result.addSuccess(self)
 
836
 
 
837
    def _classCleanUp(self, result):
 
838
        try:
 
839
            util._Janitor(self, result).postClassCleanup()
 
840
        except:
 
841
            result.addError(self, failure.Failure())
 
842
 
 
843
    def _makeReactorMethod(self, name):
 
844
        """
 
845
        Create a method which wraps the reactor method C{name}. The new
 
846
        method issues a deprecation warning and calls the original.
 
847
        """
 
848
        def _(*a, **kw):
 
849
            warnings.warn("reactor.%s cannot be used inside unit tests. "
 
850
                          "In the future, using %s will fail the test and may "
 
851
                          "crash or hang the test run."
 
852
                          % (name, name),
 
853
                          stacklevel=2, category=DeprecationWarning)
 
854
            return self._reactorMethods[name](*a, **kw)
 
855
        return _
 
856
 
 
857
    def _deprecateReactor(self, reactor):
 
858
        """
 
859
        Deprecate C{iterate}, C{crash} and C{stop} on C{reactor}. That is,
 
860
        each method is wrapped in a function that issues a deprecation
 
861
        warning, then calls the original.
 
862
 
 
863
        @param reactor: The Twisted reactor.
 
864
        """
 
865
        self._reactorMethods = {}
 
866
        for name in ['crash', 'iterate', 'stop']:
 
867
            self._reactorMethods[name] = getattr(reactor, name)
 
868
            setattr(reactor, name, self._makeReactorMethod(name))
 
869
 
 
870
    def _undeprecateReactor(self, reactor):
 
871
        """
 
872
        Restore the deprecated reactor methods. Undoes what
 
873
        L{_deprecateReactor} did.
 
874
 
 
875
        @param reactor: The Twisted reactor.
 
876
        """
 
877
        for name, method in self._reactorMethods.iteritems():
 
878
            setattr(reactor, name, method)
 
879
        self._reactorMethods = {}
 
880
 
 
881
    def _installObserver(self):
 
882
        self._observer = _logObserver
 
883
        self._observer._add()
 
884
 
 
885
    def _removeObserver(self):
 
886
        self._observer._remove()
 
887
 
 
888
    def flushLoggedErrors(self, *errorTypes):
 
889
        """
 
890
        Remove stored errors received from the log.
 
891
 
 
892
        C{TestCase} stores each error logged during the run of the test and
 
893
        reports them as errors during the cleanup phase (after C{tearDown}).
 
894
 
 
895
        @param *errorTypes: If unspecifed, flush all errors. Otherwise, only
 
896
        flush errors that match the given types.
 
897
 
 
898
        @return: A list of failures that have been removed.
 
899
        """
 
900
        return self._observer.flushErrors(*errorTypes)
 
901
 
 
902
 
 
903
    def flushWarnings(self, offendingFunctions=None):
 
904
        """
 
905
        Remove stored warnings from the list of captured warnings and return
 
906
        them.
 
907
 
 
908
        @param offendingFunctions: If C{None}, all warnings issued during the
 
909
            currently running test will be flushed.  Otherwise, only warnings
 
910
            which I{point} to a function included in this list will be flushed.
 
911
            All warnings include a filename and source line number; if these
 
912
            parts of a warning point to a source line which is part of a
 
913
            function, then the warning I{points} to that function.
 
914
        @type offendingFunctions: L{NoneType} or L{list} of functions or methods.
 
915
 
 
916
        @raise ValueError: If C{offendingFunctions} is not C{None} and includes
 
917
            an object which is not a L{FunctionType} or L{MethodType} instance.
 
918
 
 
919
        @return: A C{list}, each element of which is a C{dict} giving
 
920
            information about one warning which was flushed by this call.  The
 
921
            keys of each C{dict} are:
 
922
 
 
923
                - C{'message'}: The string which was passed as the I{message}
 
924
                  parameter to L{warnings.warn}.
 
925
 
 
926
                - C{'category'}: The warning subclass which was passed as the
 
927
                  I{category} parameter to L{warnings.warn}.
 
928
 
 
929
                - C{'filename'}: The name of the file containing the definition
 
930
                  of the code object which was C{stacklevel} frames above the
 
931
                  call to L{warnings.warn}, where C{stacklevel} is the value of
 
932
                  the C{stacklevel} parameter passed to L{warnings.warn}.
 
933
 
 
934
                - C{'lineno'}: The source line associated with the active
 
935
                  instruction of the code object object which was C{stacklevel}
 
936
                  frames above the call to L{warnings.warn}, where
 
937
                  C{stacklevel} is the value of the C{stacklevel} parameter
 
938
                  passed to L{warnings.warn}.
 
939
        """
 
940
        if offendingFunctions is None:
 
941
            toFlush = self._warnings[:]
 
942
            self._warnings[:] = []
 
943
        else:
 
944
            toFlush = []
 
945
            for aWarning in self._warnings:
 
946
                for aFunction in offendingFunctions:
 
947
                    if not isinstance(aFunction, (
 
948
                            types.FunctionType, types.MethodType)):
 
949
                        raise ValueError("%r is not a function or method" % (
 
950
                                aFunction,))
 
951
 
 
952
                    # inspect.getabsfile(aFunction) sometimes returns a
 
953
                    # filename which disagrees with the filename the warning
 
954
                    # system generates.  This seems to be because a
 
955
                    # function's code object doesn't deal with source files
 
956
                    # being renamed.  inspect.getabsfile(module) seems
 
957
                    # better (or at least agrees with the warning system
 
958
                    # more often), and does some normalization for us which
 
959
                    # is desirable.  inspect.getmodule() is attractive, but
 
960
                    # somewhat broken in Python 2.3.  See Python bug 4845.
 
961
                    aModule = sys.modules[aFunction.__module__]
 
962
                    filename = inspect.getabsfile(aModule)
 
963
 
 
964
                    if filename != os.path.normcase(aWarning.filename):
 
965
                        continue
 
966
                    lineStarts = list(_findlinestarts(aFunction.func_code))
 
967
                    first = lineStarts[0][1]
 
968
                    last = lineStarts[-1][1]
 
969
                    if not (first <= aWarning.lineno <= last):
 
970
                        continue
 
971
                    # The warning points to this function, flush it and move on
 
972
                    # to the next warning.
 
973
                    toFlush.append(aWarning)
 
974
                    break
 
975
            # Remove everything which is being flushed.
 
976
            map(self._warnings.remove, toFlush)
 
977
 
 
978
        return [
 
979
            {'message': w.message, 'category': w.category,
 
980
             'filename': w.filename, 'lineno': w.lineno}
 
981
            for w in toFlush]
 
982
 
 
983
 
 
984
    def addCleanup(self, f, *args, **kwargs):
 
985
        """
 
986
        Add the given function to a list of functions to be called after the
 
987
        test has run, but before C{tearDown}.
 
988
 
 
989
        Functions will be run in reverse order of being added. This helps
 
990
        ensure that tear down complements set up.
 
991
 
 
992
        The function C{f} may return a Deferred. If so, C{TestCase} will wait
 
993
        until the Deferred has fired before proceeding to the next function.
 
994
        """
 
995
        self._cleanups.append((f, args, kwargs))
 
996
 
 
997
 
 
998
    def callDeprecated(self, version, f, *args, **kwargs):
 
999
        """
 
1000
        Call a function that was deprecated at a specific version.
 
1001
 
 
1002
        @param version: The version that the function was deprecated in.
 
1003
        @param f: The deprecated function to call.
 
1004
        @return: Whatever the function returns.
 
1005
        """
 
1006
        result = f(*args, **kwargs)
 
1007
        warningsShown = self.flushWarnings([self.callDeprecated])
 
1008
 
 
1009
        if len(warningsShown) == 0:
 
1010
            self.fail('%r is not deprecated.' % (f,))
 
1011
 
 
1012
        observedWarning = warningsShown[0]['message']
 
1013
        expectedWarning = getDeprecationWarningString(f, version)
 
1014
        self.assertEqual(expectedWarning, observedWarning)
 
1015
 
 
1016
        return result
 
1017
 
 
1018
 
 
1019
    def _runCleanups(self):
 
1020
        """
 
1021
        Run the cleanups added with L{addCleanup} in order.
 
1022
 
 
1023
        @return: A C{Deferred} that fires when all cleanups are run.
 
1024
        """
 
1025
        def _makeFunction(f, args, kwargs):
 
1026
            return lambda: f(*args, **kwargs)
 
1027
        callables = []
 
1028
        while len(self._cleanups) > 0:
 
1029
            f, args, kwargs = self._cleanups.pop()
 
1030
            callables.append(_makeFunction(f, args, kwargs))
 
1031
        return util._runSequentially(callables)
 
1032
 
 
1033
 
 
1034
    def patch(self, obj, attribute, value):
 
1035
        """
 
1036
        Monkey patch an object for the duration of the test.
 
1037
 
 
1038
        The monkey patch will be reverted at the end of the test using the
 
1039
        L{addCleanup} mechanism.
 
1040
 
 
1041
        The L{MonkeyPatcher} is returned so that users can restore and
 
1042
        re-apply the monkey patch within their tests.
 
1043
 
 
1044
        @param obj: The object to monkey patch.
 
1045
        @param attribute: The name of the attribute to change.
 
1046
        @param value: The value to set the attribute to.
 
1047
        @return: A L{monkey.MonkeyPatcher} object.
 
1048
        """
 
1049
        monkeyPatch = monkey.MonkeyPatcher((obj, attribute, value))
 
1050
        monkeyPatch.patch()
 
1051
        self.addCleanup(monkeyPatch.restore)
 
1052
        return monkeyPatch
 
1053
 
 
1054
 
 
1055
    def runTest(self):
 
1056
        """
 
1057
        If no C{methodName} argument is passed to the constructor, L{run} will
 
1058
        treat this method as the thing with the actual test inside.
 
1059
        """
 
1060
 
 
1061
 
 
1062
    def run(self, result):
 
1063
        """
 
1064
        Run the test case, storing the results in C{result}.
 
1065
 
 
1066
        First runs C{setUp} on self, then runs the test method (defined in the
 
1067
        constructor), then runs C{tearDown}. Any of these may return
 
1068
        L{Deferred}s. After they complete, does some reactor cleanup.
 
1069
 
 
1070
        @param result: A L{TestResult} object.
 
1071
        """
 
1072
        log.msg("--> %s <--" % (self.id()))
 
1073
        from twisted.internet import reactor
 
1074
        new_result = itrial.IReporter(result, None)
 
1075
        if new_result is None:
 
1076
            result = PyUnitResultAdapter(result)
 
1077
        else:
 
1078
            result = new_result
 
1079
        self._timedOut = False
 
1080
        result.startTest(self)
 
1081
        if self.getSkip(): # don't run test methods that are marked as .skip
 
1082
            result.addSkip(self, self.getSkip())
 
1083
            result.stopTest(self)
 
1084
            return
 
1085
        self._installObserver()
 
1086
 
 
1087
        # All the code inside runThunk will be run such that warnings emitted
 
1088
        # by it will be collected and retrievable by flushWarnings.
 
1089
        def runThunk():
 
1090
            self._passed = False
 
1091
            self._deprecateReactor(reactor)
 
1092
            try:
 
1093
                d = self.deferSetUp(None, result)
 
1094
                try:
 
1095
                    self._wait(d)
 
1096
                finally:
 
1097
                    self._cleanUp(result)
 
1098
                    self._classCleanUp(result)
 
1099
            finally:
 
1100
                self._undeprecateReactor(reactor)
 
1101
 
 
1102
        self._warnings = []
 
1103
        _collectWarnings(self._warnings.append, runThunk)
 
1104
 
 
1105
        # Any collected warnings which the test method didn't flush get
 
1106
        # re-emitted so they'll be logged or show up on stdout or whatever.
 
1107
        for w in self.flushWarnings():
 
1108
            try:
 
1109
                warnings.warn_explicit(**w)
 
1110
            except:
 
1111
                result.addError(self, failure.Failure())
 
1112
 
 
1113
        result.stopTest(self)
 
1114
 
 
1115
 
 
1116
    def _getReason(self, f):
 
1117
        if len(f.value.args) > 0:
 
1118
            reason = f.value.args[0]
 
1119
        else:
 
1120
            warnings.warn(("Do not raise unittest.SkipTest with no "
 
1121
                           "arguments! Give a reason for skipping tests!"),
 
1122
                          stacklevel=2)
 
1123
            reason = f
 
1124
        return reason
 
1125
 
 
1126
    def getSkip(self):
 
1127
        """
 
1128
        Return the skip reason set on this test, if any is set. Checks on the
 
1129
        instance first, then the class, then the module, then packages. As
 
1130
        soon as it finds something with a C{skip} attribute, returns that.
 
1131
        Returns C{None} if it cannot find anything. See L{TestCase} docstring
 
1132
        for more details.
 
1133
        """
 
1134
        return util.acquireAttribute(self._parents, 'skip', None)
 
1135
 
 
1136
    def getTodo(self):
 
1137
        """
 
1138
        Return a L{Todo} object if the test is marked todo. Checks on the
 
1139
        instance first, then the class, then the module, then packages. As
 
1140
        soon as it finds something with a C{todo} attribute, returns that.
 
1141
        Returns C{None} if it cannot find anything. See L{TestCase} docstring
 
1142
        for more details.
 
1143
        """
 
1144
        todo = util.acquireAttribute(self._parents, 'todo', None)
 
1145
        if todo is None:
 
1146
            return None
 
1147
        return makeTodo(todo)
 
1148
 
 
1149
    def getTimeout(self):
 
1150
        """
 
1151
        Returns the timeout value set on this test. Checks on the instance
 
1152
        first, then the class, then the module, then packages. As soon as it
 
1153
        finds something with a C{timeout} attribute, returns that. Returns
 
1154
        L{util.DEFAULT_TIMEOUT_DURATION} if it cannot find anything. See
 
1155
        L{TestCase} docstring for more details.
 
1156
        """
 
1157
        timeout =  util.acquireAttribute(self._parents, 'timeout',
 
1158
                                         util.DEFAULT_TIMEOUT_DURATION)
 
1159
        try:
 
1160
            return float(timeout)
 
1161
        except (ValueError, TypeError):
 
1162
            # XXX -- this is here because sometimes people will have methods
 
1163
            # called 'timeout', or set timeout to 'orange', or something
 
1164
            # Particularly, test_news.NewsTestCase and ReactorCoreTestCase
 
1165
            # both do this.
 
1166
            warnings.warn("'timeout' attribute needs to be a number.",
 
1167
                          category=DeprecationWarning)
 
1168
            return util.DEFAULT_TIMEOUT_DURATION
 
1169
 
 
1170
    def getSuppress(self):
 
1171
        """
 
1172
        Returns any warning suppressions set for this test. Checks on the
 
1173
        instance first, then the class, then the module, then packages. As
 
1174
        soon as it finds something with a C{suppress} attribute, returns that.
 
1175
        Returns any empty list (i.e. suppress no warnings) if it cannot find
 
1176
        anything. See L{TestCase} docstring for more details.
 
1177
        """
 
1178
        return util.acquireAttribute(self._parents, 'suppress', [])
 
1179
 
 
1180
 
 
1181
    def visit(self, visitor):
 
1182
        """
 
1183
        Visit this test case. Call C{visitor} with C{self} as a parameter.
 
1184
 
 
1185
        Deprecated in Twisted 8.0.
 
1186
 
 
1187
        @param visitor: A callable which expects a single parameter: a test
 
1188
        case.
 
1189
 
 
1190
        @return: None
 
1191
        """
 
1192
        warnings.warn("Test visitors deprecated in Twisted 8.0",
 
1193
                      category=DeprecationWarning)
 
1194
        visitor(self)
 
1195
 
 
1196
 
 
1197
    def mktemp(self):
 
1198
        """Returns a unique name that may be used as either a temporary
 
1199
        directory or filename.
 
1200
 
 
1201
        @note: you must call os.mkdir on the value returned from this
 
1202
               method if you wish to use it as a directory!
 
1203
        """
 
1204
        MAX_FILENAME = 32 # some platforms limit lengths of filenames
 
1205
        base = os.path.join(self.__class__.__module__[:MAX_FILENAME],
 
1206
                            self.__class__.__name__[:MAX_FILENAME],
 
1207
                            self._testMethodName[:MAX_FILENAME])
 
1208
        if not os.path.exists(base):
 
1209
            os.makedirs(base)
 
1210
        dirname = tempfile.mkdtemp('', '', base)
 
1211
        return os.path.join(dirname, 'temp')
 
1212
 
 
1213
    def _wait(self, d, running=_wait_is_running):
 
1214
        """Take a Deferred that only ever callbacks. Block until it happens.
 
1215
        """
 
1216
        from twisted.internet import reactor
 
1217
        if running:
 
1218
            raise RuntimeError("_wait is not reentrant")
 
1219
 
 
1220
        results = []
 
1221
        def append(any):
 
1222
            if results is not None:
 
1223
                results.append(any)
 
1224
        def crash(ign):
 
1225
            if results is not None:
 
1226
                reactor.crash()
 
1227
        crash = utils.suppressWarnings(
 
1228
            crash, util.suppress(message=r'reactor\.crash cannot be used.*',
 
1229
                                 category=DeprecationWarning))
 
1230
        def stop():
 
1231
            reactor.crash()
 
1232
        stop = utils.suppressWarnings(
 
1233
            stop, util.suppress(message=r'reactor\.crash cannot be used.*',
 
1234
                                category=DeprecationWarning))
 
1235
 
 
1236
        running.append(None)
 
1237
        try:
 
1238
            d.addBoth(append)
 
1239
            if results:
 
1240
                # d might have already been fired, in which case append is
 
1241
                # called synchronously. Avoid any reactor stuff.
 
1242
                return
 
1243
            d.addBoth(crash)
 
1244
            reactor.stop = stop
 
1245
            try:
 
1246
                reactor.run()
 
1247
            finally:
 
1248
                del reactor.stop
 
1249
 
 
1250
            # If the reactor was crashed elsewhere due to a timeout, hopefully
 
1251
            # that crasher also reported an error. Just return.
 
1252
            # _timedOut is most likely to be set when d has fired but hasn't
 
1253
            # completed its callback chain (see self._run)
 
1254
            if results or self._timedOut: #defined in run() and _run()
 
1255
                return
 
1256
 
 
1257
            # If the timeout didn't happen, and we didn't get a result or
 
1258
            # a failure, then the user probably aborted the test, so let's
 
1259
            # just raise KeyboardInterrupt.
 
1260
 
 
1261
            # FIXME: imagine this:
 
1262
            # web/test/test_webclient.py:
 
1263
            # exc = self.assertRaises(error.Error, wait, method(url))
 
1264
            #
 
1265
            # wait() will raise KeyboardInterrupt, and assertRaises will
 
1266
            # swallow it. Therefore, wait() raising KeyboardInterrupt is
 
1267
            # insufficient to stop trial. A suggested solution is to have
 
1268
            # this code set a "stop trial" flag, or otherwise notify trial
 
1269
            # that it should really try to stop as soon as possible.
 
1270
            raise KeyboardInterrupt()
 
1271
        finally:
 
1272
            results = None
 
1273
            running.pop()
 
1274
 
 
1275
 
 
1276
class UnsupportedTrialFeature(Exception):
 
1277
    """A feature of twisted.trial was used that pyunit cannot support."""
 
1278
 
 
1279
 
 
1280
 
 
1281
class PyUnitResultAdapter(object):
 
1282
    """
 
1283
    Wrap a C{TestResult} from the standard library's C{unittest} so that it
 
1284
    supports the extended result types from Trial, and also supports
 
1285
    L{twisted.python.failure.Failure}s being passed to L{addError} and
 
1286
    L{addFailure}.
 
1287
    """
 
1288
 
 
1289
    def __init__(self, original):
 
1290
        """
 
1291
        @param original: A C{TestResult} instance from C{unittest}.
 
1292
        """
 
1293
        self.original = original
 
1294
 
 
1295
    def _exc_info(self, err):
 
1296
        return util.excInfoOrFailureToExcInfo(err)
 
1297
 
 
1298
    def startTest(self, method):
 
1299
        self.original.startTest(method)
 
1300
 
 
1301
    def stopTest(self, method):
 
1302
        self.original.stopTest(method)
 
1303
 
 
1304
    def addFailure(self, test, fail):
 
1305
        self.original.addFailure(test, self._exc_info(fail))
 
1306
 
 
1307
    def addError(self, test, error):
 
1308
        self.original.addError(test, self._exc_info(error))
 
1309
 
 
1310
    def _unsupported(self, test, feature, info):
 
1311
        self.original.addFailure(
 
1312
            test,
 
1313
            (UnsupportedTrialFeature,
 
1314
             UnsupportedTrialFeature(feature, info),
 
1315
             None))
 
1316
 
 
1317
    def addSkip(self, test, reason):
 
1318
        """
 
1319
        Report the skip as a failure.
 
1320
        """
 
1321
        self._unsupported(test, 'skip', reason)
 
1322
 
 
1323
    def addUnexpectedSuccess(self, test, todo):
 
1324
        """
 
1325
        Report the unexpected success as a failure.
 
1326
        """
 
1327
        self._unsupported(test, 'unexpected success', todo)
 
1328
 
 
1329
    def addExpectedFailure(self, test, error):
 
1330
        """
 
1331
        Report the expected failure (i.e. todo) as a failure.
 
1332
        """
 
1333
        self._unsupported(test, 'expected failure', error)
 
1334
 
 
1335
    def addSuccess(self, test):
 
1336
        self.original.addSuccess(test)
 
1337
 
 
1338
    def upDownError(self, method, error, warn, printStatus):
 
1339
        pass
 
1340
 
 
1341
 
 
1342
 
 
1343
def suiteVisit(suite, visitor):
 
1344
    """
 
1345
    Visit each test in C{suite} with C{visitor}.
 
1346
 
 
1347
    Deprecated in Twisted 8.0.
 
1348
 
 
1349
    @param visitor: A callable which takes a single argument, the L{TestCase}
 
1350
    instance to visit.
 
1351
    @return: None
 
1352
    """
 
1353
    warnings.warn("Test visitors deprecated in Twisted 8.0",
 
1354
                  category=DeprecationWarning)
 
1355
    for case in suite._tests:
 
1356
        visit = getattr(case, 'visit', None)
 
1357
        if visit is not None:
 
1358
            visit(visitor)
 
1359
        elif isinstance(case, pyunit.TestCase):
 
1360
            case = itrial.ITestCase(case)
 
1361
            case.visit(visitor)
 
1362
        elif isinstance(case, pyunit.TestSuite):
 
1363
            suiteVisit(case, visitor)
 
1364
        else:
 
1365
            case.visit(visitor)
 
1366
 
 
1367
 
 
1368
 
 
1369
class TestSuite(pyunit.TestSuite):
 
1370
    """
 
1371
    Extend the standard library's C{TestSuite} with support for the visitor
 
1372
    pattern and a consistently overrideable C{run} method.
 
1373
    """
 
1374
 
 
1375
    visit = suiteVisit
 
1376
 
 
1377
    def __call__(self, result):
 
1378
        return self.run(result)
 
1379
 
 
1380
 
 
1381
    def run(self, result):
 
1382
        """
 
1383
        Call C{run} on every member of the suite.
 
1384
        """
 
1385
        # we implement this because Python 2.3 unittest defines this code
 
1386
        # in __call__, whereas 2.4 defines the code in run.
 
1387
        for test in self._tests:
 
1388
            if result.shouldStop:
 
1389
                break
 
1390
            test(result)
 
1391
        return result
 
1392
 
 
1393
 
 
1394
 
 
1395
class TestDecorator(components.proxyForInterface(itrial.ITestCase,
 
1396
                                                 "_originalTest")):
 
1397
    """
 
1398
    Decorator for test cases.
 
1399
 
 
1400
    @param _originalTest: The wrapped instance of test.
 
1401
    @type _originalTest: A provider of L{itrial.ITestCase}
 
1402
    """
 
1403
 
 
1404
    implements(itrial.ITestCase)
 
1405
 
 
1406
 
 
1407
    def __call__(self, result):
 
1408
        """
 
1409
        Run the unit test.
 
1410
 
 
1411
        @param result: A TestResult object.
 
1412
        """
 
1413
        return self.run(result)
 
1414
 
 
1415
 
 
1416
    def run(self, result):
 
1417
        """
 
1418
        Run the unit test.
 
1419
 
 
1420
        @param result: A TestResult object.
 
1421
        """
 
1422
        return self._originalTest.run(
 
1423
            reporter._AdaptedReporter(result, self.__class__))
 
1424
 
 
1425
 
 
1426
 
 
1427
def _clearSuite(suite):
 
1428
    """
 
1429
    Clear all tests from C{suite}.
 
1430
 
 
1431
    This messes with the internals of C{suite}. In particular, it assumes that
 
1432
    the suite keeps all of its tests in a list in an instance variable called
 
1433
    C{_tests}.
 
1434
    """
 
1435
    suite._tests = []
 
1436
 
 
1437
 
 
1438
def decorate(test, decorator):
 
1439
    """
 
1440
    Decorate all test cases in C{test} with C{decorator}.
 
1441
 
 
1442
    C{test} can be a test case or a test suite. If it is a test suite, then the
 
1443
    structure of the suite is preserved.
 
1444
 
 
1445
    L{decorate} tries to preserve the class of the test suites it finds, but
 
1446
    assumes the presence of the C{_tests} attribute on the suite.
 
1447
 
 
1448
    @param test: The C{TestCase} or C{TestSuite} to decorate.
 
1449
 
 
1450
    @param decorator: A unary callable used to decorate C{TestCase}s.
 
1451
 
 
1452
    @return: A decorated C{TestCase} or a C{TestSuite} containing decorated
 
1453
        C{TestCase}s.
 
1454
    """
 
1455
 
 
1456
    try:
 
1457
        tests = iter(test)
 
1458
    except TypeError:
 
1459
        return decorator(test)
 
1460
 
 
1461
    # At this point, we know that 'test' is a test suite.
 
1462
    _clearSuite(test)
 
1463
 
 
1464
    for case in tests:
 
1465
        test.addTest(decorate(case, decorator))
 
1466
    return test
 
1467
 
 
1468
 
 
1469
 
 
1470
class _PyUnitTestCaseAdapter(TestDecorator):
 
1471
    """
 
1472
    Adapt from pyunit.TestCase to ITestCase.
 
1473
    """
 
1474
 
 
1475
 
 
1476
    def visit(self, visitor):
 
1477
        """
 
1478
        Deprecated in Twisted 8.0.
 
1479
        """
 
1480
        warnings.warn("Test visitors deprecated in Twisted 8.0",
 
1481
                      category=DeprecationWarning)
 
1482
        visitor(self)
 
1483
 
 
1484
 
 
1485
 
 
1486
class _BrokenIDTestCaseAdapter(_PyUnitTestCaseAdapter):
 
1487
    """
 
1488
    Adapter for pyunit-style C{TestCase} subclasses that have undesirable id()
 
1489
    methods. That is L{pyunit.FunctionTestCase} and L{pyunit.DocTestCase}.
 
1490
    """
 
1491
 
 
1492
    def id(self):
 
1493
        """
 
1494
        Return the fully-qualified Python name of the doctest.
 
1495
        """
 
1496
        testID = self._originalTest.shortDescription()
 
1497
        if testID is not None:
 
1498
            return testID
 
1499
        return self._originalTest.id()
 
1500
 
 
1501
 
 
1502
 
 
1503
class _ForceGarbageCollectionDecorator(TestDecorator):
 
1504
    """
 
1505
    Forces garbage collection to be run before and after the test. Any errors
 
1506
    logged during the post-test collection are added to the test result as
 
1507
    errors.
 
1508
    """
 
1509
 
 
1510
    def run(self, result):
 
1511
        gc.collect()
 
1512
        TestDecorator.run(self, result)
 
1513
        _logObserver._add()
 
1514
        gc.collect()
 
1515
        for error in _logObserver.getErrors():
 
1516
            result.addError(self, error)
 
1517
        _logObserver.flushErrors()
 
1518
        _logObserver._remove()
 
1519
 
 
1520
 
 
1521
components.registerAdapter(
 
1522
    _PyUnitTestCaseAdapter, pyunit.TestCase, itrial.ITestCase)
 
1523
 
 
1524
 
 
1525
components.registerAdapter(
 
1526
    _BrokenIDTestCaseAdapter, pyunit.FunctionTestCase, itrial.ITestCase)
 
1527
 
 
1528
 
 
1529
_docTestCase = getattr(doctest, 'DocTestCase', None)
 
1530
if _docTestCase:
 
1531
    components.registerAdapter(
 
1532
        _BrokenIDTestCaseAdapter, _docTestCase, itrial.ITestCase)
 
1533
 
 
1534
 
 
1535
def _iterateTests(testSuiteOrCase):
 
1536
    """
 
1537
    Iterate through all of the test cases in C{testSuiteOrCase}.
 
1538
    """
 
1539
    try:
 
1540
        suite = iter(testSuiteOrCase)
 
1541
    except TypeError:
 
1542
        yield testSuiteOrCase
 
1543
    else:
 
1544
        for test in suite:
 
1545
            for subtest in _iterateTests(test):
 
1546
                yield subtest
 
1547
 
 
1548
 
 
1549
 
 
1550
# Support for Python 2.3
 
1551
try:
 
1552
    iter(pyunit.TestSuite())
 
1553
except TypeError:
 
1554
    # Python 2.3's TestSuite doesn't support iteration. Let's monkey patch it!
 
1555
    def __iter__(self):
 
1556
        return iter(self._tests)
 
1557
    pyunit.TestSuite.__iter__ = __iter__
 
1558
 
 
1559
 
 
1560
 
 
1561
class _SubTestCase(TestCase):
 
1562
    def __init__(self):
 
1563
        TestCase.__init__(self, 'run')
 
1564
 
 
1565
_inst = _SubTestCase()
 
1566
 
 
1567
def _deprecate(name):
 
1568
    """
 
1569
    Internal method used to deprecate top-level assertions. Do not use this.
 
1570
    """
 
1571
    def _(*args, **kwargs):
 
1572
        warnings.warn("unittest.%s is deprecated.  Instead use the %r "
 
1573
                      "method on unittest.TestCase" % (name, name),
 
1574
                      stacklevel=2, category=DeprecationWarning)
 
1575
        return getattr(_inst, name)(*args, **kwargs)
 
1576
    return _
 
1577
 
 
1578
 
 
1579
_assertions = ['fail', 'failUnlessEqual', 'failIfEqual', 'failIfEquals',
 
1580
               'failUnless', 'failUnlessIdentical', 'failUnlessIn',
 
1581
               'failIfIdentical', 'failIfIn', 'failIf',
 
1582
               'failUnlessAlmostEqual', 'failIfAlmostEqual',
 
1583
               'failUnlessRaises', 'assertApproximates',
 
1584
               'assertFailure', 'failUnlessSubstring', 'failIfSubstring',
 
1585
               'assertAlmostEqual', 'assertAlmostEquals',
 
1586
               'assertNotAlmostEqual', 'assertNotAlmostEquals', 'assertEqual',
 
1587
               'assertEquals', 'assertNotEqual', 'assertNotEquals',
 
1588
               'assertRaises', 'assert_', 'assertIdentical',
 
1589
               'assertNotIdentical', 'assertIn', 'assertNotIn',
 
1590
               'failUnlessFailure', 'assertSubstring', 'assertNotSubstring']
 
1591
 
 
1592
 
 
1593
for methodName in _assertions:
 
1594
    globals()[methodName] = _deprecate(methodName)
 
1595
 
 
1596
 
 
1597
__all__ = ['TestCase', 'FailTest', 'SkipTest']