~justin-fathomdb/nova/justinsb-openstack-api-volumes

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/test/test_policies.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
 
2
# See LICENSE for details.
 
3
 
 
4
"""
 
5
Test code for policies.
 
6
"""
 
7
 
 
8
from zope.interface import Interface, implements, implementedBy
 
9
 
 
10
from StringIO import StringIO
 
11
 
 
12
from twisted.trial import unittest
 
13
from twisted.test.proto_helpers import StringTransport
 
14
from twisted.test.proto_helpers import StringTransportWithDisconnection
 
15
 
 
16
from twisted.internet import protocol, reactor, address, defer, task
 
17
from twisted.protocols import policies
 
18
 
 
19
 
 
20
 
 
21
class SimpleProtocol(protocol.Protocol):
 
22
 
 
23
    connected = disconnected = 0
 
24
    buffer = ""
 
25
 
 
26
    def __init__(self):
 
27
        self.dConnected = defer.Deferred()
 
28
        self.dDisconnected = defer.Deferred()
 
29
 
 
30
    def connectionMade(self):
 
31
        self.connected = 1
 
32
        self.dConnected.callback('')
 
33
 
 
34
    def connectionLost(self, reason):
 
35
        self.disconnected = 1
 
36
        self.dDisconnected.callback('')
 
37
 
 
38
    def dataReceived(self, data):
 
39
        self.buffer += data
 
40
 
 
41
 
 
42
 
 
43
class SillyFactory(protocol.ClientFactory):
 
44
 
 
45
    def __init__(self, p):
 
46
        self.p = p
 
47
 
 
48
    def buildProtocol(self, addr):
 
49
        return self.p
 
50
 
 
51
 
 
52
class EchoProtocol(protocol.Protocol):
 
53
    paused = False
 
54
 
 
55
    def pauseProducing(self):
 
56
        self.paused = True
 
57
 
 
58
    def resumeProducing(self):
 
59
        self.paused = False
 
60
 
 
61
    def stopProducing(self):
 
62
        pass
 
63
 
 
64
    def dataReceived(self, data):
 
65
        self.transport.write(data)
 
66
 
 
67
 
 
68
 
 
69
class Server(protocol.ServerFactory):
 
70
    """
 
71
    A simple server factory using L{EchoProtocol}.
 
72
    """
 
73
    protocol = EchoProtocol
 
74
 
 
75
 
 
76
 
 
77
class TestableThrottlingFactory(policies.ThrottlingFactory):
 
78
    """
 
79
    L{policies.ThrottlingFactory} using a L{task.Clock} for tests.
 
80
    """
 
81
 
 
82
    def __init__(self, clock, *args, **kwargs):
 
83
        """
 
84
        @param clock: object providing a callLater method that can be used
 
85
            for tests.
 
86
        @type clock: C{task.Clock} or alike.
 
87
        """
 
88
        policies.ThrottlingFactory.__init__(self, *args, **kwargs)
 
89
        self.clock = clock
 
90
 
 
91
 
 
92
    def callLater(self, period, func):
 
93
        """
 
94
        Forward to the testable clock.
 
95
        """
 
96
        return self.clock.callLater(period, func)
 
97
 
 
98
 
 
99
 
 
100
class TestableTimeoutFactory(policies.TimeoutFactory):
 
101
    """
 
102
    L{policies.TimeoutFactory} using a L{task.Clock} for tests.
 
103
    """
 
104
 
 
105
    def __init__(self, clock, *args, **kwargs):
 
106
        """
 
107
        @param clock: object providing a callLater method that can be used
 
108
            for tests.
 
109
        @type clock: C{task.Clock} or alike.
 
110
        """
 
111
        policies.TimeoutFactory.__init__(self, *args, **kwargs)
 
112
        self.clock = clock
 
113
 
 
114
 
 
115
    def callLater(self, period, func):
 
116
        """
 
117
        Forward to the testable clock.
 
118
        """
 
119
        return self.clock.callLater(period, func)
 
