~soren/nova/iptables-security-groups

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/python/test/test_util.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.test.test_util -*-
 
2
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
import os.path, sys
 
6
import shutil, errno
 
7
try:
 
8
    import pwd, grp
 
9
except ImportError:
 
10
    pwd = grp = None
 
11
 
 
12
from twisted.trial import unittest
 
13
 
 
14
from twisted.python import util
 
15
from twisted.internet import reactor
 
16
from twisted.internet.interfaces import IReactorProcess
 
17
from twisted.internet.protocol import ProcessProtocol
 
18
from twisted.internet.defer import Deferred
 
19
from twisted.internet.error import ProcessDone
 
20
 
 
21
from twisted.test.test_process import MockOS
 
22
 
 
23
 
 
24
 
 
25
class UtilTestCase(unittest.TestCase):
 
26
 
 
27
    def testUniq(self):
 
28
        l = ["a", 1, "ab", "a", 3, 4, 1, 2, 2, 4, 6]
 
29
        self.assertEquals(util.uniquify(l), ["a", 1, "ab", 3, 4, 2, 6])
 
30
 
 
31
    def testRaises(self):
 
32
        self.failUnless(util.raises(ZeroDivisionError, divmod, 1, 0))
 
33
        self.failIf(util.raises(ZeroDivisionError, divmod, 0, 1))
 
34
 
 
35
        try:
 
36
            util.raises(TypeError, divmod, 1, 0)
 
37
        except ZeroDivisionError:
 
38
            pass
 
39
        else:
 
40
            raise unittest.FailTest, "util.raises didn't raise when it should have"
 
41
 
 
42
    def testUninterruptably(self):
 
43
        def f(a, b):
 
44
            self.calls += 1
 
45
            exc = self.exceptions.pop()
 
46
            if exc is not None:
 
47
                raise exc(errno.EINTR, "Interrupted system call!")
 
48
            return a + b
 
49
 
 
50
        self.exceptions = [None]
 
51
        self.calls = 0
 
52
        self.assertEquals(util.untilConcludes(f, 1, 2), 3)
 
53
        self.assertEquals(self.calls, 1)
 
54
 
 
55
        self.exceptions = [None, OSError, IOError]
 
56
        self.calls = 0
 
57
        self.assertEquals(util.untilConcludes(f, 2, 3), 5)
 
58
        self.assertEquals(self.calls, 3)
 
59
 
 
60
    def testNameToLabel(self):
 
61
        """
 
62
        Test the various kinds of inputs L{nameToLabel} supports.
 
63
        """
 
64
        nameData = [
 
65
            ('f', 'F'),
 
66
            ('fo', 'Fo'),
 
67
            ('foo', 'Foo'),
 
68
            ('fooBar', 'Foo Bar'),
 
69
            ('fooBarBaz', 'Foo Bar Baz'),
 
70
            ]
 
71
        for inp, out in nameData:
 
72
            got = util.nameToLabel(inp)
 
73
            self.assertEquals(
 
74
                got, out,
 
75
                "nameToLabel(%r) == %r != %r" % (inp, got, out))
 
76
 
 
77
 
 
78
    def test_uidFromNumericString(self):
 
79
        """
 
80
        When L{uidFromString} is called with a base-ten string representation
 
81
        of an integer, it returns the integer.
 
82
        """
 
83
        self.assertEqual(util.uidFromString("100"), 100)
 
84
 
 
85
 
 
86
    def test_uidFromUsernameString(self):
 
87
        """
 
88
        When L{uidFromString} is called with a base-ten string representation
 
89
        of an integer, it returns the integer.
 
90
        """
 
91
        pwent = pwd.getpwuid(os.getuid())
 
92
        self.assertEqual(util.uidFromString(pwent.pw_name), pwent.pw_uid)
 
93
    if pwd is None:
 
94
        test_uidFromUsernameString.skip = (
 
95
            "Username/UID conversion requires the pwd module.")
 
96
 
 
97
 
 
98
    def test_gidFromNumericString(self):
 
99
        """
 
100
        When L{gidFromString} is called with a base-ten string representation
 
101
        of an integer, it returns the integer.
 
102
        """
 
103
        self.assertEqual(util.gidFromString("100"), 100)
 
104
 
 
105
 
 
106
    def test_gidFromGroupnameString(self):
 
107
        """
 
108
        When L{gidFromString} is called with a base-ten string representation
 
109
        of an integer, it returns the integer.
 
110
        """
 
111
        grent = grp.getgrgid(os.getgid())
 
112
        self.assertEqual(util.gidFromString(grent.gr_name), grent.gr_gid)
 
