~certify-web-dev/twisted/certify-trunk

« back to all changes in this revision

Viewing changes to twisted/test/test_protocols.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-17 14:52:35 UTC
  • mfrom: (1.1.5 upstream) (2.1.2 etch)
  • Revision ID: james.westby@ubuntu.com-20070117145235-btmig6qfmqfen0om
Tags: 2.5.0-0ubuntu1
New upstream version, compatible with python2.5.

Show diffs side-by-side

added added

removed removed

Lines of Context:
8
8
 
9
9
from twisted.trial import unittest
10
10
from twisted.protocols import basic, wire, portforward
11
 
from twisted.internet import reactor, protocol, defer
 
11
from twisted.internet import reactor, protocol, defer, task, error
12
12
 
13
 
import string, struct
 
13
import struct
14
14
import StringIO
15
15
 
16
16
class StringIOWithoutClosing(StringIO.StringIO):
 
17
    """
 
18
    A StringIO that can't be closed.
 
19
    """
17
20
    def close(self):
18
 
        pass
 
21
        """
 
22
        Do nothing.
 
23
        """
19
24
 
20
25
class LineTester(basic.LineReceiver):
 
26
    """
 
27
    A line receiver that parses data received and make actions on some tokens.
 
28
 
 
29
    @type delimiter: C{str}
 
30
    @ivar delimiter: character used between received lines.
 
31
    @type MAX_LENGTH: C{int}
 
32
    @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
 
33
    @type clock: L{twisted.internet.task.Clock}
 
34
    @ivar clock: clock simulating reactor callLater. Pass it to constructor if
 
35
        you want to use the pause/rawpause functionalities.
 
36
    """
21
37
 
22
38
    delimiter = '\n'
23
39
    MAX_LENGTH = 64
24
40
 
 
41
    def __init__(self, clock=None):
 
42
        """
 
43
        If given, use a clock to make callLater calls.
 
44
        """
 
45
        self.clock = clock
 
46
 
25
47
    def connectionMade(self):
 
48
        """
 
49
        Create/clean data received on connection.
 
50
        """
26
51
        self.received = []
27
52
 
28
53
    def lineReceived(self, line):
 
54
        """
 
55
        Receive line and make some action for some tokens: pause, rawpause,
 
56
        stop, len, produce, unproduce.
 
57
        """
29
58
        self.received.append(line)
30
59
        if line == '':
31
60
            self.setRawMode()
32
61
        elif line == 'pause':
33
62
            self.pauseProducing()
34
 
            reactor.callLater(0, self.resumeProducing)
 
63
            self.clock.callLater(0, self.resumeProducing)
35
64
        elif line == 'rawpause':
36
65
            self.pauseProducing()
37
66
            self.setRawMode()
38
67
            self.received.append('')
39
 
            reactor.callLater(0, self.resumeProducing)
 
68
            self.clock.callLater(0, self.resumeProducing)
40
69
        elif line == 'stop':
41
70
            self.stopProducing()
42
71
        elif line[:4] == 'len ':
47
76
            self.transport.unregisterProducer()
48
77
 
49
78
    def rawDataReceived(self, data):
 
79
        """
 
80
        Read raw data, until the quantity specified by a previous 'len' line is
 
81
        reached.
 
82
        """
50
83
        data, rest = data[:self.length], data[self.length:]
51
84
        self.length = self.length - len(data)
52
85
        self.received[-1] = self.received[-1] + data
54
87
            self.setLineMode(rest)
55
88
 
56
89
    def lineLengthExceeded(self, line):
57
 
        if len(line) > self.MAX_LENGTH+1:
58
 
            self.setLineMode(line[self.MAX_LENGTH+1:])
 
90
        """
 
91
        Adjust line mode when long lines received.
 
92
        """
 
93
        if len(line) > self.MAX_LENGTH + 1:
 
94
            self.setLineMode(line[self.MAX_LENGTH + 1:])
59
95
 