120
 
 
121
 
 
122
 
 
123
class WrapperTestCase(unittest.TestCase):
 
124
    """
 
125
    Tests for L{WrappingFactory} and L{ProtocolWrapper}.
 
126
    """
 
127
    def test_protocolFactoryAttribute(self):
 
128
        """
 
129
        Make sure protocol.factory is the wrapped factory, not the wrapping
 
130
        factory.
 
131
        """
 
132
        f = Server()
 
133
        wf = policies.WrappingFactory(f)
 
134
        p = wf.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 35))
 
135
        self.assertIdentical(p.wrappedProtocol.factory, f)
 
136
 
 
137
 
 
138
    def test_transportInterfaces(self):
 
139
        """
 
140
        The transport wrapper passed to the wrapped protocol's
 
141
        C{makeConnection} provides the same interfaces as are provided by the
 
142
        original transport.
 
143
        """
 
144
        class IStubTransport(Interface):
 
145
            pass
 
146
 
 
147
        class StubTransport:
 
148
            implements(IStubTransport)
 
149
 
 
150
        # Looking up what ProtocolWrapper implements also mutates the class.
 
151
        # It adds __implemented__ and __providedBy__ attributes to it.  These
 
152
        # prevent __getattr__ from causing the IStubTransport.providedBy call
 
153
        # below from returning True.  If, by accident, nothing else causes
 
154
        # these attributes to be added to ProtocolWrapper, the test will pass,
 
155
        # but the interface will only be provided until something does trigger
 
156
        # their addition.  So we just trigger it right now to be sure.
 
157
        implementedBy(policies.ProtocolWrapper)
 
158
 
 
159
        proto = protocol.Protocol()
 
160
        wrapper = policies.ProtocolWrapper(policies.WrappingFactory(None), proto)
 
161
 
 
162
        wrapper.makeConnection(StubTransport())
 
163
        self.assertTrue(IStubTransport.providedBy(proto.transport))
 
164
 
 
165
 
 
166
 
 
167
class WrappingFactory(policies.WrappingFactory):
 
168
    protocol = lambda s, f, p: p
 
169
 
 
170
    def startFactory(self):
 
171
        policies.WrappingFactory.startFactory(self)
 
172
        self.deferred.callback(None)
 
173
 
 
174
 
 
175
 
 
176
class ThrottlingTestCase(unittest.TestCase):
 
177
    """
 
178
    Tests for L{policies.ThrottlingFactory}.
 
179
    """
 
180
 
 
181
    def test_limit(self):
 
182
        """
 
183
        Full test using a custom server limiting number of connections.
 
184
        """
 
185
        server = Server()
 
186
        c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
 
187
        tServer = policies.ThrottlingFactory(server, 2)
 
188
        wrapTServer = WrappingFactory(tServer)
 
189
        wrapTServer.deferred = defer.Deferred()
 
190
 
 
191
        # Start listening
 
192
        p = reactor.listenTCP(0, wrapTServer, interface="127.0.0.1")
 
193
        n = p.getHost().port
 
194
 
 
195
        def _connect123(results):
 
196
            reactor.connectTCP("127.0.0.1", n, SillyFactory(c1))
 
197
            c1.dConnected.addCallback(
 
198
                lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c2)))
 
199
            c2.dConnected.addCallback(
 
200
                lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c3)))
 
201
            return c3.dDisconnected
 
202
 
 
203
        def _check123(results):
 
204
            self.assertEquals([c.connected for c in c1, c2, c3], [1, 1, 1])
 
205
            self.assertEquals([c.disconnected for c in c1, c2, c3], [0, 0, 1])
 
206
            self.assertEquals(len(tServer.protocols.keys()), 2)
 
207
            return results
 
208
 
 
209
        def _lose1(results):
 
210
            # disconnect one protocol and now another should be able to connect
 
211
            c1.transport.loseConnection()
 
212
            return c1.dDisconnected
 
213
 
 
214
        def _connect4(results):
 
215
            reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
 
216
            return c4.dConnected
 
