1
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
2
# See LICENSE for details.
5
Test code for policies.
8
from zope.interface import Interface, implements, implementedBy
10
from StringIO import StringIO
12
from twisted.trial import unittest
13
from twisted.test.proto_helpers import StringTransport
14
from twisted.test.proto_helpers import StringTransportWithDisconnection
16
from twisted.internet import protocol, reactor, address, defer, task
17
from twisted.protocols import policies
21
class SimpleProtocol(protocol.Protocol):
23
connected = disconnected = 0
27
self.dConnected = defer.Deferred()
28
self.dDisconnected = defer.Deferred()
30
def connectionMade(self):
32
self.dConnected.callback('')
34
def connectionLost(self, reason):
36
self.dDisconnected.callback('')
38
def dataReceived(self, data):
43
class SillyFactory(protocol.ClientFactory):
45
def __init__(self, p):
48
def buildProtocol(self, addr):
52
class EchoProtocol(protocol.Protocol):
55
def pauseProducing(self):
58
def resumeProducing(self):
61
def stopProducing(self):
64
def dataReceived(self, data):
65
self.transport.write(data)
69
class Server(protocol.ServerFactory):
71
A simple server factory using L{EchoProtocol}.
73
protocol = EchoProtocol
77
class TestableThrottlingFactory(policies.ThrottlingFactory):
79
L{policies.ThrottlingFactory} using a L{task.Clock} for tests.
82
def __init__(self, clock, *args, **kwargs):
84
@param clock: object providing a callLater method that can be used
86
@type clock: C{task.Clock} or alike.
88
policies.ThrottlingFactory.__init__(self, *args, **kwargs)
92
def callLater(self, period, func):
94
Forward to the testable clock.
96
return self.clock.callLater(period, func)
100
class TestableTimeoutFactory(policies.TimeoutFactory):
102
L{policies.TimeoutFactory} using a L{task.Clock} for tests.
105
def __init__(self, clock, *args, **kwargs):
107
@param clock: object providing a callLater method that can be used
109
@type clock: C{task.Clock} or alike.
111
policies.TimeoutFactory.__init__(self, *args, **kwargs)
115
def callLater(self, period, func):
117
Forward to the testable clock.
119
return self.clock.callLater(period, func)
123
class WrapperTestCase(unittest.TestCase):
125
Tests for L{WrappingFactory} and L{ProtocolWrapper}.
127
def test_protocolFactoryAttribute(self):
129
Make sure protocol.factory is the wrapped factory, not the wrapping
133
wf = policies.WrappingFactory(f)
134
p = wf.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 35))
135
self.assertIdentical(p.wrappedProtocol.factory, f)
138
def test_transportInterfaces(self):
140
The transport wrapper passed to the wrapped protocol's
141
C{makeConnection} provides the same interfaces as are provided by the
144
class IStubTransport(Interface):
148
implements(IStubTransport)
150
# Looking up what ProtocolWrapper implements also mutates the class.
151
# It adds __implemented__ and __providedBy__ attributes to it. These
152
# prevent __getattr__ from causing the IStubTransport.providedBy call
153
# below from returning True. If, by accident, nothing else causes
154
# these attributes to be added to ProtocolWrapper, the test will pass,
155
# but the interface will only be provided until something does trigger
156
# their addition. So we just trigger it right now to be sure.
157
implementedBy(policies.ProtocolWrapper)
159
proto = protocol.Protocol()
160
wrapper = policies.ProtocolWrapper(policies.WrappingFactory(None), proto)
162
wrapper.makeConnection(StubTransport())
163
self.assertTrue(IStubTransport.providedBy(proto.transport))
167
class WrappingFactory(policies.WrappingFactory):
168
protocol = lambda s, f, p: p
170
def startFactory(self):
171
policies.WrappingFactory.startFactory(self)
172
self.deferred.callback(None)
176
class ThrottlingTestCase(unittest.TestCase):
178
Tests for L{policies.ThrottlingFactory}.
181
def test_limit(self):
183
Full test using a custom server limiting number of connections.
186
c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
187
tServer = policies.ThrottlingFactory(server, 2)
188
wrapTServer = WrappingFactory(tServer)
189
wrapTServer.deferred = defer.Deferred()
192
p = reactor.listenTCP(0, wrapTServer, interface="127.0.0.1")
195
def _connect123(results):
196
reactor.connectTCP("127.0.0.1", n, SillyFactory(c1))
197
c1.dConnected.addCallback(
198
lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c2)))
199
c2.dConnected.addCallback(
200
lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c3)))
201
return c3.dDisconnected
203
def _check123(results):
204
self.assertEquals([c.connected for c in c1, c2, c3], [1, 1, 1])
205
self.assertEquals([c.disconnected for c in c1, c2, c3], [0, 0, 1])
206
self.assertEquals(len(tServer.protocols.keys()), 2)
210
# disconnect one protocol and now another should be able to connect
211
c1.transport.loseConnection()
212
return c1.dDisconnected
214
def _connect4(results):
215
reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
218
def _check4(results):
219
self.assertEquals(c4.connected, 1)
220
self.assertEquals(c4.disconnected, 0)
223
def _cleanup(results):
225
c.transport.loseConnection()
226
return defer.DeferredList([
227
defer.maybeDeferred(p.stopListening),
231
wrapTServer.deferred.addCallback(_connect123)
232
wrapTServer.deferred.addCallback(_check123)
233
wrapTServer.deferred.addCallback(_lose1)
234
wrapTServer.deferred.addCallback(_connect4)
235
wrapTServer.deferred.addCallback(_check4)
236
wrapTServer.deferred.addCallback(_cleanup)
237
return wrapTServer.deferred
240
def test_writeLimit(self):
242
Check the writeLimit parameter: write data, and check for the pause
246
tServer = TestableThrottlingFactory(task.Clock(), server, writeLimit=10)
247
port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
248
tr = StringTransportWithDisconnection()
250
port.makeConnection(tr)
251
port.producer = port.wrappedProtocol
253
port.dataReceived("0123456789")
254
port.dataReceived("abcdefghij")
255
self.assertEquals(tr.value(), "0123456789abcdefghij")
256
self.assertEquals(tServer.writtenThisSecond, 20)
257
self.assertFalse(port.wrappedProtocol.paused)
259
# at this point server should've written 20 bytes, 10 bytes
260
# above the limit so writing should be paused around 1 second
261
# from 'now', and resumed a second after that
262
tServer.clock.advance(1.05)
263
self.assertEquals(tServer.writtenThisSecond, 0)
264
self.assertTrue(port.wrappedProtocol.paused)
266
tServer.clock.advance(1.05)
267
self.assertEquals(tServer.writtenThisSecond, 0)
268
self.assertFalse(port.wrappedProtocol.paused)
271
def test_readLimit(self):
273
Check the readLimit parameter: read data and check for the pause
277
tServer = TestableThrottlingFactory(task.Clock(), server, readLimit=10)
278
port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
279
tr = StringTransportWithDisconnection()
281
port.makeConnection(tr)
283
port.dataReceived("0123456789")
284
port.dataReceived("abcdefghij")
285
self.assertEquals(tr.value(), "0123456789abcdefghij")
286
self.assertEquals(tServer.readThisSecond, 20)
288
tServer.clock.advance(1.05)
289
self.assertEquals(tServer.readThisSecond, 0)
290
self.assertEquals(tr.producerState, 'paused')
292
tServer.clock.advance(1.05)
293
self.assertEquals(tServer.readThisSecond, 0)
294
self.assertEquals(tr.producerState, 'producing')
297
port.dataReceived("0123456789")
298
port.dataReceived("abcdefghij")
299
self.assertEquals(tr.value(), "0123456789abcdefghij")
300
self.assertEquals(tServer.readThisSecond, 20)
302
tServer.clock.advance(1.05)
303
self.assertEquals(tServer.readThisSecond, 0)
304
self.assertEquals(tr.producerState, 'paused')
306
tServer.clock.advance(1.05)
307
self.assertEquals(tServer.readThisSecond, 0)
308
self.assertEquals(tr.producerState, 'producing')
312
class TimeoutTestCase(unittest.TestCase):
314
Tests for L{policies.TimeoutFactory}.
319
Create a testable, deterministic clock, and a set of
320
server factory/protocol/transport.
322
self.clock = task.Clock()
323
wrappedFactory = protocol.ServerFactory()
324
wrappedFactory.protocol = SimpleProtocol
325
self.factory = TestableTimeoutFactory(self.clock, wrappedFactory, 3)
326
self.proto = self.factory.buildProtocol(
327
address.IPv4Address('TCP', '127.0.0.1', 12345))
328
self.transport = StringTransportWithDisconnection()
329
self.transport.protocol = self.proto
330
self.proto.makeConnection(self.transport)
333
def test_timeout(self):
335
Make sure that when a TimeoutFactory accepts a connection, it will
336
time out that connection if no data is read or written within the
339
# Let almost 3 time units pass
340
self.clock.pump([0.0, 0.5, 1.0, 1.0, 0.4])
341
self.failIf(self.proto.wrappedProtocol.disconnected)
343
# Now let the timer elapse
344
self.clock.pump([0.0, 0.2])
345
self.failUnless(self.proto.wrappedProtocol.disconnected)
348
def test_sendAvoidsTimeout(self):
350
Make sure that writing data to a transport from a protocol
351
constructed by a TimeoutFactory resets the timeout countdown.
353
# Let half the countdown period elapse
354
self.clock.pump([0.0, 0.5, 1.0])
355
self.failIf(self.proto.wrappedProtocol.disconnected)
357
# Send some data (self.proto is the /real/ proto's transport, so this
358
# is the write that gets called)
359
self.proto.write('bytes bytes bytes')
361
# More time passes, putting us past the original timeout
362
self.clock.pump([0.0, 1.0, 1.0])
363
self.failIf(self.proto.wrappedProtocol.disconnected)
365
# Make sure writeSequence delays timeout as well
366
self.proto.writeSequence(['bytes'] * 3)
369
self.clock.pump([0.0, 1.0, 1.0])
370
self.failIf(self.proto.wrappedProtocol.disconnected)
372
# Don't write anything more, just let the timeout expire
373
self.clock.pump([0.0, 2.0])
374
self.failUnless(self.proto.wrappedProtocol.disconnected)
377
def test_receiveAvoidsTimeout(self):
379
Make sure that receiving data also resets the timeout countdown.
381
# Let half the countdown period elapse
382
self.clock.pump([0.0, 1.0, 0.5])
383
self.failIf(self.proto.wrappedProtocol.disconnected)
385
# Some bytes arrive, they should reset the counter
386
self.proto.dataReceived('bytes bytes bytes')
388
# We pass the original timeout
389
self.clock.pump([0.0, 1.0, 1.0])
390
self.failIf(self.proto.wrappedProtocol.disconnected)
392
# Nothing more arrives though, the new timeout deadline is passed,
393
# the connection should be dropped.
394
self.clock.pump([0.0, 1.0, 1.0])
395
self.failUnless(self.proto.wrappedProtocol.disconnected)
399
class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
401
A testable protocol with timeout facility.
403
@ivar timedOut: set to C{True} if a timeout has been detected.
404
@type timedOut: C{bool}
409
def __init__(self, clock):
411
Initialize the protocol with a C{task.Clock} object.
416
def connectionMade(self):
418
Upon connection, set the timeout.
420
self.setTimeout(self.timeOut)
423
def dataReceived(self, data):
425
Reset the timeout on data.
428
protocol.Protocol.dataReceived(self, data)
431
def connectionLost(self, reason=None):
433
On connection lost, cancel all timeout operations.
435
self.setTimeout(None)
438
def timeoutConnection(self):
440
Flags the timedOut variable to indicate the timeout of the connection.
445
def callLater(self, timeout, func, *args, **kwargs):
447
Override callLater to use the deterministic clock.
449
return self.clock.callLater(timeout, func, *args, **kwargs)
453
class TestTimeout(unittest.TestCase):
455
Tests for L{policies.TimeoutMixin}.
460
Create a testable, deterministic clock and a C{TimeoutTester} instance.
462
self.clock = task.Clock()
463
self.proto = TimeoutTester(self.clock)
466
def test_overriddenCallLater(self):
468
Test that the callLater of the clock is used instead of
469
C{reactor.callLater}.
471
self.proto.setTimeout(10)
472
self.assertEquals(len(self.clock.calls), 1)
475
def test_timeout(self):
477
Check that the protocol does timeout at the time specified by its
478
C{timeOut} attribute.
480
self.proto.makeConnection(StringTransport())
483
self.clock.pump([0, 0.5, 1.0, 1.0])
484
self.failIf(self.proto.timedOut)
485
self.clock.pump([0, 1.0])
486
self.failUnless(self.proto.timedOut)
489
def test_noTimeout(self):
491
Check that receiving data is delaying the timeout of the connection.
493
self.proto.makeConnection(StringTransport())
495
self.clock.pump([0, 0.5, 1.0, 1.0])
496
self.failIf(self.proto.timedOut)
497
self.proto.dataReceived('hello there')
498
self.clock.pump([0, 1.0, 1.0, 0.5])
499
self.failIf(self.proto.timedOut)
500
self.clock.pump([0, 1.0])
501
self.failUnless(self.proto.timedOut)
504
def test_resetTimeout(self):
506
Check that setting a new value for timeout cancel the previous value
507
and install a new timeout.
509
self.proto.timeOut = None
510
self.proto.makeConnection(StringTransport())
512
self.proto.setTimeout(1)
513
self.assertEquals(self.proto.timeOut, 1)
515
self.clock.pump([0, 0.9])
516
self.failIf(self.proto.timedOut)
517
self.clock.pump([0, 0.2])
518
self.failUnless(self.proto.timedOut)
521
def test_cancelTimeout(self):
523
Setting the timeout to C{None} cancel any timeout operations.
525
self.proto.timeOut = 5
526
self.proto.makeConnection(StringTransport())
528
self.proto.setTimeout(None)
529
self.assertEquals(self.proto.timeOut, None)
531
self.clock.pump([0, 5, 5, 5])
532
self.failIf(self.proto.timedOut)
535
def test_return(self):
537
setTimeout should return the value of the previous timeout.
539
self.proto.timeOut = 5
541
self.assertEquals(self.proto.setTimeout(10), 5)
542
self.assertEquals(self.proto.setTimeout(None), 10)
543
self.assertEquals(self.proto.setTimeout(1), None)
544
self.assertEquals(self.proto.timeOut, 1)
546
# Clean up the DelayedCall
547
self.proto.setTimeout(None)
551
class LimitTotalConnectionsFactoryTestCase(unittest.TestCase):
552
"""Tests for policies.LimitTotalConnectionsFactory"""
553
def testConnectionCounting(self):
554
# Make a basic factory
555
factory = policies.LimitTotalConnectionsFactory()
556
factory.protocol = protocol.Protocol
558
# connectionCount starts at zero
559
self.assertEqual(0, factory.connectionCount)
561
# connectionCount increments as connections are made
562
p1 = factory.buildProtocol(None)
563
self.assertEqual(1, factory.connectionCount)
564
p2 = factory.buildProtocol(None)
565
self.assertEqual(2, factory.connectionCount)
567
# and decrements as they are lost
568
p1.connectionLost(None)
569
self.assertEqual(1, factory.connectionCount)
570
p2.connectionLost(None)
571
self.assertEqual(0, factory.connectionCount)
573
def testConnectionLimiting(self):
574
# Make a basic factory with a connection limit of 1
575
factory = policies.LimitTotalConnectionsFactory()
576
factory.protocol = protocol.Protocol
577
factory.connectionLimit = 1
580
p = factory.buildProtocol(None)
581
self.assertNotEqual(None, p)
582
self.assertEqual(1, factory.connectionCount)
584
# Try to make a second connection, which will exceed the connection
585
# limit. This should return None, because overflowProtocol is None.
586
self.assertEqual(None, factory.buildProtocol(None))
587
self.assertEqual(1, factory.connectionCount)
589
# Define an overflow protocol
590
class OverflowProtocol(protocol.Protocol):
591
def connectionMade(self):
592
factory.overflowed = True
593
factory.overflowProtocol = OverflowProtocol
594
factory.overflowed = False
596
# Try to make a second connection again, now that we have an overflow
597
# protocol. Note that overflow connections count towards the connection
599
op = factory.buildProtocol(None)
600
op.makeConnection(None) # to trigger connectionMade
601
self.assertEqual(True, factory.overflowed)
602
self.assertEqual(2, factory.connectionCount)
604
# Close the connections.
605
p.connectionLost(None)
606
self.assertEqual(1, factory.connectionCount)
607
op.connectionLost(None)
608
self.assertEqual(0, factory.connectionCount)
611
class WriteSequenceEchoProtocol(EchoProtocol):
612
def dataReceived(self, bytes):
613
if bytes.find('vector!') != -1:
614
self.transport.writeSequence([bytes])
616
EchoProtocol.dataReceived(self, bytes)
618
class TestLoggingFactory(policies.TrafficLoggingFactory):
620
def open(self, name):
621
assert self.openFile is None, "open() called too many times"
622
self.openFile = StringIO()
627
class LoggingFactoryTestCase(unittest.TestCase):
629
Tests for L{policies.TrafficLoggingFactory}.
632
def test_thingsGetLogged(self):
634
Check the output produced by L{policies.TrafficLoggingFactory}.
636
wrappedFactory = Server()
637
wrappedFactory.protocol = WriteSequenceEchoProtocol
638
t = StringTransportWithDisconnection()
639
f = TestLoggingFactory(wrappedFactory, 'test')
640
p = f.buildProtocol(('1.2.3.4', 5678))
644
v = f.openFile.getvalue()
645
self.failUnless('*' in v, "* not found in %r" % (v,))
646
self.failIf(t.value())
648
p.dataReceived('here are some bytes')
650
v = f.openFile.getvalue()
651
self.assertIn("C 1: 'here are some bytes'", v)
652
self.assertIn("S 1: 'here are some bytes'", v)
653
self.assertEquals(t.value(), 'here are some bytes')
656
p.dataReceived('prepare for vector! to the extreme')
657
v = f.openFile.getvalue()
658
self.assertIn("SV 1: ['prepare for vector! to the extreme']", v)
659
self.assertEquals(t.value(), 'prepare for vector! to the extreme')
663
v = f.openFile.getvalue()
664
self.assertIn('ConnectionDone', v)
667
def test_counter(self):
669
Test counter management with the resetCounter method.
671
wrappedFactory = Server()
672
f = TestLoggingFactory(wrappedFactory, 'test')
673
self.assertEqual(f._counter, 0)
674
f.buildProtocol(('1.2.3.4', 5678))
675
self.assertEqual(f._counter, 1)
678
f.buildProtocol(('1.2.3.4', 5679))
679
self.assertEqual(f._counter, 2)
682
self.assertEqual(f._counter, 0)