1
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
2
# See LICENSE for details.
6
Test cases for twisted.protocols package.
9
from twisted.trial import unittest
10
from twisted.protocols import basic, wire, portforward
11
from twisted.internet import reactor, protocol, defer, task, error
16
class StringIOWithoutClosing(StringIO.StringIO):
18
A StringIO that can't be closed.
25
class LineTester(basic.LineReceiver):
27
A line receiver that parses data received and make actions on some tokens.
29
@type delimiter: C{str}
30
@ivar delimiter: character used between received lines.
31
@type MAX_LENGTH: C{int}
32
@ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
33
@type clock: L{twisted.internet.task.Clock}
34
@ivar clock: clock simulating reactor callLater. Pass it to constructor if
35
you want to use the pause/rawpause functionalities.
41
def __init__(self, clock=None):
43
If given, use a clock to make callLater calls.
47
def connectionMade(self):
49
Create/clean data received on connection.
53
def lineReceived(self, line):
55
Receive line and make some action for some tokens: pause, rawpause,
56
stop, len, produce, unproduce.
58
self.received.append(line)
63
self.clock.callLater(0, self.resumeProducing)
64
elif line == 'rawpause':
67
self.received.append('')
68
self.clock.callLater(0, self.resumeProducing)
71
elif line[:4] == 'len ':
72
self.length = int(line[4:])
73
elif line.startswith('produce'):
74
self.transport.registerProducer(self, False)
75
elif line.startswith('unproduce'):
76
self.transport.unregisterProducer()
78
def rawDataReceived(self, data):
80
Read raw data, until the quantity specified by a previous 'len' line is
83
data, rest = data[:self.length], data[self.length:]
84
self.length = self.length - len(data)
85
self.received[-1] = self.received[-1] + data
87
self.setLineMode(rest)
89
def lineLengthExceeded(self, line):
91
Adjust line mode when long lines received.
93
if len(line) > self.MAX_LENGTH + 1:
94
self.setLineMode(line[self.MAX_LENGTH + 1:])
97
class LineOnlyTester(basic.LineOnlyReceiver):
99
A buffering line only receiver.
104
def connectionMade(self):
106
Create/clean data received on connection.
110
def lineReceived(self, line):
114
self.received.append(line)
116
class WireTestCase(unittest.TestCase):
122
Test wire.Echo protocol: send some data and check it send it back.
124
t = StringIOWithoutClosing()
126
a.makeConnection(protocol.FileWrapper(t))
127
a.dataReceived("hello")
128
a.dataReceived("world")
129
a.dataReceived("how")
130
a.dataReceived("are")
131
a.dataReceived("you")
132
self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
136
Test wire.Who protocol.
138
t = StringIOWithoutClosing()
140
a.makeConnection(protocol.FileWrapper(t))
141
self.failUnlessEqual(t.getvalue(), "root\r\n")
145
Test wire.QOTD protocol.
147
t = StringIOWithoutClosing()
149
a.makeConnection(protocol.FileWrapper(t))
150
self.failUnlessEqual(t.getvalue(),
151
"An apple a day keeps the doctor away.\r\n")
153
def testDiscard(self):
155
Test wire.Discard protocol.
157
t = StringIOWithoutClosing()
159
a.makeConnection(protocol.FileWrapper(t))
160
a.dataReceived("hello")
161
a.dataReceived("world")
162
a.dataReceived("how")
163
a.dataReceived("are")
164
a.dataReceived("you")
165
self.failUnlessEqual(t.getvalue(), "")
167
class LineReceiverTestCase(unittest.TestCase):
169
Test LineReceiver, using the C{LineTester} wrapper.
184
1234567890123456789012345678901234567890123456789012345678901234567890
189
output = ['len 10', '0123456789', 'len 5', '1234\n',
190
'len 20', 'foo 123', '0123456789\n012345678',
191
'len 0', 'foo 5', '', '67890', 'len 1', 'a']
193
def testBuffer(self):
195
Test buffering for different packet size, checking received matches
198
for packet_size in range(1, 10):
199
t = StringIOWithoutClosing()
201
a.makeConnection(protocol.FileWrapper(t))
202
for i in range(len(self.buffer)/packet_size + 1):
203
s = self.buffer[i*packet_size:(i+1)*packet_size]
205
self.failUnlessEqual(self.output, a.received)
208
pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n'
210
pause_output1 = ['twiddle1', 'twiddle2', 'pause']
211
pause_output2 = pause_output1+['twiddle3']
213
def testPausing(self):
215
Test pause inside data receiving. It uses fake clock to see if
216
pausing/resuming work.
218
for packet_size in range(1, 10):
219
t = StringIOWithoutClosing()
221
a = LineTester(clock)
222
a.makeConnection(protocol.FileWrapper(t))
223
for i in range(len(self.pause_buf)/packet_size + 1):
224
s = self.pause_buf[i*packet_size:(i+1)*packet_size]
226
self.failUnlessEqual(self.pause_output1, a.received)
228
self.failUnlessEqual(self.pause_output2, a.received)
230
rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
232
rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
233
rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
236
def testRawPausing(self):
238
Test pause inside raw date receiving.
240
for packet_size in range(1, 10):
241
t = StringIOWithoutClosing()
243
a = LineTester(clock)
244
a.makeConnection(protocol.FileWrapper(t))
245
for i in range(len(self.rawpause_buf)/packet_size + 1):
246
s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
248
self.failUnlessEqual(self.rawpause_output1, a.received)
250
self.failUnlessEqual(self.rawpause_output2, a.received)
252
stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
254
stop_output = ['twiddle1', 'twiddle2', 'stop']
256
def testStopProducing(self):
258
Test stop inside producing.
260
for packet_size in range(1, 10):
261
t = StringIOWithoutClosing()
263
a.makeConnection(protocol.FileWrapper(t))
264
for i in range(len(self.stop_buf)/packet_size + 1):
265
s = self.stop_buf[i*packet_size:(i+1)*packet_size]
267
self.failUnlessEqual(self.stop_output, a.received)
270
def testLineReceiverAsProducer(self):
272
Test produce/unproduce in receiving.
275
t = StringIOWithoutClosing()
276
a.makeConnection(protocol.FileWrapper(t))
277
a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
278
self.assertEquals(a.received,
279
['produce', 'hello world', 'unproduce', 'goodbye'])
282
class LineOnlyReceiverTestCase(unittest.TestCase):
284
Test line only receiveer.
292
def testBuffer(self):
294
Test buffering over line protocol: data received should match buffer.
296
t = StringIOWithoutClosing()
298
a.makeConnection(protocol.FileWrapper(t))
299
for c in self.buffer:
301
self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])
303
def testLineTooLong(self):
305
Test sending a line too long: it should close the connection.
307
t = StringIOWithoutClosing()
309
a.makeConnection(protocol.FileWrapper(t))
310
res = a.dataReceived('x'*200)
311
self.assertTrue(isinstance(res, error.ConnectionLost))
316
def connectionMade(self):
319
def stringReceived(self, s):
320
self.received.append(s)
325
def connectionLost(self, reason):
329
class TestNetstring(TestMixin, basic.NetstringReceiver):
333
class LPTestCaseMixin:
338
def getProtocol(self):
339
t = StringIOWithoutClosing()
341
a.makeConnection(protocol.FileWrapper(t))
344
def testIllegal(self):
345
for s in self.illegal_strings:
346
r = self.getProtocol()
349
self.assertEquals(r.transport.closed, 1)
352
class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):
354
strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
357
'9999999999999999999999', 'abc', '4:abcde',
358
'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
360
protocol = TestNetstring
362
def testBuffer(self):
363
for packet_size in range(1, 10):
364
t = StringIOWithoutClosing()
367
a.makeConnection(protocol.FileWrapper(t))
368
for s in self.strings:
371
for i in range(len(out)/packet_size + 1):
372
s = out[i*packet_size:(i+1)*packet_size]
375
self.assertEquals(a.received, self.strings)
378
class TestInt32(TestMixin, basic.Int32StringReceiver):
382
class Int32TestCase(unittest.TestCase, LPTestCaseMixin):
385
strings = ["a", "b" * 16]
386
illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
387
partial_strings = ["\x00\x00\x00", "hello there", ""]
389
def testPartial(self):
390
for s in self.partial_strings:
391
r = self.getProtocol()
392
r.MAX_LENGTH = 99999999
395
self.assertEquals(r.received, [])
397
def testReceive(self):
398
r = self.getProtocol()
399
for s in self.strings:
400
for c in struct.pack("!i",len(s))+s:
402
self.assertEquals(r.received, self.strings)
405
class OnlyProducerTransport(object):
406
# Transport which isn't really a transport, just looks like one to
407
# someone not looking very hard.
410
disconnecting = False
415
def pauseProducing(self):
418
def resumeProducing(self):
421
def write(self, bytes):
422
self.data.append(bytes)
425
class ConsumingProtocol(basic.LineReceiver):
426
# Protocol that really, really doesn't want any more bytes.
428
def lineReceived(self, line):
429
self.transport.write(line)
430
self.pauseProducing()
433
class ProducerTestCase(unittest.TestCase):
434
def testPauseResume(self):
435
p = ConsumingProtocol()
436
t = OnlyProducerTransport()
439
p.dataReceived('hello, ')
441
self.failIf(t.paused)
442
self.failIf(p.paused)
444
p.dataReceived('world\r\n')
446
self.assertEquals(t.data, ['hello, world'])
447
self.failUnless(t.paused)
448
self.failUnless(p.paused)
452
self.failIf(t.paused)
453
self.failIf(p.paused)
455
p.dataReceived('hello\r\nworld\r\n')
457
self.assertEquals(t.data, ['hello, world', 'hello'])
458
self.failUnless(t.paused)
459
self.failUnless(p.paused)
462
p.dataReceived('goodbye\r\n')
464
self.assertEquals(t.data, ['hello, world', 'hello', 'world'])
465
self.failUnless(t.paused)
466
self.failUnless(p.paused)
470
self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
471
self.failUnless(t.paused)
472
self.failUnless(p.paused)
476
self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
477
self.failIf(t.paused)
478
self.failIf(p.paused)
481
class Portforwarding(unittest.TestCase):
483
Test port forwarding.
486
self.serverProtocol = wire.Echo()
487
self.clientProtocol = protocol.Protocol()
492
self.clientProtocol.transport.loseConnection()
496
self.serverProtocol.transport.loseConnection()
499
return defer.gatherResults(
500
[defer.maybeDeferred(p.stopListening) for p in self.openPorts])
502
def testPortforward(self):
504
Test port forwarding through Echo protocol.
506
realServerFactory = protocol.ServerFactory()
507
realServerFactory.protocol = lambda: self.serverProtocol
508
realServerPort = reactor.listenTCP(0, realServerFactory,
509
interface='127.0.0.1')
510
self.openPorts.append(realServerPort)
512
proxyServerFactory = portforward.ProxyFactory('127.0.0.1',
513
realServerPort.getHost().port)
514
proxyServerPort = reactor.listenTCP(0, proxyServerFactory,
515
interface='127.0.0.1')
516
self.openPorts.append(proxyServerPort)
521
def testDataReceived(data):
522
received.extend(data)
523
if len(received) >= nBytes:
524
self.assertEquals(''.join(received), 'x' * nBytes)
526
self.clientProtocol.dataReceived = testDataReceived
528
def testConnectionMade():
529
self.clientProtocol.transport.write('x' * nBytes)
530
self.clientProtocol.connectionMade = testConnectionMade
532
clientFactory = protocol.ClientFactory()
533
clientFactory.protocol = lambda: self.clientProtocol
536
'127.0.0.1', proxyServerPort.getHost().port, clientFactory)