217
 
 
218
        def _check4(results):
 
219
            self.assertEquals(c4.connected, 1)
 
220
            self.assertEquals(c4.disconnected, 0)
 
221
            return results
 
222
 
 
223
        def _cleanup(results):
 
224
            for c in c2, c4:
 
225
                c.transport.loseConnection()
 
226
            return defer.DeferredList([
 
227
                defer.maybeDeferred(p.stopListening),
 
228
                c2.dDisconnected,
 
229
                c4.dDisconnected])
 
230
 
 
231
        wrapTServer.deferred.addCallback(_connect123)
 
232
        wrapTServer.deferred.addCallback(_check123)
 
233
        wrapTServer.deferred.addCallback(_lose1)
 
234
        wrapTServer.deferred.addCallback(_connect4)
 
235
        wrapTServer.deferred.addCallback(_check4)
 
236
        wrapTServer.deferred.addCallback(_cleanup)
 
237
        return wrapTServer.deferred
 
238
 
 
239
 
 
240
    def test_writeLimit(self):
 
241
        """
 
242
        Check the writeLimit parameter: write data, and check for the pause
 
243
        status.
 
244
        """
 
245
        server = Server()
 
246
        tServer = TestableThrottlingFactory(task.Clock(), server, writeLimit=10)
 
247
        port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
 
248
        tr = StringTransportWithDisconnection()
 
249
        tr.protocol = port
 
250
        port.makeConnection(tr)
 
251
        port.producer = port.wrappedProtocol
 
252
 
 
253
        port.dataReceived("0123456789")
 
254
        port.dataReceived("abcdefghij")
 
255
        self.assertEquals(tr.value(), "0123456789abcdefghij")
 
256
        self.assertEquals(tServer.writtenThisSecond, 20)
 
257
        self.assertFalse(port.wrappedProtocol.paused)
 
258
 
 
259
        # at this point server should've written 20 bytes, 10 bytes
 
260
        # above the limit so writing should be paused around 1 second
 
261
        # from 'now', and resumed a second after that
 
262
        tServer.clock.advance(1.05)
 
263
        self.assertEquals(tServer.writtenThisSecond, 0)
 
264
        self.assertTrue(port.wrappedProtocol.paused)
 
265
 
 
266
        tServer.clock.advance(1.05)
 
267
        self.assertEquals(tServer.writtenThisSecond, 0)
 
268
        self.assertFalse(port.wrappedProtocol.paused)
 
269
 
 
270
 
 
271
    def test_readLimit(self):
 
272
        """
 
273
        Check the readLimit parameter: read data and check for the pause
 
274
        status.
 
275
        """
 
276
        server = Server()
 
277
        tServer = TestableThrottlingFactory(task.Clock(), server, readLimit=10)
 
278
        port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
 
279
        tr = StringTransportWithDisconnection()
 
280
        tr.protocol = port
 
281
        port.makeConnection(tr)
 
282
 
 
283
        port.dataReceived("0123456789")
 
284
        port.dataReceived("abcdefghij")
 
285
        self.assertEquals(tr.value(), "0123456789abcdefghij")
 
286
        self.assertEquals(tServer.readThisSecond, 20)
 
287
 
 
288
        tServer.clock.advance(1.05)
 
289
        self.assertEquals(tServer.readThisSecond, 0)
 
290
        self.assertEquals(tr.producerState, 'paused')
 
291
 
 
292
        tServer.clock.advance(1.05)
 
293
        self.assertEquals(tServer.readThisSecond, 0)
 
294
        self.assertEquals(tr.producerState, 'producing')
 
295
 
 
296
        tr.clear()
 
297
        port.dataReceived("0123456789")
 
298
        port.dataReceived("abcdefghij")
 
299
        self.assertEquals(tr.value(), "0123456789abcdefghij")
 
300
        self.assertEquals(tServer.readThisSecond, 20)
 
301
 
 
302
        tServer.clock.advance(1.05)
 