113
    if grp is None:
 
114
        test_gidFromGroupnameString.skip = (
 
115
            "Group Name/GID conversion requires the grp module.")
 
116
 
 
117
 
 
118
    def test_moduleMovedForSplitDeprecation(self):
 
119
        """
 
120
        Calling L{moduleMovedForSplit} results in a deprecation warning.
 
121
        """
 
122
        util.moduleMovedForSplit("foo", "bar", "baz", "quux", "corge", {})
 
123
        warnings = self.flushWarnings(
 
124
            offendingFunctions=[self.test_moduleMovedForSplitDeprecation])
 
125
        self.assertEquals(
 
126
            warnings[0]['message'],
 
127
            "moduleMovedForSplit is deprecated since Twisted 9.0.")
 
128
        self.assertEquals(warnings[0]['category'], DeprecationWarning)
 
129
        self.assertEquals(len(warnings), 1)
 
130
 
 
131
 
 
132
 
 
133
class TestMergeFunctionMetadata(unittest.TestCase):
 
134
    """
 
135
    Tests for L{mergeFunctionMetadata}.
 
136
    """
 
137
 
 
138
    def test_mergedFunctionBehavesLikeMergeTarget(self):
 
139
        """
 
140
        After merging C{foo}'s data into C{bar}, the returned function behaves
 
141
        as if it is C{bar}.
 
142
        """
 
143
        foo_object = object()
 
144
        bar_object = object()
 
145
 
 
146
        def foo():
 
147
            return foo_object
 
148
 
 
149
        def bar(x, y, (a, b), c=10, *d, **e):
 
150
            return bar_object
 
151
 
 
152
        baz = util.mergeFunctionMetadata(foo, bar)
 
153
        self.assertIdentical(baz(1, 2, (3, 4), quux=10), bar_object)
 
154
 
 
155
 
 
156
    def test_moduleIsMerged(self):
 
157
        """
 
158
        Merging C{foo} into C{bar} returns a function with C{foo}'s
 
159
        C{__module__}.
 
160
        """
 
161
        def foo():
 
162
            pass
 
163
 
 
164
        def bar():
 
165
            pass
 
166
        bar.__module__ = 'somewhere.else'
 
167
 
 
168
        baz = util.mergeFunctionMetadata(foo, bar)
 
169
        self.assertEqual(baz.__module__, foo.__module__)
 
170
 
 
171
 
 
172
    def test_docstringIsMerged(self):
 
173
        """
 
174
        Merging C{foo} into C{bar} returns a function with C{foo}'s docstring.
 
175
        """
 
176
 
 
177
        def foo():
 
178
            """
 
179
            This is foo.
 
180
            """
 
181
 
 
182
        def bar():
 
183
            """
 
184
            This is bar.
 
185
            """
 
186
 
 
187
        baz = util.mergeFunctionMetadata(foo, bar)
 
188
        self.assertEqual(baz.__doc__, foo.__doc__)
 
189
 
 
190
 
 
191
    def test_nameIsMerged(self):
 
192
        """
 
193
        Merging C{foo} into C{bar} returns a function with C{foo}'s name.
 
194
        """
 
195
 
 
196
        def foo():
 
197
            pass
 
198
 
 
199
        def bar():
 
200
            pass
 
201
 
 
202
        baz = util.mergeFunctionMetadata(foo, bar)
 
203
        self.assertEqual(baz.__name__, foo.__name__)
 
204
 
 
205
 
 
206
    def test_instanceDictionaryIsMerged(self):
 
207
        """
 
208
        Merging C{foo} into C{bar} returns a function with C{bar}'s
 
209
        dictionary, updated by C{foo}'s.
 
210
        """
 
211
 
 
212
        def foo():
 
213
            pass
 
214
        foo.a = 1
 
215
        foo.b = 2
 
216
 
 
217
        def bar():
 
218
            pass
 
219
        bar.b = 3
 
220
        bar.c = 4
 
221
 
 
222
        baz = util.mergeFunctionMetadata(foo, bar)
 
223
        self.assertEqual(foo.a, baz.a)
 
224
        self.assertEqual(foo.b, baz.b)
 
225
        self.assertEqual(bar.c, baz.c)
 
226
 
 
227
 
 
228
 
 
229
class OrderedDictTest(unittest.TestCase):
 
230
    def testOrderedDict(self):
 
231
        d = util.OrderedDict()
 
232
        d['a'] = 'b'
 
233
        d['b'] = 'a'
 
234
        d[3] = 12
 
