9
9
from twisted.trial import unittest
10
10
from twisted.protocols import basic, wire, portforward
11
from twisted.internet import reactor, protocol, defer
11
from twisted.internet import reactor, protocol, defer, task, error
16
16
class StringIOWithoutClosing(StringIO.StringIO):
18
A StringIO that can't be closed.
20
25
class LineTester(basic.LineReceiver):
27
A line receiver that parses data received and make actions on some tokens.
29
@type delimiter: C{str}
30
@ivar delimiter: character used between received lines.
31
@type MAX_LENGTH: C{int}
32
@ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
33
@type clock: L{twisted.internet.task.Clock}
34
@ivar clock: clock simulating reactor callLater. Pass it to constructor if
35
you want to use the pause/rawpause functionalities.
41
def __init__(self, clock=None):
43
If given, use a clock to make callLater calls.
25
47
def connectionMade(self):
49
Create/clean data received on connection.
28
53
def lineReceived(self, line):
55
Receive line and make some action for some tokens: pause, rawpause,
56
stop, len, produce, unproduce.
29
58
self.received.append(line)
32
61
elif line == 'pause':
33
62
self.pauseProducing()
34
reactor.callLater(0, self.resumeProducing)
63
self.clock.callLater(0, self.resumeProducing)
35
64
elif line == 'rawpause':
36
65
self.pauseProducing()
38
67
self.received.append('')
39
reactor.callLater(0, self.resumeProducing)
68
self.clock.callLater(0, self.resumeProducing)
40
69
elif line == 'stop':
41
70
self.stopProducing()
42
71
elif line[:4] == 'len ':
47
76
self.transport.unregisterProducer()
49
78
def rawDataReceived(self, data):
80
Read raw data, until the quantity specified by a previous 'len' line is
50
83
data, rest = data[:self.length], data[self.length:]
51
84
self.length = self.length - len(data)
52
85
self.received[-1] = self.received[-1] + data
54
87
self.setLineMode(rest)
56
89
def lineLengthExceeded(self, line):
57
if len(line) > self.MAX_LENGTH+1:
58
self.setLineMode(line[self.MAX_LENGTH+1:])
91
Adjust line mode when long lines received.
93
if len(line) > self.MAX_LENGTH + 1:
94
self.setLineMode(line[self.MAX_LENGTH + 1:])
61
97
class LineOnlyTester(basic.LineOnlyReceiver):
99
A buffering line only receiver.
66
104
def connectionMade(self):
106
Create/clean data received on connection.
67
108
self.received = []
69
110
def lineReceived(self, line):
70
114
self.received.append(line)
72
116
class WireTestCase(unittest.TestCase):
74
120
def testEcho(self):
122
Test wire.Echo protocol: send some data and check it send it back.
75
124
t = StringIOWithoutClosing()
77
126
a.makeConnection(protocol.FileWrapper(t))
83
132
self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
85
134
def testWho(self):
136
Test wire.Who protocol.
86
138
t = StringIOWithoutClosing()
88
140
a.makeConnection(protocol.FileWrapper(t))
89
141
self.failUnlessEqual(t.getvalue(), "root\r\n")
91
143
def testQOTD(self):
145
Test wire.QOTD protocol.
92
147
t = StringIOWithoutClosing()
94
149
a.makeConnection(protocol.FileWrapper(t))
131
191
'len 0', 'foo 5', '', '67890', 'len 1', 'a']
133
193
def testBuffer(self):
195
Test buffering for different packet size, checking received matches
134
198
for packet_size in range(1, 10):
135
199
t = StringIOWithoutClosing()
147
211
pause_output2 = pause_output1+['twiddle3']
149
213
def testPausing(self):
215
Test pause inside data receiving. It uses fake clock to see if
216
pausing/resuming work.
150
218
for packet_size in range(1, 10):
151
219
t = StringIOWithoutClosing()
221
a = LineTester(clock)
153
222
a.makeConnection(protocol.FileWrapper(t))
154
223
for i in range(len(self.pause_buf)/packet_size + 1):
155
224
s = self.pause_buf[i*packet_size:(i+1)*packet_size]
156
225
a.dataReceived(s)
157
226
self.failUnlessEqual(self.pause_output1, a.received)
159
228
self.failUnlessEqual(self.pause_output2, a.received)
161
230
rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
163
232
rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
164
rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345', 'twiddle3']
233
rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
166
236
def testRawPausing(self):
238
Test pause inside raw date receiving.
167
240
for packet_size in range(1, 10):
168
241
t = StringIOWithoutClosing()
243
a = LineTester(clock)
170
244
a.makeConnection(protocol.FileWrapper(t))
171
245
for i in range(len(self.rawpause_buf)/packet_size + 1):
172
246
s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
173
247
a.dataReceived(s)
174
248
self.failUnlessEqual(self.rawpause_output1, a.received)
176
250
self.failUnlessEqual(self.rawpause_output2, a.received)
178
252
stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
180
254
stop_output = ['twiddle1', 'twiddle2', 'stop']
181
256
def testStopProducing(self):
258
Test stop inside producing.
182
260
for packet_size in range(1, 10):
183
261
t = StringIOWithoutClosing()
192
270
def testLineReceiverAsProducer(self):
272
Test produce/unproduce in receiving.
194
275
t = StringIOWithoutClosing()
195
276
a.makeConnection(protocol.FileWrapper(t))
196
277
a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
197
self.assertEquals(a.received, ['produce', 'hello world', 'unproduce', 'goodbye'])
278
self.assertEquals(a.received,
279
['produce', 'hello world', 'unproduce', 'goodbye'])
200
282
class LineOnlyReceiverTestCase(unittest.TestCase):
284
Test line only receiveer.
214
301
self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])
216
303
def testLineTooLong(self):
305
Test sending a line too long: it should close the connection.
217
307
t = StringIOWithoutClosing()
218
308
a = LineOnlyTester()
219
309
a.makeConnection(protocol.FileWrapper(t))
220
310
res = a.dataReceived('x'*200)
221
self.failIfEqual(res, None)
311
self.assertTrue(isinstance(res, error.ConnectionLost))
226
316
def connectionMade(self):
227
317
self.received = []
264
354
strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
266
illegal_strings = ['9999999999999999999999', 'abc', '4:abcde',
267
'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
357
'9999999999999999999999', 'abc', '4:abcde',
358
'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
269
360
protocol = TestNetstring
271
362
def testBuffer(self):
272
363
for packet_size in range(1, 10):
273
364
t = StringIOWithoutClosing()
390
481
class Portforwarding(unittest.TestCase):
483
Test port forwarding.
486
self.serverProtocol = wire.Echo()
487
self.clientProtocol = protocol.Protocol()
492
self.clientProtocol.transport.loseConnection()
496
self.serverProtocol.transport.loseConnection()
499
return defer.gatherResults(
500
[defer.maybeDeferred(p.stopListening) for p in self.openPorts])
391
502
def testPortforward(self):
392
serverProtocol = wire.Echo()
504
Test port forwarding through Echo protocol.
393
506
realServerFactory = protocol.ServerFactory()
394
realServerFactory.protocol = lambda: serverProtocol
507
realServerFactory.protocol = lambda: self.serverProtocol
395
508
realServerPort = reactor.listenTCP(0, realServerFactory,
396
509
interface='127.0.0.1')
510
self.openPorts.append(realServerPort)
398
512
proxyServerFactory = portforward.ProxyFactory('127.0.0.1',
399
realServerPort.getHost().port)
513
realServerPort.getHost().port)
400
514
proxyServerPort = reactor.listenTCP(0, proxyServerFactory,
401
515
interface='127.0.0.1')
516
self.openPorts.append(proxyServerPort)
405
clientProtocol = protocol.Protocol()
406
clientProtocol.dataReceived = received.extend
407
clientProtocol.connectionMade = lambda: clientProtocol.transport.write('x' * nBytes)
521
def testDataReceived(data):
522
received.extend(data)
523
if len(received) >= nBytes:
524
self.assertEquals(''.join(received), 'x' * nBytes)
526
self.clientProtocol.dataReceived = testDataReceived
528
def testConnectionMade():
529
self.clientProtocol.transport.write('x' * nBytes)
530
self.clientProtocol.connectionMade = testConnectionMade
408
532
clientFactory = protocol.ClientFactory()
409
clientFactory.protocol = lambda: clientProtocol
411
reactor.connectTCP('127.0.0.1', proxyServerPort.getHost().port,
415
while len(received) < nBytes and c < 100:
416
reactor.iterate(0.01)
419
self.assertEquals(''.join(received), 'x' * nBytes)
421
clientProtocol.transport.loseConnection()
422
serverProtocol.transport.loseConnection()
423
return defer.gatherResults([
424
defer.maybeDeferred(realServerPort.stopListening),
425
defer.maybeDeferred(proxyServerPort.stopListening)])
533
clientFactory.protocol = lambda: self.clientProtocol
536
'127.0.0.1', proxyServerPort.getHost().port, clientFactory)