303
        self.assertEquals(tServer.readThisSecond, 0)
 
304
        self.assertEquals(tr.producerState, 'paused')
 
305
 
 
306
        tServer.clock.advance(1.05)
 
307
        self.assertEquals(tServer.readThisSecond, 0)
 
308
        self.assertEquals(tr.producerState, 'producing')
 
309
 
 
310
 
 
311
 
 
312
class TimeoutTestCase(unittest.TestCase):
 
313
    """
 
314
    Tests for L{policies.TimeoutFactory}.
 
315
    """
 
316
 
 
317
    def setUp(self):
 
318
        """
 
319
        Create a testable, deterministic clock, and a set of
 
320
        server factory/protocol/transport.
 
321
        """
 
322
        self.clock = task.Clock()
 
323
        wrappedFactory = protocol.ServerFactory()
 
324
        wrappedFactory.protocol = SimpleProtocol
 
325
        self.factory = TestableTimeoutFactory(self.clock, wrappedFactory, 3)
 
326
        self.proto = self.factory.buildProtocol(
 
327
            address.IPv4Address('TCP', '127.0.0.1', 12345))
 
328
        self.transport = StringTransportWithDisconnection()
 
329
        self.transport.protocol = self.proto
 
330
        self.proto.makeConnection(self.transport)
 
331
 
 
332
 
 
333
    def test_timeout(self):
 
334
        """
 
335
        Make sure that when a TimeoutFactory accepts a connection, it will
 
336
        time out that connection if no data is read or written within the
 
337
        timeout period.
 
338
        """
 
339
        # Let almost 3 time units pass
 
340
        self.clock.pump([0.0, 0.5, 1.0, 1.0, 0.4])
 
341
        self.failIf(self.proto.wrappedProtocol.disconnected)
 
342
 
 
343
        # Now let the timer elapse
 
344
        self.clock.pump([0.0, 0.2])
 
345
        self.failUnless(self.proto.wrappedProtocol.disconnected)
 
346
 
 
347
 
 
348
    def test_sendAvoidsTimeout(self):
 
349
        """
 
350
        Make sure that writing data to a transport from a protocol
 
351
        constructed by a TimeoutFactory resets the timeout countdown.
 
352
        """
 
353
        # Let half the countdown period elapse
 
354
        self.clock.pump([0.0, 0.5, 1.0])
 
355
        self.failIf(self.proto.wrappedProtocol.disconnected)
 
356
 
 
357
        # Send some data (self.proto is the /real/ proto's transport, so this
 
358
        # is the write that gets called)
 
359
        self.proto.write('bytes bytes bytes')
 
360
 
 
361
        # More time passes, putting us past the original timeout
 
362
        self.clock.pump([0.0, 1.0, 1.0])
 
363
        self.failIf(self.proto.wrappedProtocol.disconnected)
 
364
 
 
365
        # Make sure writeSequence delays timeout as well
 
366
        self.proto.writeSequence(['bytes'] * 3)
 
367
 
 
368
        # Tick tock
 
369
        self.clock.pump([0.0, 1.0, 1.0])
 
370
        self.failIf(self.proto.wrappedProtocol.disconnected)
 
371
 
 
372
        # Don't write anything more, just let the timeout expire
 
373
        self.clock.pump([0.0, 2.0])
 
374
        self.failUnless(self.proto.wrappedProtocol.disconnected)
 
375
 
 
376
 
 
377
    def test_receiveAvoidsTimeout(self):
 
378
        """
 
379
        Make sure that receiving data also resets the timeout countdown.
 
380
        """
 
381
        # Let half the countdown period elapse
 
382
        self.clock.pump([0.0, 1.0, 0.5])
 
383
        self.failIf(self.proto.wrappedProtocol.disconnected)
 
384
 
 
385
        # Some bytes arrive, they should reset the counter
 
386
        self.proto.dataReceived('bytes bytes bytes')
 
387
 
 
388
        # We pass the original timeout
 
389
        self.clock.pump([0.0, 1.0, 1.0])
 