235
        d[1234] = 4321
 
236
        self.assertEquals(repr(d), "{'a': 'b', 'b': 'a', 3: 12, 1234: 4321}")
 
237
        self.assertEquals(d.values(), ['b', 'a', 12, 4321])
 
238
        del d[3]
 
239
        self.assertEquals(repr(d), "{'a': 'b', 'b': 'a', 1234: 4321}")
 
240
        self.assertEquals(d, {'a': 'b', 'b': 'a', 1234:4321})
 
241
        self.assertEquals(d.keys(), ['a', 'b', 1234])
 
242
        self.assertEquals(list(d.iteritems()),
 
243
                          [('a', 'b'), ('b','a'), (1234, 4321)])
 
244
        item = d.popitem()
 
245
        self.assertEquals(item, (1234, 4321))
 
246
 
 
247
    def testInitialization(self):
 
248
        d = util.OrderedDict({'monkey': 'ook',
 
249
                              'apple': 'red'})
 
250
        self.failUnless(d._order)
 
251
 
 
252
        d = util.OrderedDict(((1,1),(3,3),(2,2),(0,0)))
 
253
        self.assertEquals(repr(d), "{1: 1, 3: 3, 2: 2, 0: 0}")
 
254
 
 
255
class InsensitiveDictTest(unittest.TestCase):
 
256
    def testPreserve(self):
 
257
        InsensitiveDict=util.InsensitiveDict
 
258
        dct=InsensitiveDict({'Foo':'bar', 1:2, 'fnz':{1:2}}, preserve=1)
 
259
        self.assertEquals(dct['fnz'], {1:2})
 
260
        self.assertEquals(dct['foo'], 'bar')
 
261
        self.assertEquals(dct.copy(), dct)
 
262
        self.assertEquals(dct['foo'], dct.get('Foo'))
 
263
        assert 1 in dct and 'foo' in dct
 
264
        self.assertEquals(eval(repr(dct)), dct)
 
265
        keys=['Foo', 'fnz', 1]
 
266
        for x in keys:
 
267
            assert x in dct.keys()
 
268
            assert (x, dct[x]) in dct.items()
 
269
        self.assertEquals(len(keys), len(dct))
 
270
        del dct[1]
 
271
        del dct['foo']
 
272
 
 
273
    def testNoPreserve(self):
 
274
        InsensitiveDict=util.InsensitiveDict
 
275
        dct=InsensitiveDict({'Foo':'bar', 1:2, 'fnz':{1:2}}, preserve=0)
 
276
        keys=['foo', 'fnz', 1]
 
277
        for x in keys:
 
278
            assert x in dct.keys()
 
279
            assert (x, dct[x]) in dct.items()
 
280
        self.assertEquals(len(keys), len(dct))
 
281
        del dct[1]
 
282
        del dct['foo']
 
283
 
 
284
 
 
285
 
 
286
 
 
287
class PasswordTestingProcessProtocol(ProcessProtocol):
 
288
    """
 
289
    Write the string C{"secret\n"} to a subprocess and then collect all of
 
290
    its output and fire a Deferred with it when the process ends.
 
291
    """
 
292
    def connectionMade(self):
 
293
        self.output = []
 
294
        self.transport.write('secret\n')
 
295
 
 
296
    def childDataReceived(self, fd, output):
 
297
        self.output.append((fd, output))
 
298
 
 
299
    def processEnded(self, reason):
 
300
        self.finished.callback((reason, self.output))
 
301
 
 
302
 
 
303
class GetPasswordTest(unittest.TestCase):
 
304
    if not IReactorProcess.providedBy(reactor):
 
305
        skip = "Process support required to test getPassword"
 
306
 
 
307
    def test_stdin(self):
 
308
        """
 
309
        Making sure getPassword accepts a password from standard input by
 
310
        running a child process which uses getPassword to read in a string
 
311
        which it then writes it out again.  Write a string to the child
 
312
        process and then read one and make sure it is the right string.
 
313
        """
 
314
        p = PasswordTestingProcessProtocol()
 
315
        p.finished = Deferred()
 
316
        reactor.spawnProcess(
 
317
            p,
 
318
            sys.executable,
 
319
            [sys.executable,
 
320
             '-c',
 
321
             ('import sys\n'
 
322
             'from twisted.python.util import getPassword\n'
 
323
              'sys.stdout.write(getPassword())\n'
 
324
              'sys.stdout.flush()\n')],
 
325
            env={'PYTHONPATH': os.pathsep.join(sys.path)})
 
326
 
 
327
        def processFinished((reason, output)):
 
