~landscape/zope3/newer-from-ztk

« back to all changes in this revision

Viewing changes to src/twisted/test/test_protocols.py

  • Committer: Thomas Hervé
  • Date: 2009-07-08 13:52:04 UTC
  • Revision ID: thomas@canonical.com-20090708135204-df5eesrthifpylf8
Remove twisted copy

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
2
 
# See LICENSE for details.
3
 
 
4
 
 
5
 
"""
6
 
Test cases for twisted.protocols package.
7
 
"""
8
 
 
9
 
from twisted.trial import unittest
10
 
from twisted.protocols import basic, wire, portforward
11
 
from twisted.internet import reactor, protocol, defer, task, error
12
 
 
13
 
import struct
14
 
import StringIO
15
 
 
16
 
class StringIOWithoutClosing(StringIO.StringIO):
17
 
    """
18
 
    A StringIO that can't be closed.
19
 
    """
20
 
    def close(self):
21
 
        """
22
 
        Do nothing.
23
 
        """
24
 
 
25
 
class LineTester(basic.LineReceiver):
26
 
    """
27
 
    A line receiver that parses data received and make actions on some tokens.
28
 
 
29
 
    @type delimiter: C{str}
30
 
    @ivar delimiter: character used between received lines.
31
 
    @type MAX_LENGTH: C{int}
32
 
    @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
33
 
    @type clock: L{twisted.internet.task.Clock}
34
 
    @ivar clock: clock simulating reactor callLater. Pass it to constructor if
35
 
        you want to use the pause/rawpause functionalities.
36
 
    """
37
 
 
38
 
    delimiter = '\n'
39
 
    MAX_LENGTH = 64
40
 
 
41
 
    def __init__(self, clock=None):
42
 
        """
43
 
        If given, use a clock to make callLater calls.
44
 
        """
45
 
        self.clock = clock
46
 
 
47
 
    def connectionMade(self):
48
 
        """
49
 
        Create/clean data received on connection.
50
 
        """
51
 
        self.received = []
52
 
 
53
 
    def lineReceived(self, line):
54
 
        """
55
 
        Receive line and make some action for some tokens: pause, rawpause,
56
 
        stop, len, produce, unproduce.
57
 
        """
58
 
        self.received.append(line)
59
 
        if line == '':
60
 
            self.setRawMode()
61
 
        elif line == 'pause':
62
 
            self.pauseProducing()
63
 
            self.clock.callLater(0, self.resumeProducing)
64
 
        elif line == 'rawpause':
65
 
            self.pauseProducing()
66
 
            self.setRawMode()
67
 
            self.received.append('')
68
 
            self.clock.callLater(0, self.resumeProducing)
69
 
        elif line == 'stop':
70
 
            self.stopProducing()
71
 
        elif line[:4] == 'len ':
72
 
            self.length = int(line[4:])
73
 
        elif line.startswith('produce'):
74
 
            self.transport.registerProducer(self, False)
75
 
        elif line.startswith('unproduce'):
76
 
            self.transport.unregisterProducer()
77
 
 
78
 
    def rawDataReceived(self, data):
79
 
        """
80
 
        Read raw data, until the quantity specified by a previous 'len' line is
81
 
        reached.
82
 
        """
83
 
        data, rest = data[:self.length], data[self.length:]
84
 
        self.length = self.length - len(data)
85
 
        self.received[-1] = self.received[-1] + data
86
 
        if self.length == 0:
87
 
            self.setLineMode(rest)
88
 
 
89
 
    def lineLengthExceeded(self, line):
90
 
        """
91
 
        Adjust line mode when long lines received.
92
 
        """
93
 
        if len(line) > self.MAX_LENGTH + 1:
94
 
            self.setLineMode(line[self.MAX_LENGTH + 1:])
95
 
 
96
 
 
97
 
class LineOnlyTester(basic.LineOnlyReceiver):
98
 
    """
99
 
    A buffering line only receiver.
100
 
    """
101
 
    delimiter = '\n'
102
 
    MAX_LENGTH = 64
103
 
 
104
 
    def connectionMade(self):