390
        self.failIf(self.proto.wrappedProtocol.disconnected)
 
391
 
 
392
        # Nothing more arrives though, the new timeout deadline is passed,
 
393
        # the connection should be dropped.
 
394
        self.clock.pump([0.0, 1.0, 1.0])
 
395
        self.failUnless(self.proto.wrappedProtocol.disconnected)
 
396
 
 
397
 
 
398
 
 
399
class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
 
400
    """
 
401
    A testable protocol with timeout facility.
 
402
 
 
403
    @ivar timedOut: set to C{True} if a timeout has been detected.
 
404
    @type timedOut: C{bool}
 
405
    """
 
406
    timeOut  = 3
 
407
    timedOut = False
 
408
 
 
409
    def __init__(self, clock):
 
410
        """
 
411
        Initialize the protocol with a C{task.Clock} object.
 
412
        """
 
413
        self.clock = clock
 
414
 
 
415
 
 
416
    def connectionMade(self):
 
417
        """
 
418
        Upon connection, set the timeout.
 
419
        """
 
420
        self.setTimeout(self.timeOut)
 
421
 
 
422
 
 
423
    def dataReceived(self, data):
 
424
        """
 
425
        Reset the timeout on data.
 
426
        """
 
427
        self.resetTimeout()
 
428
        protocol.Protocol.dataReceived(self, data)
 
429
 
 
430
 
 
431
    def connectionLost(self, reason=None):
 
432
        """
 
433
        On connection lost, cancel all timeout operations.
 
434
        """
 
435
        self.setTimeout(None)
 
436
 
 
437
 
 
438
    def timeoutConnection(self):
 
439
        """
 
440
        Flags the timedOut variable to indicate the timeout of the connection.
 
441
        """
 
442
        self.timedOut = True
 
443
 
 
444
 
 
445
    def callLater(self, timeout, func, *args, **kwargs):
 
446
        """
 
447
        Override callLater to use the deterministic clock.
 
448
        """
 
449
        return self.clock.callLater(timeout, func, *args, **kwargs)
 
450
 
 
451
 
 
452
 
 
453
class TestTimeout(unittest.TestCase):
 
454
    """
 
455
    Tests for L{policies.TimeoutMixin}.
 
456
    """
 
457
 
 
458
    def setUp(self):
 
459
        """
 
460
        Create a testable, deterministic clock and a C{TimeoutTester} instance.
 
461
        """
 
462
        self.clock = task.Clock()
 
463
        self.proto = TimeoutTester(self.clock)
 
464
 
 
465
 
 
466
    def test_overriddenCallLater(self):
 
467
        """
 
468
        Test that the callLater of the clock is used instead of
 
469
        C{reactor.callLater}.
 
470
        """
 
471
        self.proto.setTimeout(10)
 
472
        self.assertEquals(len(self.clock.calls), 1)
 
473
 
 
474
 
 
475
    def test_timeout(self):
 
476
        """
 
477
        Check that the protocol does timeout at the time specified by its
 
478
        C{timeOut} attribute.
 
479
        """
 
480
        self.proto.makeConnection(StringTransport())
 
481
 
 
482
        # timeOut value is 3
 
483
        self.clock.pump([0, 0.5, 1.0, 1.0])
 
484
        self.failIf(self.proto.timedOut)
 
485
        self.clock.pump([0, 1.0])
 
486
        self.failUnless(self.proto.timedOut)
 
487
 
 
488
 
 
489
    def test_noTimeout(self):
 
490
        """
 
491
        Check that receiving data is delaying the timeout of the connection.
 
492
        """
 
493
        self.proto.makeConnection(StringTransport())
 
494
 
 
495
        self.clock.pump([0, 0.5, 1.0, 1.0])
 
496
        self.failIf(self.proto.timedOut)
 
497
        self.proto.dataReceived('hello there')
 
498
        self.clock.pump([0, 1.0, 1.0, 0.5])
 
499
        self.failIf(self.proto.timedOut)
 