328
            reason.trap(ProcessDone)
 
329
            self.assertIn((1, 'secret'), output)
 
330
 
 
331
        return p.finished.addCallback(processFinished)
 
332
 
 
333
 
 
334
 
 
335
class SearchUpwardsTest(unittest.TestCase):
 
336
    def testSearchupwards(self):
 
337
        os.makedirs('searchupwards/a/b/c')
 
338
        file('searchupwards/foo.txt', 'w').close()
 
339
        file('searchupwards/a/foo.txt', 'w').close()
 
340
        file('searchupwards/a/b/c/foo.txt', 'w').close()
 
341
        os.mkdir('searchupwards/bar')
 
342
        os.mkdir('searchupwards/bam')
 
343
        os.mkdir('searchupwards/a/bar')
 
344
        os.mkdir('searchupwards/a/b/bam')
 
345
        actual=util.searchupwards('searchupwards/a/b/c',
 
346
                                  files=['foo.txt'],
 
347
                                  dirs=['bar', 'bam'])
 
348
        expected=os.path.abspath('searchupwards') + os.sep
 
349
        self.assertEqual(actual, expected)
 
350
        shutil.rmtree('searchupwards')
 
351
        actual=util.searchupwards('searchupwards/a/b/c',
 
352
                                  files=['foo.txt'],
 
353
                                  dirs=['bar', 'bam'])
 
354
        expected=None
 
355
        self.assertEqual(actual, expected)
 
356
 
 
357
class Foo:
 
358
    def __init__(self, x):
 
359
        self.x = x
 
360
 
 
361
class DSU(unittest.TestCase):
 
362
    def testDSU(self):
 
363
        L = [Foo(x) for x in range(20, 9, -1)]
 
364
        L2 = util.dsu(L, lambda o: o.x)
 
365
        self.assertEquals(range(10, 21), [o.x for o in L2])
 
366
 
 
367
class IntervalDifferentialTestCase(unittest.TestCase):
 
368
    def testDefault(self):
 
369
        d = iter(util.IntervalDifferential([], 10))
 
370
        for i in range(100):
 
371
            self.assertEquals(d.next(), (10, None))
 
372
 
 
373
    def testSingle(self):
 
374
        d = iter(util.IntervalDifferential([5], 10))
 
375
        for i in range(100):
 
376
            self.assertEquals(d.next(), (5, 0))
 
377
 
 
378
    def testPair(self):
 
379
        d = iter(util.IntervalDifferential([5, 7], 10))
 
380
        for i in range(100):
 
381
            self.assertEquals(d.next(), (5, 0))
 
382
            self.assertEquals(d.next(), (2, 1))
 
383
            self.assertEquals(d.next(), (3, 0))
 
384
            self.assertEquals(d.next(), (4, 1))
 
385
            self.assertEquals(d.next(), (1, 0))
 
386
            self.assertEquals(d.next(), (5, 0))
 
387
            self.assertEquals(d.next(), (1, 1))
 
388
            self.assertEquals(d.next(), (4, 0))
 
389
            self.assertEquals(d.next(), (3, 1))
 
390
            self.assertEquals(d.next(), (2, 0))
 
391
            self.assertEquals(d.next(), (5, 0))
 
392
            self.assertEquals(d.next(), (0, 1))
 
393
 
 
394
    def testTriple(self):
 
395
        d = iter(util.IntervalDifferential([2, 4, 5], 10))
 
396
        for i in range(100):
 
397
            self.assertEquals(d.next(), (2, 0))
 
398
            self.assertEquals(d.next(), (2, 0))
 
399
            self.assertEquals(d.next(), (0, 1))
 
400
            self.assertEquals(d.next(), (1, 2))
 
401
            self.assertEquals(d.next(), (1, 0))
 
402
            self.assertEquals(d.next(), (2, 0))
 
403
            self.assertEquals(d.next(), (0, 1))
 
404
            self.assertEquals(d.next(), (2, 0))
 
405
            self.assertEquals(d.next(), (0, 2))
 
406
            self.assertEquals(d.next(), (2, 0))
 
407
            self.assertEquals(d.next(), (0, 1))
 
408
            self.assertEquals(d.next(), (2, 0))
 
409
            self.assertEquals(d.next(), (1, 2))
 
410
            self.assertEquals(d.next(), (1, 0))
 
411
            self.assertEquals(d.next(), (0, 1))
 
412
            self.assertEquals(d.next(), (2, 0))
 
413
            self.assertEquals(d.next(), (2, 0))
 
414
            self.assertEquals(d.next(), (0, 1))
 