105
 
        """
106
 
        Create/clean data received on connection.
107
 
        """
108
 
        self.received = []
109
 
 
110
 
    def lineReceived(self, line):
111
 
        """
112
 
        Save received data.
113
 
        """
114
 
        self.received.append(line)
115
 
 
116
 
class WireTestCase(unittest.TestCase):
117
 
    """
118
 
    Test wire protocols.
119
 
    """
120
 
    def testEcho(self):
121
 
        """
122
 
        Test wire.Echo protocol: send some data and check it send it back.
123
 
        """
124
 
        t = StringIOWithoutClosing()
125
 
        a = wire.Echo()
126
 
        a.makeConnection(protocol.FileWrapper(t))
127
 
        a.dataReceived("hello")
128
 
        a.dataReceived("world")
129
 
        a.dataReceived("how")
130
 
        a.dataReceived("are")
131
 
        a.dataReceived("you")
132
 
        self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
133
 
 
134
 
    def testWho(self):
135
 
        """
136
 
        Test wire.Who protocol.
137
 
        """
138
 
        t = StringIOWithoutClosing()
139
 
        a = wire.Who()
140
 
        a.makeConnection(protocol.FileWrapper(t))
141
 
        self.failUnlessEqual(t.getvalue(), "root\r\n")
142
 
 
143
 
    def testQOTD(self):
144
 
        """
145
 
        Test wire.QOTD protocol.
146
 
        """
147
 
        t = StringIOWithoutClosing()
148
 
        a = wire.QOTD()
149
 
        a.makeConnection(protocol.FileWrapper(t))
150
 
        self.failUnlessEqual(t.getvalue(),
151
 
                             "An apple a day keeps the doctor away.\r\n")
152
 
 
153
 
    def testDiscard(self):
154
 
        """
155
 
        Test wire.Discard protocol.
156
 
        """
157
 
        t = StringIOWithoutClosing()
158
 
        a = wire.Discard()
159
 
        a.makeConnection(protocol.FileWrapper(t))
160
 
        a.dataReceived("hello")
161
 
        a.dataReceived("world")
162
 
        a.dataReceived("how")
163
 
        a.dataReceived("are")
164
 
        a.dataReceived("you")
165
 
        self.failUnlessEqual(t.getvalue(), "")
166
 
 
167
 
class LineReceiverTestCase(unittest.TestCase):
168
 
    """
169
 
    Test LineReceiver, using the C{LineTester} wrapper.
170
 
    """
171
 
    buffer = '''\
172
 
len 10
173
 
 
174
 
0123456789len 5
175
 
 
176
 
1234
177
 
len 20
178
 
foo 123
179
 
 
180
 
0123456789
181
 
012345678len 0
182
 
foo 5
183
 
 
184
 
1234567890123456789012345678901234567890123456789012345678901234567890
185
 
len 1
186
 
 
187
 
a'''
188
 
 
189
 
    output = ['len 10', '0123456789', 'len 5', '1234\n',
190
 
              'len 20', 'foo 123', '0123456789\n012345678',
191
 
              'len 0', 'foo 5', '', '67890', 'len 1', 'a']
192
 
 
193
 
    def testBuffer(self):
194
 
        """
195
 
        Test buffering for different packet size, checking received matches
196
 
        expected data.
197
 
        """
198
 
        for packet_size in range(1, 10):
199
 
            t = StringIOWithoutClosing()
200
 
            a = LineTester()
201
 
            a.makeConnection(protocol.FileWrapper(t))
202
 
            for i in range(len(self.buffer)/packet_size + 1):
203
 
                s = self.buffer[i*packet_size:(i+1)*packet_size]
204
 
                a.dataReceived(s)
205
 
            self.failUnlessEqual(self.output, a.received)
206
 
 
207
 
 
208
 
    pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n'
209
 
 
210
 
    pause_output1 = ['twiddle1', 'twiddle2', 'pause']
211
 
    pause_output2 = pause_output1+['twiddle3']
212
 
 
213
 
    def testPausing(self):
214
 
        """
215
 
        Test pause inside data receiving. It uses fake clock to see if
216
 
        pausing/resuming work.
217
 
        """