60
96
 
61
97
class LineOnlyTester(basic.LineOnlyReceiver):
62
 
 
 
98
    """
 
99
    A buffering line only receiver.
 
100
    """
63
101
    delimiter = '\n'
64
102
    MAX_LENGTH = 64
65
103
 
66
104
    def connectionMade(self):
 
105
        """
 
106
        Create/clean data received on connection.
 
107
        """
67
108
        self.received = []
68
109
 
69
110
    def lineReceived(self, line):
 
111
        """
 
112
        Save received data.
 
113
        """
70
114
        self.received.append(line)
71
115
 
72
116
class WireTestCase(unittest.TestCase):
73
 
 
 
117
    """
 
118
    Test wire protocols.
 
119
    """
74
120
    def testEcho(self):
 
121
        """
 
122
        Test wire.Echo protocol: send some data and check it send it back.
 
123
        """
75
124
        t = StringIOWithoutClosing()
76
125
        a = wire.Echo()
77
126
        a.makeConnection(protocol.FileWrapper(t))
83
132
        self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
84
133
 
85
134
    def testWho(self):
 
135
        """
 
136
        Test wire.Who protocol.
 
137
        """
86
138
        t = StringIOWithoutClosing()
87
139
        a = wire.Who()
88
140
        a.makeConnection(protocol.FileWrapper(t))
89
141
        self.failUnlessEqual(t.getvalue(), "root\r\n")
90
142
 
91
143
    def testQOTD(self):
 
144
        """
 
145
        Test wire.QOTD protocol.
 
146
        """
92
147
        t = StringIOWithoutClosing()
93
148
        a = wire.QOTD()
94
149
        a.makeConnection(protocol.FileWrapper(t))
96
151
                             "An apple a day keeps the doctor away.\r\n")
97
152
 
98
153
    def testDiscard(self):
 
154
        """
 
155
        Test wire.Discard protocol.
 
156
        """
99
157
        t = StringIOWithoutClosing()
100
158
        a = wire.Discard()
101
159
        a.makeConnection(protocol.FileWrapper(t))
107
165
        self.failUnlessEqual(t.getvalue(), "")
108
166
 
109
167
class LineReceiverTestCase(unittest.TestCase):
110
 
 
 
168
    """
 
169
    Test LineReceiver, using the C{LineTester} wrapper.
 
170
    """