415
            self.assertEquals(d.next(), (0, 2))
 
416
 
 
417
    def testInsert(self):
 
418
        d = iter(util.IntervalDifferential([], 10))
 
419
        self.assertEquals(d.next(), (10, None))
 
420
        d.addInterval(3)
 
421
        self.assertEquals(d.next(), (3, 0))
 
422
        self.assertEquals(d.next(), (3, 0))
 
423
        d.addInterval(6)
 
424
        self.assertEquals(d.next(), (3, 0))
 
425
        self.assertEquals(d.next(), (3, 0))
 
426
        self.assertEquals(d.next(), (0, 1))
 
427
        self.assertEquals(d.next(), (3, 0))
 
428
        self.assertEquals(d.next(), (3, 0))
 
429
        self.assertEquals(d.next(), (0, 1))
 
430
 
 
431
    def testRemove(self):
 
432
        d = iter(util.IntervalDifferential([3, 5], 10))
 
433
        self.assertEquals(d.next(), (3, 0))
 
434
        self.assertEquals(d.next(), (2, 1))
 
435
        self.assertEquals(d.next(), (1, 0))
 
436
        d.removeInterval(3)
 
437
        self.assertEquals(d.next(), (4, 0))
 
438
        self.assertEquals(d.next(), (5, 0))
 
439
        d.removeInterval(5)
 
440
        self.assertEquals(d.next(), (10, None))
 
441
        self.assertRaises(ValueError, d.removeInterval, 10)
 
442
 
 
443
 
 
444
 
 
445
class Record(util.FancyEqMixin):
 
446
    """
 
447
    Trivial user of L{FancyEqMixin} used by tests.
 
448
    """
 
449
    compareAttributes = ('a', 'b')
 
450
 
 
451
    def __init__(self, a, b):
 
452
        self.a = a
 
453
        self.b = b
 
454
 
 
455
 
 
456
 
 
457
class DifferentRecord(util.FancyEqMixin):
 
458
    """
 
459
    Trivial user of L{FancyEqMixin} which is not related to L{Record}.
 
460
    """
 
461
    compareAttributes = ('a', 'b')
 
462
 
 
463
    def __init__(self, a, b):
 
464
        self.a = a
 
465
        self.b = b
 
466
 
 
467
 
 
468
 
 
469
class DerivedRecord(Record):
 
470
    """
 
471
    A class with an inheritance relationship to L{Record}.
 
472
    """
 
473
 
 
474
 
 
475
 
 
476
class EqualToEverything(object):
 
477
    """
 
478
    A class the instances of which consider themselves equal to everything.
 
479
    """
 
480
    def __eq__(self, other):
 
481
        return True
 
482
 
 
483
 
 
484
    def __ne__(self, other):
 
485
        return False
 
486
 
 
487
 
 
488
 
 
489
class EqualToNothing(object):
 
490
    """
 
491
    A class the instances of which consider themselves equal to nothing.
 
492
    """
 
493
    def __eq__(self, other):
 
494
        return False
 
495
 
 
496
 
 
497
    def __ne__(self, other):
 
498
        return True
 
499
 
 
500
 
 
501
 
 
502
class EqualityTests(unittest.TestCase):
 
503
    """
 
504
    Tests for L{FancyEqMixin}.
 
505
    """
 
506
    def test_identity(self):
 
507
        """
 
508
        Instances of a class which mixes in L{FancyEqMixin} but which
 
509
        defines no comparison attributes compare by identity.
 
510
        """
 
511
        class Empty(util.FancyEqMixin):
 
512
            pass
 
513
 
 
514
        self.assertFalse(Empty() == Empty())
 
515
        self.assertTrue(Empty() != Empty())
 
516
        empty = Empty()
 
517
        self.assertTrue(empty == empty)
 
518
        self.assertFalse(empty != empty)
 
519
 
 
520
 
 
521
    def test_equality(self):
 
522
        """
 
523
        Instances of a class which mixes in L{FancyEqMixin} should compare
 
524
        equal if all of their attributes compare equal.  They should not
 
525
        compare equal if any of their attributes do not compare equal.
 
526
        """
 
527
        self.assertTrue(Record(1, 2) == Record(1, 2))
 
528
        self.assertFalse(Record(1, 2) == Record(1, 3))
 
529
        self.assertFalse(Record(1, 2) == Record(2, 2))
 
530
        self.assertFalse(Record(1, 2) == Record(3, 4))
 
531
 
 
532
 
 
533
    def test_unequality(self):
 
534
        """
 
535
        Unequality between instances of a particular L{record} should be
 
536
        defined as the negation of equality.
 
537
        """
 