500
        self.clock.pump([0, 1.0])
 
501
        self.failUnless(self.proto.timedOut)
 
502
 
 
503
 
 
504
    def test_resetTimeout(self):
 
505
        """
 
506
        Check that setting a new value for timeout cancel the previous value
 
507
        and install a new timeout.
 
508
        """
 
509
        self.proto.timeOut = None
 
510
        self.proto.makeConnection(StringTransport())
 
511
 
 
512
        self.proto.setTimeout(1)
 
513
        self.assertEquals(self.proto.timeOut, 1)
 
514
 
 
515
        self.clock.pump([0, 0.9])
 
516
        self.failIf(self.proto.timedOut)
 
517
        self.clock.pump([0, 0.2])
 
518
        self.failUnless(self.proto.timedOut)
 
519
 
 
520
 
 
521
    def test_cancelTimeout(self):
 
522
        """
 
523
        Setting the timeout to C{None} cancel any timeout operations.
 
524
        """
 
525
        self.proto.timeOut = 5
 
526
        self.proto.makeConnection(StringTransport())
 
527
 
 
528
        self.proto.setTimeout(None)
 
529
        self.assertEquals(self.proto.timeOut, None)
 
530
 
 
531
        self.clock.pump([0, 5, 5, 5])
 
532
        self.failIf(self.proto.timedOut)
 
533
 
 
534
 
 
535
    def test_return(self):
 
536
        """
 
537
        setTimeout should return the value of the previous timeout.
 
538
        """
 
539
        self.proto.timeOut = 5
 
540
 
 
541
        self.assertEquals(self.proto.setTimeout(10), 5)
 
542
        self.assertEquals(self.proto.setTimeout(None), 10)
 
543
        self.assertEquals(self.proto.setTimeout(1), None)
 
544
        self.assertEquals(self.proto.timeOut, 1)
 
545
 
 
546
        # Clean up the DelayedCall
 
547
        self.proto.setTimeout(None)
 
548
 
 
549
 
 
550
 
 
551
class LimitTotalConnectionsFactoryTestCase(unittest.TestCase):
 
552
    """Tests for policies.LimitTotalConnectionsFactory"""
 
553
    def testConnectionCounting(self):
 
554
        # Make a basic factory
 
555
        factory = policies.LimitTotalConnectionsFactory()
 
556
        factory.protocol = protocol.Protocol
 
557
 
 
558
        # connectionCount starts at zero
 
559
        self.assertEqual(0, factory.connectionCount)
 
560
 
 
561
        # connectionCount increments as connections are made
 
562
        p1 = factory.buildProtocol(None)
 
563
        self.assertEqual(1, factory.connectionCount)
 
564
        p2 = factory.buildProtocol(None)
 
565
        self.assertEqual(2, factory.connectionCount)
 
566
 
 
567
        # and decrements as they are lost
 
568
        p1.connectionLost(None)
 
569
        self.assertEqual(1, factory.connectionCount)
 
570
        p2.connectionLost(None)
 
571
        self.assertEqual(0, factory.connectionCount)
 
572
 
 
573
    def testConnectionLimiting(self):
 
574
        # Make a basic factory with a connection limit of 1
 
575
        factory = policies.LimitTotalConnectionsFactory()
 
576
        factory.protocol = protocol.Protocol
 
577
        factory.connectionLimit = 1
 
578
 
 
579
        # Make a connection
 
580
        p = factory.buildProtocol(None)
 
581
        self.assertNotEqual(None, p)
 
582
        self.assertEqual(1, factory.connectionCount)
 
583
 
 
584
        # Try to make a second connection, which will exceed the connection
 
585
        # limit.  This should return None, because overflowProtocol is None.
 
586
        self.assertEqual(None, factory.buildProtocol(None))
 
587
        self.assertEqual(1, factory.connectionCount)
 
588
 
 
589
        # Define an overflow protocol
 
590
        class OverflowProtocol(protocol.Protocol):
 
591
            def connectionMade(self):
 
592
                factory.overflowed = True
 