218
 
        for packet_size in range(1, 10):
219
 
            t = StringIOWithoutClosing()
220
 
            clock = task.Clock()
221
 
            a = LineTester(clock)
222
 
            a.makeConnection(protocol.FileWrapper(t))
223
 
            for i in range(len(self.pause_buf)/packet_size + 1):
224
 
                s = self.pause_buf[i*packet_size:(i+1)*packet_size]
225
 
                a.dataReceived(s)
226
 
            self.failUnlessEqual(self.pause_output1, a.received)
227
 
            clock.advance(0)
228
 
            self.failUnlessEqual(self.pause_output2, a.received)
229
 
 
230
 
    rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
231
 
 
232
 
    rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
233
 
    rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
234
 
                        'twiddle3']
235
 
 
236
 
    def testRawPausing(self):
237
 
        """
238
 
        Test pause inside raw date receiving.
239
 
        """
240
 
        for packet_size in range(1, 10):
241
 
            t = StringIOWithoutClosing()
242
 
            clock = task.Clock()
243
 
            a = LineTester(clock)
244
 
            a.makeConnection(protocol.FileWrapper(t))
245
 
            for i in range(len(self.rawpause_buf)/packet_size + 1):
246
 
                s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
247
 
                a.dataReceived(s)
248
 
            self.failUnlessEqual(self.rawpause_output1, a.received)
249
 
            clock.advance(0)
250
 
            self.failUnlessEqual(self.rawpause_output2, a.received)
251
 
 
252
 
    stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
253
 
 
254
 
    stop_output = ['twiddle1', 'twiddle2', 'stop']
255
 
 
256
 
    def testStopProducing(self):
257
 
        """
258
 
        Test stop inside producing.
259
 
        """
260
 
        for packet_size in range(1, 10):
261
 
            t = StringIOWithoutClosing()
262
 
            a = LineTester()
263
 
            a.makeConnection(protocol.FileWrapper(t))
264
 
            for i in range(len(self.stop_buf)/packet_size + 1):
265
 
                s = self.stop_buf[i*packet_size:(i+1)*packet_size]
266
 
                a.dataReceived(s)
267
 
            self.failUnlessEqual(self.stop_output, a.received)
268
 
 
269
 
 
270
 
    def testLineReceiverAsProducer(self):
271
 
        """
272
 
        Test produce/unproduce in receiving.
273
 
        """
274
 
        a = LineTester()
275
 
        t = StringIOWithoutClosing()
276
 
        a.makeConnection(protocol.FileWrapper(t))
277
 
        a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
278
 
        self.assertEquals(a.received,
279
 
                          ['produce', 'hello world', 'unproduce', 'goodbye'])
280
 
 
281
 
 
282
 
class LineOnlyReceiverTestCase(unittest.TestCase):
283
 
    """
284
 
    Test line only receiveer.
285
 
    """
286
 
    buffer = """foo
287
 
    bleakness
288
 
    desolation
289
 
    plastic forks
290
 
    """
291
 
 
292
 
    def testBuffer(self):
293
 
        """
294
 
        Test buffering over line protocol: data received should match buffer.
295
 
        """
296
 
        t = StringIOWithoutClosing()
297
 
        a = LineOnlyTester()
298
 
        a.makeConnection(protocol.FileWrapper(t))
299
 
        for c in self.buffer:
300
 
            a.dataReceived(c)
301
 
        self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])
302
 
 
303
 
    def testLineTooLong(self):
304
 
        """
305
 
        Test sending a line too long: it should close the connection.
306
 
        """
307
 
        t = StringIOWithoutClosing()
308
 
        a = LineOnlyTester()
309
 
        a.makeConnection(protocol.FileWrapper(t))
310
 
        res = a.dataReceived('x'*200)
311
 
        self.assertTrue(isinstance(res, error.ConnectionLost))
312
 
 
313
 
 
314
 
class TestMixin:
315
 
 
316
 
    def connectionMade(self):
317
 
        self.received = []
318
 
 
319
 
    def stringReceived(self, s):