538
        self.assertFalse(Record(1, 2) != Record(1, 2))
 
539
        self.assertTrue(Record(1, 2) != Record(1, 3))
 
540
        self.assertTrue(Record(1, 2) != Record(2, 2))
 
541
        self.assertTrue(Record(1, 2) != Record(3, 4))
 
542
 
 
543
 
 
544
    def test_differentClassesEquality(self):
 
545
        """
 
546
        Instances of different classes which mix in L{FancyEqMixin} should not
 
547
        compare equal.
 
548
        """
 
549
        self.assertFalse(Record(1, 2) == DifferentRecord(1, 2))
 
550
 
 
551
 
 
552
    def test_differentClassesInequality(self):
 
553
        """
 
554
        Instances of different classes which mix in L{FancyEqMixin} should
 
555
        compare unequal.
 
556
        """
 
557
        self.assertTrue(Record(1, 2) != DifferentRecord(1, 2))
 
558
 
 
559
 
 
560
    def test_inheritedClassesEquality(self):
 
561
        """
 
562
        An instance of a class which derives from a class which mixes in
 
563
        L{FancyEqMixin} should compare equal to an instance of the base class
 
564
        if and only if all of their attributes compare equal.
 
565
        """
 
566
        self.assertTrue(Record(1, 2) == DerivedRecord(1, 2))
 
567
        self.assertFalse(Record(1, 2) == DerivedRecord(1, 3))
 
568
        self.assertFalse(Record(1, 2) == DerivedRecord(2, 2))
 
569
        self.assertFalse(Record(1, 2) == DerivedRecord(3, 4))
 
570
 
 
571
 
 
572
    def test_inheritedClassesInequality(self):
 
573
        """
 
574
        An instance of a class which derives from a class which mixes in
 
575
        L{FancyEqMixin} should compare unequal to an instance of the base
 
576
        class if any of their attributes compare unequal.
 
577
        """
 
578
        self.assertFalse(Record(1, 2) != DerivedRecord(1, 2))
 
579
        self.assertTrue(Record(1, 2) != DerivedRecord(1, 3))
 
580
        self.assertTrue(Record(1, 2) != DerivedRecord(2, 2))
 
581
        self.assertTrue(Record(1, 2) != DerivedRecord(3, 4))
 
582
 
 
583
 
 
584
    def test_rightHandArgumentImplementsEquality(self):
 
585
        """
 
586
        The right-hand argument to the equality operator is given a chance
 
587
        to determine the result of the operation if it is of a type
 
588
        unrelated to the L{FancyEqMixin}-based instance on the left-hand
 
589
        side.
 
590
        """
 
591
        self.assertTrue(Record(1, 2) == EqualToEverything())
 
592
        self.assertFalse(Record(1, 2) == EqualToNothing())
 
593
 
 
594
 
 
595
    def test_rightHandArgumentImplementsUnequality(self):
 
596
        """
 
597
        The right-hand argument to the non-equality operator is given a
 
598
        chance to determine the result of the operation if it is of a type
 
599
        unrelated to the L{FancyEqMixin}-based instance on the left-hand
 
600
        side.
 
601
        """
 
602
        self.assertFalse(Record(1, 2) != EqualToEverything())
 
603
        self.assertTrue(Record(1, 2) != EqualToNothing())
 
604
 
 
605
 
 
606
 
 
607
class RunAsEffectiveUserTests(unittest.TestCase):
 
608
    """
 
609
    Test for the L{util.runAsEffectiveUser} function.
 
610
    """
 
611
 
 
612
    if getattr(os, "geteuid", None) is None:
 
613
        skip = "geteuid/seteuid not available"
 
614
 
 
615
    def setUp(self):
 
616
        self.mockos = MockOS()
 
617
        self.patch(os, "geteuid", self.mockos.geteuid)
 
618
        self.patch(os, "getegid", self.mockos.getegid)
 
619
        self.patch(os, "seteuid", self.mockos.seteuid)
 
620
        self.patch(os, "setegid", self.mockos.setegid)
 
621
 
 
622
 
 
623
    def _securedFunction(self, startUID, startGID, wantUID, wantGID):
 
624
        """
 
625
        Check if wanted UID/GID matched start or saved ones.
 
626
        """
 
627
        self.assertTrue(wantUID == startUID or
 
628
                        wantUID == self.mockos.seteuidCalls[-1])
 
629
        self.assertTrue(wantGID == startGID or
 
630
                        wantGID == self.mockos.setegidCalls[-1])
 
