1
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
2
# See LICENSE for details.
5
Test cases for twisted.protocols package.
10
from twisted.trial import unittest
11
from twisted.protocols import basic, wire, portforward
12
from twisted.internet import reactor, protocol, defer, task, error
13
from twisted.test import proto_helpers
16
class LineTester(basic.LineReceiver):
18
A line receiver that parses data received and make actions on some tokens.
20
@type delimiter: C{str}
21
@ivar delimiter: character used between received lines.
22
@type MAX_LENGTH: C{int}
23
@ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
24
@type clock: L{twisted.internet.task.Clock}
25
@ivar clock: clock simulating reactor callLater. Pass it to constructor if
26
you want to use the pause/rawpause functionalities.
32
def __init__(self, clock=None):
34
If given, use a clock to make callLater calls.
38
def connectionMade(self):
40
Create/clean data received on connection.
44
def lineReceived(self, line):
46
Receive line and make some action for some tokens: pause, rawpause,
47
stop, len, produce, unproduce.
49
self.received.append(line)
54
self.clock.callLater(0, self.resumeProducing)
55
elif line == 'rawpause':
58
self.received.append('')
59
self.clock.callLater(0, self.resumeProducing)
62
elif line[:4] == 'len ':
63
self.length = int(line[4:])
64
elif line.startswith('produce'):
65
self.transport.registerProducer(self, False)
66
elif line.startswith('unproduce'):
67
self.transport.unregisterProducer()
69
def rawDataReceived(self, data):
71
Read raw data, until the quantity specified by a previous 'len' line is
74
data, rest = data[:self.length], data[self.length:]
75
self.length = self.length - len(data)
76
self.received[-1] = self.received[-1] + data
78
self.setLineMode(rest)
80
def lineLengthExceeded(self, line):
82
Adjust line mode when long lines received.
84
if len(line) > self.MAX_LENGTH + 1:
85
self.setLineMode(line[self.MAX_LENGTH + 1:])
88
class LineOnlyTester(basic.LineOnlyReceiver):
90
A buffering line only receiver.
95
def connectionMade(self):
97
Create/clean data received on connection.
101
def lineReceived(self, line):
105
self.received.append(line)
107
class WireTestCase(unittest.TestCase):
113
Test wire.Echo protocol: send some data and check it send it back.
115
t = proto_helpers.StringTransport()
118
a.dataReceived("hello")
119
a.dataReceived("world")
120
a.dataReceived("how")
121
a.dataReceived("are")
122
a.dataReceived("you")
123
self.assertEquals(t.value(), "helloworldhowareyou")
128
Test wire.Who protocol.
130
t = proto_helpers.StringTransport()
133
self.assertEquals(t.value(), "root\r\n")
138
Test wire.QOTD protocol.
140
t = proto_helpers.StringTransport()
143
self.assertEquals(t.value(),
144
"An apple a day keeps the doctor away.\r\n")
147
def test_discard(self):
149
Test wire.Discard protocol.
151
t = proto_helpers.StringTransport()
154
a.dataReceived("hello")
155
a.dataReceived("world")
156
a.dataReceived("how")
157
a.dataReceived("are")
158
a.dataReceived("you")
159
self.assertEqual(t.value(), "")
163
class LineReceiverTestCase(unittest.TestCase):
165
Test LineReceiver, using the C{LineTester} wrapper.
180
1234567890123456789012345678901234567890123456789012345678901234567890
185
output = ['len 10', '0123456789', 'len 5', '1234\n',
186
'len 20', 'foo 123', '0123456789\n012345678',
187
'len 0', 'foo 5', '', '67890', 'len 1', 'a']
189
def testBuffer(self):
191
Test buffering for different packet size, checking received matches
194
for packet_size in range(1, 10):
195
t = proto_helpers.StringIOWithoutClosing()
197
a.makeConnection(protocol.FileWrapper(t))
198
for i in range(len(self.buffer)/packet_size + 1):
199
s = self.buffer[i*packet_size:(i+1)*packet_size]
201
self.failUnlessEqual(self.output, a.received)
204
pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n'
206
pause_output1 = ['twiddle1', 'twiddle2', 'pause']
207
pause_output2 = pause_output1+['twiddle3']
209
def test_pausing(self):
211
Test pause inside data receiving. It uses fake clock to see if
212
pausing/resuming work.
214
for packet_size in range(1, 10):
215
t = proto_helpers.StringIOWithoutClosing()
217
a = LineTester(clock)
218
a.makeConnection(protocol.FileWrapper(t))
219
for i in range(len(self.pause_buf)/packet_size + 1):
220
s = self.pause_buf[i*packet_size:(i+1)*packet_size]
222
self.assertEquals(self.pause_output1, a.received)
224
self.assertEquals(self.pause_output2, a.received)
226
rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
228
rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
229
rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
232
def test_rawPausing(self):
234
Test pause inside raw date receiving.
236
for packet_size in range(1, 10):
237
t = proto_helpers.StringIOWithoutClosing()
239
a = LineTester(clock)
240
a.makeConnection(protocol.FileWrapper(t))
241
for i in range(len(self.rawpause_buf)/packet_size + 1):
242
s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
244
self.assertEquals(self.rawpause_output1, a.received)
246
self.assertEquals(self.rawpause_output2, a.received)
248
stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
250
stop_output = ['twiddle1', 'twiddle2', 'stop']
252
def test_stopProducing(self):
254
Test stop inside producing.
256
for packet_size in range(1, 10):
257
t = proto_helpers.StringIOWithoutClosing()
259
a.makeConnection(protocol.FileWrapper(t))
260
for i in range(len(self.stop_buf)/packet_size + 1):
261
s = self.stop_buf[i*packet_size:(i+1)*packet_size]
263
self.assertEquals(self.stop_output, a.received)
266
def test_lineReceiverAsProducer(self):
268
Test produce/unproduce in receiving.
271
t = proto_helpers.StringIOWithoutClosing()
272
a.makeConnection(protocol.FileWrapper(t))
273
a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
274
self.assertEquals(a.received,
275
['produce', 'hello world', 'unproduce', 'goodbye'])
278
def test_clearLineBuffer(self):
280
L{LineReceiver.clearLineBuffer} removes all buffered data and returns
281
it as a C{str} and can be called from beneath C{dataReceived}.
283
class ClearingReceiver(basic.LineReceiver):
284
def lineReceived(self, line):
286
self.rest = self.clearLineBuffer()
288
protocol = ClearingReceiver()
289
protocol.dataReceived('foo\r\nbar\r\nbaz')
290
self.assertEqual(protocol.line, 'foo')
291
self.assertEqual(protocol.rest, 'bar\r\nbaz')
293
# Deliver another line to make sure the previously buffered data is
295
protocol.dataReceived('quux\r\n')
296
self.assertEqual(protocol.line, 'quux')
297
self.assertEqual(protocol.rest, '')
301
class LineOnlyReceiverTestCase(unittest.TestCase):
303
Test line only receiveer.
311
def test_buffer(self):
313
Test buffering over line protocol: data received should match buffer.
315
t = proto_helpers.StringTransport()
318
for c in self.buffer:
320
self.assertEquals(a.received, self.buffer.split('\n')[:-1])
322
def test_lineTooLong(self):
324
Test sending a line too long: it should close the connection.
326
t = proto_helpers.StringTransport()
329
res = a.dataReceived('x'*200)
330
self.assertIsInstance(res, error.ConnectionLost)
336
def connectionMade(self):
339
def stringReceived(self, s):
340
self.received.append(s)
345
def connectionLost(self, reason):
349
class TestNetstring(TestMixin, basic.NetstringReceiver):
353
class LPTestCaseMixin:
358
def getProtocol(self):
360
Return a new instance of C{self.protocol} connected to a new instance
361
of L{proto_helpers.StringTransport}.
363
t = proto_helpers.StringTransport()
369
def test_illegal(self):
371
Assert that illegal strings cause the transport to be closed.
373
for s in self.illegalStrings:
374
r = self.getProtocol()
377
self.assertTrue(r.transport.disconnecting)
380
class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):
382
strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
385
'9999999999999999999999', 'abc', '4:abcde',
386
'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
388
protocol = TestNetstring
390
def test_buffer(self):
392
Test that when strings are received in chunks of different lengths,
393
they are still parsed correctly.
395
for packet_size in range(1, 10):
396
t = proto_helpers.StringTransport()
400
for s in self.strings:
403
for i in range(len(out)/packet_size + 1):
404
s = out[i*packet_size:(i+1)*packet_size]
407
self.assertEquals(a.received, self.strings)
409
def test_sendNonStrings(self):
411
L{basic.NetstringReceiver.sendString} will send objects that are not
412
strings by sending their string representation according to str().
414
nonStrings = [ [], { 1 : 'a', 2 : 'b' }, ['a', 'b', 'c'], 673,
415
(12, "fine", "and", "you?") ]
417
t = proto_helpers.StringTransport()
424
length = out[:out.find(":")]
425
data = out[out.find(":") + 1:-1] #[:-1] to ignore the trailing ","
426
self.assertEquals(int(length), len(str(s)))
427
self.assertEquals(data, str(s))
429
warnings = self.flushWarnings(
430
offendingFunctions=[self.test_sendNonStrings])
431
self.assertEqual(len(warnings), 5)
433
warnings[0]["message"],
434
"data passed to sendString() must be a string. Non-string support "
435
"is deprecated since Twisted 10.0")
437
warnings[0]['category'],
441
class IntNTestCaseMixin(LPTestCaseMixin):
443
TestCase mixin for int-prefixed protocols.
448
illegalStrings = None
449
partialStrings = None
451
def test_receive(self):
453
Test receiving data find the same data send.
455
r = self.getProtocol()
456
for s in self.strings:
457
for c in struct.pack(r.structFormat,len(s)) + s:
459
self.assertEquals(r.received, self.strings)
461
def test_partial(self):
463
Send partial data, nothing should be definitely received.
465
for s in self.partialStrings:
466
r = self.getProtocol()
469
self.assertEquals(r.received, [])
473
Test sending data over protocol.
475
r = self.getProtocol()
476
r.sendString("b" * 16)
477
self.assertEquals(r.transport.value(),
478
struct.pack(r.structFormat, 16) + "b" * 16)
481
def test_lengthLimitExceeded(self):
483
When a length prefix is received which is greater than the protocol's
484
C{MAX_LENGTH} attribute, the C{lengthLimitExceeded} method is called
485
with the received length prefix.
488
r = self.getProtocol()
489
r.lengthLimitExceeded = length.append
491
r.dataReceived(struct.pack(r.structFormat, 11))
492
self.assertEqual(length, [11])
495
def test_longStringNotDelivered(self):
497
If a length prefix for a string longer than C{MAX_LENGTH} is delivered
498
to C{dataReceived} at the same time as the entire string, the string is
499
not passed to C{stringReceived}.
501
r = self.getProtocol()
504
struct.pack(r.structFormat, 11) + 'x' * 11)
505
self.assertEqual(r.received, [])
509
class TestInt32(TestMixin, basic.Int32StringReceiver):
511
A L{basic.Int32StringReceiver} storing received strings in an array.
513
@ivar received: array holding received strings.
517
class Int32TestCase(unittest.TestCase, IntNTestCaseMixin):
519
Test case for int32-prefixed protocol
522
strings = ["a", "b" * 16]
523
illegalStrings = ["\x10\x00\x00\x00aaaaaa"]
524
partialStrings = ["\x00\x00\x00", "hello there", ""]
528
Test specific behavior of the 32-bits length.
530
r = self.getProtocol()
532
self.assertEquals(r.transport.value(), "\x00\x00\x00\x03foo")
533
r.dataReceived("\x00\x00\x00\x04ubar")
534
self.assertEquals(r.received, ["ubar"])
537
class TestInt16(TestMixin, basic.Int16StringReceiver):
539
A L{basic.Int16StringReceiver} storing received strings in an array.
541
@ivar received: array holding received strings.
545
class Int16TestCase(unittest.TestCase, IntNTestCaseMixin):
547
Test case for int16-prefixed protocol
550
strings = ["a", "b" * 16]
551
illegalStrings = ["\x10\x00aaaaaa"]
552
partialStrings = ["\x00", "hello there", ""]
556
Test specific behavior of the 16-bits length.
558
r = self.getProtocol()
560
self.assertEquals(r.transport.value(), "\x00\x03foo")
561
r.dataReceived("\x00\x04ubar")
562
self.assertEquals(r.received, ["ubar"])
564
def test_tooLongSend(self):
566
Send too much data: that should cause an error.
568
r = self.getProtocol()
569
tooSend = "b" * (2**(r.prefixLength*8) + 1)
570
self.assertRaises(AssertionError, r.sendString, tooSend)
573
class TestInt8(TestMixin, basic.Int8StringReceiver):
575
A L{basic.Int8StringReceiver} storing received strings in an array.
577
@ivar received: array holding received strings.
581
class Int8TestCase(unittest.TestCase, IntNTestCaseMixin):
583
Test case for int8-prefixed protocol
586
strings = ["a", "b" * 16]
587
illegalStrings = ["\x00\x00aaaaaa"]
588
partialStrings = ["\x08", "dzadz", ""]
592
Test specific behavior of the 8-bits length.
594
r = self.getProtocol()
596
self.assertEquals(r.transport.value(), "\x03foo")
597
r.dataReceived("\x04ubar")
598
self.assertEquals(r.received, ["ubar"])
600
def test_tooLongSend(self):
602
Send too much data: that should cause an error.
604
r = self.getProtocol()
605
tooSend = "b" * (2**(r.prefixLength*8) + 1)
606
self.assertRaises(AssertionError, r.sendString, tooSend)
609
class OnlyProducerTransport(object):
610
# Transport which isn't really a transport, just looks like one to
611
# someone not looking very hard.
614
disconnecting = False
619
def pauseProducing(self):
622
def resumeProducing(self):
625
def write(self, bytes):
626
self.data.append(bytes)
629
class ConsumingProtocol(basic.LineReceiver):
630
# Protocol that really, really doesn't want any more bytes.
632
def lineReceived(self, line):
633
self.transport.write(line)
634
self.pauseProducing()
637
class ProducerTestCase(unittest.TestCase):
638
def testPauseResume(self):
639
p = ConsumingProtocol()
640
t = OnlyProducerTransport()
643
p.dataReceived('hello, ')
645
self.failIf(t.paused)
646
self.failIf(p.paused)
648
p.dataReceived('world\r\n')
650
self.assertEquals(t.data, ['hello, world'])
651
self.failUnless(t.paused)
652
self.failUnless(p.paused)
656
self.failIf(t.paused)
657
self.failIf(p.paused)
659
p.dataReceived('hello\r\nworld\r\n')
661
self.assertEquals(t.data, ['hello, world', 'hello'])
662
self.failUnless(t.paused)
663
self.failUnless(p.paused)
666
p.dataReceived('goodbye\r\n')
668
self.assertEquals(t.data, ['hello, world', 'hello', 'world'])
669
self.failUnless(t.paused)
670
self.failUnless(p.paused)
674
self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
675
self.failUnless(t.paused)
676
self.failUnless(p.paused)
680
self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
681
self.failIf(t.paused)
682
self.failIf(p.paused)
686
class TestableProxyClientFactory(portforward.ProxyClientFactory):
688
Test proxy client factory that keeps the last created protocol instance.
690
@ivar protoInstance: the last instance of the protocol.
691
@type protoInstance: L{portforward.ProxyClient}
694
def buildProtocol(self, addr):
696
Create the protocol instance and keeps track of it.
698
proto = portforward.ProxyClientFactory.buildProtocol(self, addr)
699
self.protoInstance = proto
704
class TestableProxyFactory(portforward.ProxyFactory):
706
Test proxy factory that keeps the last created protocol instance.
708
@ivar protoInstance: the last instance of the protocol.
709
@type protoInstance: L{portforward.ProxyServer}
711
@ivar clientFactoryInstance: client factory used by C{protoInstance} to
712
create forward connections.
713
@type clientFactoryInstance: L{TestableProxyClientFactory}
716
def buildProtocol(self, addr):
718
Create the protocol instance, keeps track of it, and makes it use
719
C{clientFactoryInstance} as client factory.
721
proto = portforward.ProxyFactory.buildProtocol(self, addr)
722
self.clientFactoryInstance = TestableProxyClientFactory()
723
# Force the use of this specific instance
724
proto.clientProtocolFactory = lambda: self.clientFactoryInstance
725
self.protoInstance = proto
730
class Portforwarding(unittest.TestCase):
732
Test port forwarding.
736
self.serverProtocol = wire.Echo()
737
self.clientProtocol = protocol.Protocol()
743
self.proxyServerFactory.protoInstance.transport.loseConnection()
744
except AttributeError:
747
self.proxyServerFactory.clientFactoryInstance.protoInstance.transport.loseConnection()
748
except AttributeError:
751
self.clientProtocol.transport.loseConnection()
752
except AttributeError:
755
self.serverProtocol.transport.loseConnection()
756
except AttributeError:
758
return defer.gatherResults(
759
[defer.maybeDeferred(p.stopListening) for p in self.openPorts])
762
def test_portforward(self):
764
Test port forwarding through Echo protocol.
766
realServerFactory = protocol.ServerFactory()
767
realServerFactory.protocol = lambda: self.serverProtocol
768
realServerPort = reactor.listenTCP(0, realServerFactory,
769
interface='127.0.0.1')
770
self.openPorts.append(realServerPort)
771
self.proxyServerFactory = TestableProxyFactory('127.0.0.1',
772
realServerPort.getHost().port)
773
proxyServerPort = reactor.listenTCP(0, self.proxyServerFactory,
774
interface='127.0.0.1')
775
self.openPorts.append(proxyServerPort)
780
def testDataReceived(data):
781
received.extend(data)
782
if len(received) >= nBytes:
783
self.assertEquals(''.join(received), 'x' * nBytes)
785
self.clientProtocol.dataReceived = testDataReceived
787
def testConnectionMade():
788
self.clientProtocol.transport.write('x' * nBytes)
789
self.clientProtocol.connectionMade = testConnectionMade
791
clientFactory = protocol.ClientFactory()
792
clientFactory.protocol = lambda: self.clientProtocol
795
'127.0.0.1', proxyServerPort.getHost().port, clientFactory)
801
class StringTransportTestCase(unittest.TestCase):
803
Test L{proto_helpers.StringTransport} helper behaviour.
806
def test_noUnicode(self):
808
Test that L{proto_helpers.StringTransport} doesn't accept unicode data.
810
s = proto_helpers.StringTransport()
811
self.assertRaises(TypeError, s.write, u'foo')