111
171
    buffer = '''\
112
172
len 10
113
173
 
131
191
              'len 0', 'foo 5', '', '67890', 'len 1', 'a']
132
192
 
133
193
    def testBuffer(self):
 
194
        """
 
195
        Test buffering for different packet size, checking received matches
 
196
        expected data.
 
197
        """
134
198
        for packet_size in range(1, 10):
135
199
            t = StringIOWithoutClosing()
136
200
            a = LineTester()
147
211
    pause_output2 = pause_output1+['twiddle3']
148
212
 
149
213
    def testPausing(self):
 
214
        """
 
215
        Test pause inside data receiving. It uses fake clock to see if
 
216
        pausing/resuming work.
 
217
        """
150
218
        for packet_size in range(1, 10):
151
219
            t = StringIOWithoutClosing()
152
 
            a = LineTester()
 
220
            clock = task.Clock()
 
221
            a = LineTester(clock)
153
222
            a.makeConnection(protocol.FileWrapper(t))
154
223
            for i in range(len(self.pause_buf)/packet_size + 1):
155
224
                s = self.pause_buf[i*packet_size:(i+1)*packet_size]
156
225
                a.dataReceived(s)
157
226
            self.failUnlessEqual(self.pause_output1, a.received)
158
 
            reactor.iterate(0)
 
227
            clock.advance(0)
159
228
            self.failUnlessEqual(self.pause_output2, a.received)
160
229
 
161
230
    rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
162
231
 
163
232
    rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
164
 
    rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345', 'twiddle3']
 
233
    rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
 
234
                        'twiddle3']
165
235
 
166
236
    def testRawPausing(self):
 
237
        """
 
238
        Test pause inside raw date receiving.
 
239
        """
167
240
        for packet_size in range(1, 10):
168
241
            t = StringIOWithoutClosing()
169
 
            a = LineTester()
 
242
            clock = task.Clock()
 
243
            a = LineTester(clock)
170
244
            a.makeConnection(protocol.FileWrapper(t))
171
245
            for i in range(len(self.rawpause_buf)/packet_size + 1):
172
246
                s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
173
247
                a.dataReceived(s)
174
248
            self.failUnlessEqual(self.rawpause_output1, a.received)
175
 
            reactor.iterate(0)
 
249
            clock.advance(0)
176
250
            self.failUnlessEqual(self.rawpause_output2, a.received)
177
251
 
178
252
    stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
179
253
 
180
254
    stop_output = ['twiddle1', 'twiddle2', 'stop']
 
255
 
181
256
    def testStopProducing(self):
 
257
        """
 
258
        Test stop inside producing.
 
259
        """
182
260
        for packet_size in range(1, 10):
183
261
            t = StringIOWithoutClosing()
184
262
            a = LineTester()
190
268
 
191
269
 
192
270
    def testLineReceiverAsProducer(self):
 
271
        """
 
272
        Test produce/unproduce in receiving.
 
273
        """
193
274
        a = LineTester()
194
275
        t = StringIOWithoutClosing()
195
276
        a.makeConnection(protocol.FileWrapper(t))
196
277
        a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
197
 
        self.assertEquals(a.received, ['produce', 'hello world', 'unproduce', 'goodbye'])
 
278
        self.assertEquals(a.received,
 
279
                          ['produce', 'hello world', 'unproduce', 'goodbye'])
198
280
 
199
281
 
200
282
class LineOnlyReceiverTestCase(unittest.TestCase):
201
 
 
 
283
    """
 
284
    Test line only receiveer.
 
285
    """
202
286
    buffer = """foo
203
287
    bleakness
204
288
    desolation
206
290
    """
207
291
 
208
292
    def testBuffer(self):
 
293
        """
 
294
        Test buffering over line protocol: data received should match buffer.
 
295
        """
209
296
        t = StringIOWithoutClosing()
210
297
        a = LineOnlyTester()
211
298
        a.makeConnection(protocol.FileWrapper(t))
214
301
        self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])
215
302
 
216
303
    def testLineTooLong(self):
 
304
        """
 
305
        Test sending a line too long: it should close the connection.
 
306
        """
217
307
        t = StringIOWithoutClosing()
218
308
        a = LineOnlyTester()
219
309
        a.makeConnection(protocol.FileWrapper(t))
220
310
        res = a.dataReceived('x'*200)
221
 
        self.failIfEqual(res, None)
222
 
            
223
 
                
 
311
        self.assertTrue(isinstance(res, error.ConnectionLost))
 
312
 
 
313
 
224
314
class TestMixin:
225
 
    
 
315
 
226
316
    def connectionMade(self):
227
317
        self.received = []
228
318
 
250
340
        a = self.protocol()
251
341
        a.makeConnection(protocol.FileWrapper(t))
252
342
        return a
253
 
    
 
343
 
254
344
    def testIllegal(self):
255
345
        for s in self.illegal_strings:
256
346
            r = self.getProtocol()
263
353
 
264
354
    strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
265
355
 
266
 
    illegal_strings = ['9999999999999999999999', 'abc', '4:abcde',
267
 
                       '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
 
356
    illegal_strings = [
 
357
        '9999999999999999999999', 'abc', '4:abcde',
 
358
        '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
268
359
 
269
360
    protocol = TestNetstring
270
 
    
 
361
 
271
362
    def testBuffer(self):
272
363
        for packet_size in range(1, 10):
273
364
            t = StringIOWithoutClosing()
294
385
    strings = ["a", "b" * 16]
295
386
    illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
296
387
    partial_strings = ["\x00\x00\x00", "hello there", ""]
297
 
    
 
388
 
298
389
    def testPartial(self):
299
390
        for s in self.partial_strings:
300
391
            r = self.getProtocol()
388
479
 
389
480
 
390
481
class Portforwarding(unittest.TestCase):
 
482
    """
 
483
    Test port forwarding.
 
484
    """
 
485
    def setUp(self):
 
486
        self.serverProtocol = wire.Echo()
 
487
        self.clientProtocol = protocol.Protocol()
 
488
        self.openPorts = []
 
489
 
 
490
    def tearDown(self):
 
491
        try:
 
492
            self.clientProtocol.transport.loseConnection()
 
493
        except:
 
494
            pass
 
495
        try:
 
496
            self.serverProtocol.transport.loseConnection()
 
497
        except:
 
498
            pass
 
499
        return defer.gatherResults(
 
500
            [defer.maybeDeferred(p.stopListening) for p in self.openPorts])
 
501
 
391
502
    def testPortforward(self):
392
 
        serverProtocol = wire.Echo()
 
503
        """
 
504
        Test port forwarding through Echo protocol.
 
505
        """
393
506
        realServerFactory = protocol.ServerFactory()
394
 
        realServerFactory.protocol = lambda: serverProtocol
 
507
        realServerFactory.protocol = lambda: self.serverProtocol
395
508
        realServerPort = reactor.listenTCP(0, realServerFactory,
396
509
                                           interface='127.0.0.1')
 
510
        self.openPorts.append(realServerPort)
397
511
 
398
512
        proxyServerFactory = portforward.ProxyFactory('127.0.0.1',
399
 
                                                      realServerPort.getHost().port)
 
513
                                realServerPort.getHost().port)
400
514
        proxyServerPort = reactor.listenTCP(0, proxyServerFactory,
401
515
                                            interface='127.0.0.1')
 
516
        self.openPorts.append(proxyServerPort)
402
517
 
403
518
        nBytes = 1000
404
519
        received = []
405
 
        clientProtocol = protocol.Protocol()
406
 
        clientProtocol.dataReceived = received.extend
407
 
        clientProtocol.connectionMade = lambda: clientProtocol.transport.write('x' * nBytes)
 
520
        d = defer.Deferred()
 
521
        def testDataReceived(data):
 
522
            received.extend(data)
 
523
            if len(received) >= nBytes:
 
524
                self.assertEquals(''.join(received), 'x' * nBytes)
 
525
                d.callback(None)
 
526
        self.clientProtocol.dataReceived = testDataReceived
 
527
 
 
528
        def testConnectionMade():
 
529
            self.clientProtocol.transport.write('x' * nBytes)
 
530
        self.clientProtocol.connectionMade = testConnectionMade
 
531
 
408
532
        clientFactory = protocol.ClientFactory()
409
 
        clientFactory.protocol = lambda: clientProtocol
410
 
 
411
 
        reactor.connectTCP('127.0.0.1', proxyServerPort.getHost().port,
412
 
                           clientFactory)
413
 
 
414
 
        c = 0
415
 
        while len(received) < nBytes and c < 100:
416
 
            reactor.iterate(0.01)
417
 
            c += 1
418
 
 
419
 
        self.assertEquals(''.join(received), 'x' * nBytes)
420
 
        
421
 
        clientProtocol.transport.loseConnection()
422
 
        serverProtocol.transport.loseConnection()
423
 
        return defer.gatherResults([
424
 
            defer.maybeDeferred(realServerPort.stopListening),
425
 
            defer.maybeDeferred(proxyServerPort.stopListening)])
 
533
        clientFactory.protocol = lambda: self.clientProtocol
 
534
 
 
535
        reactor.connectTCP(
 
536
            '127.0.0.1', proxyServerPort.getHost().port, clientFactory)
 
537
 
 
538
        return d
 
539