631
 
 
632
 
 
633
    def test_forwardResult(self):
 
634
        """
 
635
        L{util.runAsEffectiveUser} forwards the result obtained by calling the
 
636
        given function
 
637
        """
 
638
        result = util.runAsEffectiveUser(0, 0, lambda: 1)
 
639
        self.assertEquals(result, 1)
 
640
 
 
641
 
 
642
    def test_takeParameters(self):
 
643
        """
 
644
        L{util.runAsEffectiveUser} pass the given parameters to the given
 
645
        function.
 
646
        """
 
647
        result = util.runAsEffectiveUser(0, 0, lambda x: 2*x, 3)
 
648
        self.assertEquals(result, 6)
 
649
 
 
650
 
 
651
    def test_takesKeyworkArguments(self):
 
652
        """
 
653
        L{util.runAsEffectiveUser} pass the keyword parameters to the given
 
654
        function.
 
655
        """
 
656
        result = util.runAsEffectiveUser(0, 0, lambda x, y=1, z=1: x*y*z, 2, z=3)
 
657
        self.assertEquals(result, 6)
 
658
 
 
659
 
 
660
    def _testUIDGIDSwitch(self, startUID, startGID, wantUID, wantGID,
 
661
                          expectedUIDSwitches, expectedGIDSwitches):
 
662
        """
 
663
        Helper method checking the calls to C{os.seteuid} and C{os.setegid}
 
664
        made by L{util.runAsEffectiveUser}, when switching from startUID to
 
665
        wantUID and from startGID to wantGID.
 
666
        """
 
667
        self.mockos.euid = startUID
 
668
        self.mockos.egid = startGID
 
669
        util.runAsEffectiveUser(
 
670
            wantUID, wantGID,
 
671
            self._securedFunction, startUID, startGID, wantUID, wantGID)
 
672
        self.assertEquals(self.mockos.seteuidCalls, expectedUIDSwitches)
 
673
        self.assertEquals(self.mockos.setegidCalls, expectedGIDSwitches)
 
674
        self.mockos.seteuidCalls = []
 
675
        self.mockos.setegidCalls = []
 
676
 
 
677
 
 
678
    def test_root(self):
 
679
        """
 
680
        Check UID/GID switches when current effective UID is root.
 
681
        """
 
682
        self._testUIDGIDSwitch(0, 0, 0, 0, [], [])
 
683
        self._testUIDGIDSwitch(0, 0, 1, 0, [1, 0], [])
 
684
        self._testUIDGIDSwitch(0, 0, 0, 1, [], [1, 0])
 
685
        self._testUIDGIDSwitch(0, 0, 1, 1, [1, 0], [1, 0])
 
686
 
 
687
 
 
688
    def test_UID(self):
 
689
        """
 
690
        Check UID/GID switches when current effective UID is non-root.
 
691
        """
 
692
        self._testUIDGIDSwitch(1, 0, 0, 0, [0, 1], [])
 
693
        self._testUIDGIDSwitch(1, 0, 1, 0, [], [])
 
694
        self._testUIDGIDSwitch(1, 0, 1, 1, [0, 1, 0, 1], [1, 0])
 
695
        self._testUIDGIDSwitch(1, 0, 2, 1, [0, 2, 0, 1], [1, 0])
 
696
 
 
697
 
 
698
    def test_GID(self):
 
699
        """
 
700
        Check UID/GID switches when current effective GID is non-root.
 
701
        """
 
702
        self._testUIDGIDSwitch(0, 1, 0, 0, [], [0, 1])
 
703
        self._testUIDGIDSwitch(0, 1, 0, 1, [], [])
 
704
        self._testUIDGIDSwitch(0, 1, 1, 1, [1, 0], [])
 
705
        self._testUIDGIDSwitch(0, 1, 1, 2, [1, 0], [2, 1])
 
706
 
 
707
 
 
708
    def test_UIDGID(self):
 
709
        """
 
710
        Check UID/GID switches when current effective UID/GID is non-root.
 
711
        """
 
712
        self._testUIDGIDSwitch(1, 1, 0, 0, [0, 1], [0, 1])
 
713
        self._testUIDGIDSwitch(1, 1, 0, 1, [0, 1], [])
 
714
        self._testUIDGIDSwitch(1, 1, 1, 0, [0, 1, 0, 1], [0, 1])
 
715
        self._testUIDGIDSwitch(1, 1, 1, 1, [], [])
 
716
        self._testUIDGIDSwitch(1, 1, 2, 1, [0, 2, 0, 1], [])
 