320
 
        self.received.append(s)
321
 
 
322
 
    MAX_LENGTH = 50
323
 
    closed = 0
324
 
 
325
 
    def connectionLost(self, reason):
326
 
        self.closed = 1
327
 
 
328
 
 
329
 
class TestNetstring(TestMixin, basic.NetstringReceiver):
330
 
    pass
331
 
 
332
 
 
333
 
class LPTestCaseMixin:
334
 
 
335
 
    illegal_strings = []
336
 
    protocol = None
337
 
 
338
 
    def getProtocol(self):
339
 
        t = StringIOWithoutClosing()
340
 
        a = self.protocol()
341
 
        a.makeConnection(protocol.FileWrapper(t))
342
 
        return a
343
 
 
344
 
    def testIllegal(self):
345
 
        for s in self.illegal_strings:
346
 
            r = self.getProtocol()
347
 
            for c in s:
348
 
                r.dataReceived(c)
349
 
            self.assertEquals(r.transport.closed, 1)
350
 
 
351
 
 
352
 
class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):
353
 
 
354
 
    strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
355
 
 
356
 
    illegal_strings = [
357
 
        '9999999999999999999999', 'abc', '4:abcde',
358
 
        '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
359
 
 
360
 
    protocol = TestNetstring
361
 
 
362
 
    def testBuffer(self):
363
 
        for packet_size in range(1, 10):
364
 
            t = StringIOWithoutClosing()
365
 
            a = TestNetstring()
366
 
            a.MAX_LENGTH = 699
367
 
            a.makeConnection(protocol.FileWrapper(t))
368
 
            for s in self.strings:
369
 
                a.sendString(s)
370
 
            out = t.getvalue()
371
 
            for i in range(len(out)/packet_size + 1):
372
 
                s = out[i*packet_size:(i+1)*packet_size]
373
 
                if s:
374
 
                    a.dataReceived(s)
375
 
            self.assertEquals(a.received, self.strings)
376
 
 
377
 
 
378
 
class TestInt32(TestMixin, basic.Int32StringReceiver):
379
 
    MAX_LENGTH = 50
380
 
 
381
 
 
382
 
class Int32TestCase(unittest.TestCase, LPTestCaseMixin):
383
 
 
384
 
    protocol = TestInt32
385
 
    strings = ["a", "b" * 16]
386
 
    illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
387
 
    partial_strings = ["\x00\x00\x00", "hello there", ""]
388
 
 
389
 
    def testPartial(self):
390
 
        for s in self.partial_strings:
391
 
            r = self.getProtocol()
392
 
            r.MAX_LENGTH = 99999999
393
 
            for c in s:
394
 
                r.dataReceived(c)
395
 
            self.assertEquals(r.received, [])
396
 
 
397
 
    def testReceive(self):
398
 
        r = self.getProtocol()
399
 
        for s in self.strings:
400
 
            for c in struct.pack("!i",len(s))+s:
401
 
                r.dataReceived(c)
402
 
        self.assertEquals(r.received, self.strings)
403
 
 
404
 
 
405
 
class OnlyProducerTransport(object):
406
 
    # Transport which isn't really a transport, just looks like one to
407
 
    # someone not looking very hard.
408
 
 
409
 
    paused = False
410
 
    disconnecting = False
411
 
 
412
 
    def __init__(self):
413
 
        self.data = []
414
 
 
415
 
    def pauseProducing(self):
416
 
        self.paused = True
417
 
 
418
 
    def resumeProducing(self):
419
 
        self.paused = False
420
 
 
421
 
    def write(self, bytes):
422
 
        self.data.append(bytes)
423
 
 
424
 
 
425
 
class ConsumingProtocol(basic.LineReceiver):
426
 
    # Protocol that really, really doesn't want any more bytes.
427
 
 
428
 
    def lineReceived(self, line):
429
 
        self.transport.write(line)
430
 
        self.pauseProducing()
431
 
 
432
 
 
433
 
class ProducerTestCase(unittest.TestCase):
434
 
    def testPauseResume(self):
435
 
        p = ConsumingProtocol()
436
 
        t = OnlyProducerTransport()
437
 
        p.makeConnection(t)
438
 
 
439
 
        p.dataReceived('hello, ')
440
 
        self.failIf(t.data)
441
 
        self.failIf(t.paused)
442
 
        self.failIf(p.paused)
443
 
 
444
 
        p.dataReceived('world\r\n')
445
 
 
446
 
        self.assertEquals(t.data, ['hello, world'])
447
 
        self.failUnless(t.paused)
448
 
        self.failUnless(p.paused)
449
 
 
450
 
        p.resumeProducing()
451
 
 
452
 
        self.failIf(t.paused)
453
 
        self.failIf(p.paused)
454
 
 
455
 
        p.dataReceived('hello\r\nworld\r\n')
456
 
 
457
 
        self.assertEquals(t.data, ['hello, world', 'hello'])
458
 
        self.failUnless(t.paused)
459
 
        self.failUnless(p.paused)
460
 
 
461
 
        p.resumeProducing()
462
 
        p.dataReceived('goodbye\r\n')
463
 
 
464
 
        self.assertEquals(t.data, ['hello, world', 'hello', 'world'])
465
 
        self.failUnless(t.paused)
466
 
        self.failUnless(p.paused)
467
 
 
468
 
        p.resumeProducing()
469
 
 
470
 
        self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
471
 
        self.failUnless(t.paused)
472
 
        self.failUnless(p.paused)
473
 
 
474
 
        p.resumeProducing()
475
 
 
476
 
        self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
477
 
        self.failIf(t.paused)
478
 
        self.failIf(p.paused)
479
 
 
480
 
 
481
 
class Portforwarding(unittest.TestCase):
482
 
    """
483
 
    Test port forwarding.
484
 
    """
485
 
    def setUp(self):
486
 
        self.serverProtocol = wire.Echo()
487
 
        self.clientProtocol = protocol.Protocol()
488
 
        self.openPorts = []
489
 
 
490
 
    def tearDown(self):
491
 
        try:
492
 
            self.clientProtocol.transport.loseConnection()
493
 
        except:
494
 
            pass
495
 
        try:
496
 
            self.serverProtocol.transport.loseConnection()
497
 
        except:
498
 
            pass
499
 
        return defer.gatherResults(
500
 
            [defer.maybeDeferred(p.stopListening) for p in self.openPorts])
501
 
 
502
 
    def testPortforward(self):
503
 
        """
504
 
        Test port forwarding through Echo protocol.
505
 
        """
506
 
        realServerFactory = protocol.ServerFactory()
507
 
        realServerFactory.protocol = lambda: self.serverProtocol
508
 
        realServerPort = reactor.listenTCP(0, realServerFactory,
509
 
                                           interface='127.0.0.1')
510
 
        self.openPorts.append(realServerPort)
511
 
 
512
 
        proxyServerFactory = portforward.ProxyFactory('127.0.0.1',
513
 
                                realServerPort.getHost().port)
514
 
        proxyServerPort = reactor.listenTCP(0, proxyServerFactory,
515
 
                                            interface='127.0.0.1')
516
 
        self.openPorts.append(proxyServerPort)
517
 
 
518
 
        nBytes = 1000
519
 
        received = []
520
 
        d = defer.Deferred()
521
 
        def testDataReceived(data):
522
 
            received.extend(data)
523
 
            if len(received) >= nBytes:
524
 
                self.assertEquals(''.join(received), 'x' * nBytes)
525
 
                d.callback(None)
526
 
        self.clientProtocol.dataReceived = testDataReceived
527
 
 
528
 
        def testConnectionMade():
529
 
            self.clientProtocol.transport.write('x' * nBytes)
530
 
        self.clientProtocol.connectionMade = testConnectionMade
531
 
 
532
 
        clientFactory = protocol.ClientFactory()
533
 
        clientFactory.protocol = lambda: self.clientProtocol
534
 
 
535
 
        reactor.connectTCP(
536
 
            '127.0.0.1', proxyServerPort.getHost().port, clientFactory)
537
 
 
538
 
        return d
539