593
        factory.overflowProtocol = OverflowProtocol
 
594
        factory.overflowed = False
 
595
 
 
596
        # Try to make a second connection again, now that we have an overflow
 
597
        # protocol.  Note that overflow connections count towards the connection
 
598
        # count.
 
599
        op = factory.buildProtocol(None)
 
600
        op.makeConnection(None) # to trigger connectionMade
 
601
        self.assertEqual(True, factory.overflowed)
 
602
        self.assertEqual(2, factory.connectionCount)
 
603
 
 
604
        # Close the connections.
 
605
        p.connectionLost(None)
 
606
        self.assertEqual(1, factory.connectionCount)
 
607
        op.connectionLost(None)
 
608
        self.assertEqual(0, factory.connectionCount)
 
609
 
 
610
 
 
611
class WriteSequenceEchoProtocol(EchoProtocol):
 
612
    def dataReceived(self, bytes):
 
613
        if bytes.find('vector!') != -1:
 
614
            self.transport.writeSequence([bytes])
 
615
        else:
 
616
            EchoProtocol.dataReceived(self, bytes)
 
617
 
 
618
class TestLoggingFactory(policies.TrafficLoggingFactory):
 
619
    openFile = None
 
620
    def open(self, name):
 
621
        assert self.openFile is None, "open() called too many times"
 
622
        self.openFile = StringIO()
 
623
        return self.openFile
 
624
 
 
625
 
 
626
 
 
627
class LoggingFactoryTestCase(unittest.TestCase):
 
628
    """
 
629
    Tests for L{policies.TrafficLoggingFactory}.
 
630
    """
 
631
 
 
632
    def test_thingsGetLogged(self):
 
633
        """
 
634
        Check the output produced by L{policies.TrafficLoggingFactory}.
 
635
        """
 
636
        wrappedFactory = Server()
 
637
        wrappedFactory.protocol = WriteSequenceEchoProtocol
 
638
        t = StringTransportWithDisconnection()
 
639
        f = TestLoggingFactory(wrappedFactory, 'test')
 
640
        p = f.buildProtocol(('1.2.3.4', 5678))
 
641
        t.protocol = p
 
642
        p.makeConnection(t)
 
643
 
 
644
        v = f.openFile.getvalue()
 
645
        self.failUnless('*' in v, "* not found in %r" % (v,))
 
646
        self.failIf(t.value())
 
647
 
 
648
        p.dataReceived('here are some bytes')
 
649
 
 
650
        v = f.openFile.getvalue()
 
651
        self.assertIn("C 1: 'here are some bytes'", v)
 
652
        self.assertIn("S 1: 'here are some bytes'", v)
 
653
        self.assertEquals(t.value(), 'here are some bytes')
 
654
 
 
655
        t.clear()
 
656
        p.dataReceived('prepare for vector! to the extreme')
 
657
        v = f.openFile.getvalue()
 
658
        self.assertIn("SV 1: ['prepare for vector! to the extreme']", v)
 
659
        self.assertEquals(t.value(), 'prepare for vector! to the extreme')
 
660
 
 
661
        p.loseConnection()
 
662
 
 
663
        v = f.openFile.getvalue()
 
664
        self.assertIn('ConnectionDone', v)
 
665
 
 
666
 
 
667
    def test_counter(self):
 
668
        """
 
669
        Test counter management with the resetCounter method.
 
670
        """
 
671
        wrappedFactory = Server()
 
672
        f = TestLoggingFactory(wrappedFactory, 'test')
 
673
        self.assertEqual(f._counter, 0)
 
674
        f.buildProtocol(('1.2.3.4', 5678))
 
675
        self.assertEqual(f._counter, 1)
 
676
        # Reset log file
 
677
        f.openFile = None
 
678
        f.buildProtocol(('1.2.3.4', 5679))
 
679
        self.assertEqual(f._counter, 2)
 
680
 
 
681
        f.resetCounter()
 
682
        self.assertEqual(f._counter, 0)
 
683