717
        self._testUIDGIDSwitch(1, 1, 1, 2, [0, 1, 0, 1], [2, 1])
 
718
        self._testUIDGIDSwitch(1, 1, 2, 2, [0, 2, 0, 1], [2, 1])
 
719
 
 
720
 
 
721
 
 
722
class UnsignedIDTests(unittest.TestCase):
 
723
    """
 
724
    Tests for L{util.unsignedID} and L{util.setIDFunction}.
 
725
    """
 
726
    def setUp(self):
 
727
        """
 
728
        Save the value of L{util._idFunction} and arrange for it to be restored
 
729
        after the test runs.
 
730
        """
 
731
        self.addCleanup(setattr, util, '_idFunction', util._idFunction)
 
732
 
 
733
 
 
734
    def test_setIDFunction(self):
 
735
        """
 
736
        L{util.setIDFunction} returns the last value passed to it.
 
737
        """
 
738
        value = object()
 
739
        previous = util.setIDFunction(value)
 
740
        result = util.setIDFunction(previous)
 
741
        self.assertIdentical(value, result)
 
742
 
 
743
 
 
744
    def test_unsignedID(self):
 
745
        """
 
746
        L{util.unsignedID} uses the function passed to L{util.setIDFunction} to
 
747
        determine the unique integer id of an object and then adjusts it to be
 
748
        positive if necessary.
 
749
        """
 
750
        foo = object()
 
751
        bar = object()
 
752
 
 
753
        # A fake object identity mapping
 
754
        objects = {foo: 17, bar: -73}
 
755
        def fakeId(obj):
 
756
            return objects[obj]
 
757
 
 
758
        util.setIDFunction(fakeId)
 
759
 
 
760
        self.assertEquals(util.unsignedID(foo), 17)
 
761
        self.assertEquals(util.unsignedID(bar), (sys.maxint + 1) * 2 - 73)
 
762
 
 
763
 
 
764
    def test_defaultIDFunction(self):
 
765
        """
 
766
        L{util.unsignedID} uses the built in L{id} by default.
 
767
        """
 
768
        obj = object()
 
769
        idValue = id(obj)
 
770
        if idValue < 0:
 
771
            idValue += (sys.maxint + 1) * 2
 
772
 
 
773
        self.assertEquals(util.unsignedID(obj), idValue)
 
774
 
 
775
 
 
776
 
 
777
class InitGroupsTests(unittest.TestCase):
 
778
    """
 
779
    Tests for L{util.initgroups}.
 
780
    """
 
781
 
 
782
    if pwd is None:
 
783
        skip = "pwd not available"
 
784
 
 
785
 
 
786
    def setUp(self):
 
787
        self.addCleanup(setattr, util, "_c_initgroups", util._c_initgroups)
 
788
        self.addCleanup(setattr, util, "setgroups", util.setgroups)
 
789
 
 
790
 
 
791
    def test_initgroupsForceC(self):
 
792
        """
 
793
        If we fake the presence of the C extension, it's called instead of the
 
794
        Python implementation.
 
795
        """
 
796
        calls = []
 
797
        util._c_initgroups = lambda x, y: calls.append((x, y))
 
798
        setgroupsCalls = []
 
799
        util.setgroups = calls.append
 
800
 
 
801
        util.initgroups(os.getuid(), 4)
 
802
        self.assertEquals(calls, [(pwd.getpwuid(os.getuid())[0], 4)])
 
803
        self.assertFalse(setgroupsCalls)
 
804
 
 
805
 
 
806
    def test_initgroupsForcePython(self):
 
807
        """
 
808
        If we fake the absence of the C extension, the Python implementation is
 
809
        called instead, calling C{os.setgroups}.
 
810
        """
 
811
        util._c_initgroups = None
 
812
        calls = []
 
813
        util.setgroups = calls.append
 
814
        util.initgroups(os.getuid(), os.getgid())
 
815
        # Something should be in the calls, we don't really care what
 
816
        self.assertTrue(calls)
 
817
 
 
818
 
 
819
    def test_initgroupsInC(self):
 
820
        """
 
821
        If the C extension is present, it's called instead of the Python
 
822
        version.  We check that by making sure C{os.setgroups} is not called.
 
823
        """
 
824
        calls = []
 
825
        util.setgroups = calls.append
 
826
        try:
 
827
            util.initgroups(os.getuid(), os.getgid())
 
828
        except OSError:
 
829
            pass
 
830
        self.assertFalse(calls)
 
831
 
 
832
 
 
833
    if util._c_initgroups is None:
 
834
        test_initgroupsInC.skip = "C initgroups not available"