~0x44/nova/extdoc

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/test/test_protocols.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
 
2
# See LICENSE for details.
 
3
 
 
4
"""
 
5
Test cases for twisted.protocols package.
 
6
"""
 
7
 
 
8
import struct
 
9
 
 
10
from twisted.trial import unittest
 
11
from twisted.protocols import basic, wire, portforward
 
12
from twisted.internet import reactor, protocol, defer, task, error
 
13
from twisted.test import proto_helpers
 
14
 
 
15
 
 
16
class LineTester(basic.LineReceiver):
 
17
    """
 
18
    A line receiver that parses data received and make actions on some tokens.
 
19
 
 
20
    @type delimiter: C{str}
 
21
    @ivar delimiter: character used between received lines.
 
22
    @type MAX_LENGTH: C{int}
 
23
    @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called.
 
24
    @type clock: L{twisted.internet.task.Clock}
 
25
    @ivar clock: clock simulating reactor callLater. Pass it to constructor if
 
26
        you want to use the pause/rawpause functionalities.
 
27
    """
 
28
 
 
29
    delimiter = '\n'
 
30
    MAX_LENGTH = 64
 
31
 
 
32
    def __init__(self, clock=None):
 
33
        """
 
34
        If given, use a clock to make callLater calls.
 
35
        """
 
36
        self.clock = clock
 
37
 
 
38
    def connectionMade(self):
 
39
        """
 
40
        Create/clean data received on connection.
 
41
        """
 
42
        self.received = []
 
43
 
 
44
    def lineReceived(self, line):
 
45
        """
 
46
        Receive line and make some action for some tokens: pause, rawpause,
 
47
        stop, len, produce, unproduce.
 
48
        """
 
49
        self.received.append(line)
 
50
        if line == '':
 
51
            self.setRawMode()
 
52
        elif line == 'pause':
 
53
            self.pauseProducing()
 
54
            self.clock.callLater(0, self.resumeProducing)
 
55
        elif line == 'rawpause':
 
56
            self.pauseProducing()
 
57
            self.setRawMode()
 
58
            self.received.append('')
 
59
            self.clock.callLater(0, self.resumeProducing)
 
60
        elif line == 'stop':
 
61
            self.stopProducing()
 
62
        elif line[:4] == 'len ':
 
63
            self.length = int(line[4:])
 
64
        elif line.startswith('produce'):
 
65
            self.transport.registerProducer(self, False)
 
66
        elif line.startswith('unproduce'):
 
67
            self.transport.unregisterProducer()
 
68
 
 
69
    def rawDataReceived(self, data):
 
70
        """
 
71
        Read raw data, until the quantity specified by a previous 'len' line is
 
72
        reached.
 
73
        """
 
74
        data, rest = data[:self.length], data[self.length:]
 
75
        self.length = self.length - len(data)
 
76
        self.received[-1] = self.received[-1] + data
 
77
        if self.length == 0:
 
78
            self.setLineMode(rest)
 
79
 
 
80
    def lineLengthExceeded(self, line):
 
81
        """
 
82
        Adjust line mode when long lines received.
 
83
        """
 
84
        if len(line) > self.MAX_LENGTH + 1:
 
85
            self.setLineMode(line[self.MAX_LENGTH + 1:])
 
86
 
 
87
 
 
88
class LineOnlyTester(basic.LineOnlyReceiver):
 
89
    """
 
90
    A buffering line only receiver.
 
91
    """
 
92
    delimiter = '\n'
 
93
    MAX_LENGTH = 64
 
94
 
 
95
    def connectionMade(self):
 
96
        """
 
97
        Create/clean data received on connection.
 
98
        """
 
99
        self.received = []
 
100
 
 
101
    def lineReceived(self, line):
 
102
        """
 
103
        Save received data.
 
104
        """
 
105
        self.received.append(line)
 
106
 
 
107
class WireTestCase(unittest.TestCase):
 
108
    """
 
109
    Test wire protocols.
 
110
    """
 
111
    def test_echo(self):
 
112
        """
 
113
        Test wire.Echo protocol: send some data and check it send it back.
 
114
        """
 
115
        t = proto_helpers.StringTransport()
 
116
        a = wire.Echo()
 
117
        a.makeConnection(t)
 
118
        a.dataReceived("hello")
 
119
        a.dataReceived("world")
 
120
        a.dataReceived("how")
 
121
        a.dataReceived("are")
 
122
        a.dataReceived("you")
 
123
        self.assertEquals(t.value(), "helloworldhowareyou")
 
124
 
 
125
 
 
126
    def test_who(self):
 
127
        """
 
128
        Test wire.Who protocol.
 
129
        """
 
130
        t = proto_helpers.StringTransport()
 
131
        a = wire.Who()
 
132
        a.makeConnection(t)
 
133
        self.assertEquals(t.value(), "root\r\n")
 
134
 
 
135
 
 
136
    def test_QOTD(self):
 
137
        """
 
138
        Test wire.QOTD protocol.
 
139
        """
 
140
        t = proto_helpers.StringTransport()
 
141
        a = wire.QOTD()
 
142
        a.makeConnection(t)
 
143
        self.assertEquals(t.value(),
 
144
                          "An apple a day keeps the doctor away.\r\n")
 
145
 
 
146
 
 
147
    def test_discard(self):
 
148
        """
 
149
        Test wire.Discard protocol.
 
150
        """
 
151
        t = proto_helpers.StringTransport()
 
152
        a = wire.Discard()
 
153
        a.makeConnection(t)
 
154
        a.dataReceived("hello")
 
155
        a.dataReceived("world")
 
156
        a.dataReceived("how")
 
157
        a.dataReceived("are")
 
158
        a.dataReceived("you")
 
159
        self.assertEqual(t.value(), "")
 
160
 
 
161
 
 
162
 
 
163
class LineReceiverTestCase(unittest.TestCase):
 
164
    """
 
165
    Test LineReceiver, using the C{LineTester} wrapper.
 
166
    """
 
167
    buffer = '''\
 
168
len 10
 
169
 
 
170
0123456789len 5
 
171
 
 
172
1234
 
173
len 20
 
174
foo 123
 
175
 
 
176
0123456789
 
177
012345678len 0
 
178
foo 5
 
179
 
 
180
1234567890123456789012345678901234567890123456789012345678901234567890
 
181
len 1
 
182
 
 
183
a'''
 
184
 
 
185
    output = ['len 10', '0123456789', 'len 5', '1234\n',
 
186
              'len 20', 'foo 123', '0123456789\n012345678',
 
187
              'len 0', 'foo 5', '', '67890', 'len 1', 'a']
 
188
 
 
189
    def testBuffer(self):
 
190
        """
 
191
        Test buffering for different packet size, checking received matches
 
192
        expected data.
 
193
        """
 
194
        for packet_size in range(1, 10):
 
195
            t = proto_helpers.StringIOWithoutClosing()
 
196
            a = LineTester()
 
197
            a.makeConnection(protocol.FileWrapper(t))
 
198
            for i in range(len(self.buffer)/packet_size + 1):
 
199
                s = self.buffer[i*packet_size:(i+1)*packet_size]
 
200
                a.dataReceived(s)
 
201
            self.failUnlessEqual(self.output, a.received)
 
202
 
 
203
 
 
204
    pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n'
 
205
 
 
206
    pause_output1 = ['twiddle1', 'twiddle2', 'pause']
 
207
    pause_output2 = pause_output1+['twiddle3']
 
208
 
 
209
    def test_pausing(self):
 
210
        """
 
211
        Test pause inside data receiving. It uses fake clock to see if
 
212
        pausing/resuming work.
 
213
        """
 
214
        for packet_size in range(1, 10):
 
215
            t = proto_helpers.StringIOWithoutClosing()
 
216
            clock = task.Clock()
 
217
            a = LineTester(clock)
 
218
            a.makeConnection(protocol.FileWrapper(t))
 
219
            for i in range(len(self.pause_buf)/packet_size + 1):
 
220
                s = self.pause_buf[i*packet_size:(i+1)*packet_size]
 
221
                a.dataReceived(s)
 
222
            self.assertEquals(self.pause_output1, a.received)
 
223
            clock.advance(0)
 
224
            self.assertEquals(self.pause_output2, a.received)
 
225
 
 
226
    rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
 
227
 
 
228
    rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
 
229
    rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345',
 
230
                        'twiddle3']
 
231
 
 
232
    def test_rawPausing(self):
 
233
        """
 
234
        Test pause inside raw date receiving.
 
235
        """
 
236
        for packet_size in range(1, 10):
 
237
            t = proto_helpers.StringIOWithoutClosing()
 
238
            clock = task.Clock()
 
239
            a = LineTester(clock)
 
240
            a.makeConnection(protocol.FileWrapper(t))
 
241
            for i in range(len(self.rawpause_buf)/packet_size + 1):
 
242
                s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
 
243
                a.dataReceived(s)
 
244
            self.assertEquals(self.rawpause_output1, a.received)
 
245
            clock.advance(0)
 
246
            self.assertEquals(self.rawpause_output2, a.received)
 
247
 
 
248
    stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
 
249
 
 
250
    stop_output = ['twiddle1', 'twiddle2', 'stop']
 
251
 
 
252
    def test_stopProducing(self):
 
253
        """
 
254
        Test stop inside producing.
 
255
        """
 
256
        for packet_size in range(1, 10):
 
257
            t = proto_helpers.StringIOWithoutClosing()
 
258
            a = LineTester()
 
259
            a.makeConnection(protocol.FileWrapper(t))
 
260
            for i in range(len(self.stop_buf)/packet_size + 1):
 
261
                s = self.stop_buf[i*packet_size:(i+1)*packet_size]
 
262
                a.dataReceived(s)
 
263
            self.assertEquals(self.stop_output, a.received)
 
264
 
 
265
 
 
266
    def test_lineReceiverAsProducer(self):
 
267
        """
 
268
        Test produce/unproduce in receiving.
 
269
        """
 
270
        a = LineTester()
 
271
        t = proto_helpers.StringIOWithoutClosing()
 
272
        a.makeConnection(protocol.FileWrapper(t))
 
273
        a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
 
274
        self.assertEquals(a.received,
 
275
                          ['produce', 'hello world', 'unproduce', 'goodbye'])
 
276
 
 
277
 
 
278
    def test_clearLineBuffer(self):
 
279
        """
 
280
        L{LineReceiver.clearLineBuffer} removes all buffered data and returns
 
281
        it as a C{str} and can be called from beneath C{dataReceived}.
 
282
        """
 
283
        class ClearingReceiver(basic.LineReceiver):
 
284
            def lineReceived(self, line):
 
285
                self.line = line
 
286
                self.rest = self.clearLineBuffer()
 
287
 
 
288
        protocol = ClearingReceiver()
 
289
        protocol.dataReceived('foo\r\nbar\r\nbaz')
 
290
        self.assertEqual(protocol.line, 'foo')
 
291
        self.assertEqual(protocol.rest, 'bar\r\nbaz')
 
292
 
 
293
        # Deliver another line to make sure the previously buffered data is
 
294
        # really gone.
 
295
        protocol.dataReceived('quux\r\n')
 
296
        self.assertEqual(protocol.line, 'quux')
 
297
        self.assertEqual(protocol.rest, '')
 
298
 
 
299
 
 
300
 
 
301
class LineOnlyReceiverTestCase(unittest.TestCase):
 
302
    """
 
303
    Test line only receiveer.
 
304
    """
 
305
    buffer = """foo
 
306
    bleakness
 
307
    desolation
 
308
    plastic forks
 
309
    """
 
310
 
 
311
    def test_buffer(self):
 
312
        """
 
313
        Test buffering over line protocol: data received should match buffer.
 
314
        """
 
315
        t = proto_helpers.StringTransport()
 
316
        a = LineOnlyTester()
 
317
        a.makeConnection(t)
 
318
        for c in self.buffer:
 
319
            a.dataReceived(c)
 
320
        self.assertEquals(a.received, self.buffer.split('\n')[:-1])
 
321
 
 
322
    def test_lineTooLong(self):
 
323
        """
 
324
        Test sending a line too long: it should close the connection.
 
325
        """
 
326
        t = proto_helpers.StringTransport()
 
327
        a = LineOnlyTester()
 
328
        a.makeConnection(t)
 
329
        res = a.dataReceived('x'*200)
 
330
        self.assertIsInstance(res, error.ConnectionLost)
 
331
 
 
332
 
 
333
 
 
334
class TestMixin:
 
335
 
 
336
    def connectionMade(self):
 
337
        self.received = []
 
338
 
 
339
    def stringReceived(self, s):
 
340
        self.received.append(s)
 
341
 
 
342
    MAX_LENGTH = 50
 
343
    closed = 0
 
344
 
 
345
    def connectionLost(self, reason):
 
346
        self.closed = 1
 
347
 
 
348
 
 
349
class TestNetstring(TestMixin, basic.NetstringReceiver):
 
350
    pass
 
351
 
 
352
 
 
353
class LPTestCaseMixin:
 
354
 
 
355
    illegalStrings = []
 
356
    protocol = None
 
357
 
 
358
    def getProtocol(self):
 
359
        """
 
360
        Return a new instance of C{self.protocol} connected to a new instance
 
361
        of L{proto_helpers.StringTransport}.
 
362
        """
 
363
        t = proto_helpers.StringTransport()
 
364
        a = self.protocol()
 
365
        a.makeConnection(t)
 
366
        return a
 
367
 
 
368
 
 
369
    def test_illegal(self):
 
370
        """
 
371
        Assert that illegal strings cause the transport to be closed.
 
372
        """
 
373
        for s in self.illegalStrings:
 
374
            r = self.getProtocol()
 
375
            for c in s:
 
376
                r.dataReceived(c)
 
377
            self.assertTrue(r.transport.disconnecting)
 
378
 
 
379
 
 
380
class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):
 
381
 
 
382
    strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
 
383
 
 
384
    illegalStrings = [
 
385
        '9999999999999999999999', 'abc', '4:abcde',
 
386
        '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
 
387
 
 
388
    protocol = TestNetstring
 
389
 
 
390
    def test_buffer(self):
 
391
        """
 
392
        Test that when strings are received in chunks of different lengths,
 
393
        they are still parsed correctly.
 
394
        """
 
395
        for packet_size in range(1, 10):
 
396
            t = proto_helpers.StringTransport()
 
397
            a = TestNetstring()
 
398
            a.MAX_LENGTH = 699
 
399
            a.makeConnection(t)
 
400
            for s in self.strings:
 
401
                a.sendString(s)
 
402
            out = t.value()
 
403
            for i in range(len(out)/packet_size + 1):
 
404
                s = out[i*packet_size:(i+1)*packet_size]
 
405
                if s:
 
406
                    a.dataReceived(s)
 
407
            self.assertEquals(a.received, self.strings)
 
408
 
 
409
    def test_sendNonStrings(self):
 
410
        """
 
411
        L{basic.NetstringReceiver.sendString} will send objects that are not
 
412
        strings by sending their string representation according to str().
 
413
        """
 
414
        nonStrings = [ [], { 1 : 'a', 2 : 'b' }, ['a', 'b', 'c'], 673,
 
415
                       (12, "fine", "and", "you?") ]
 
416
        a = TestNetstring()
 
417
        t = proto_helpers.StringTransport()
 
418
        a.MAX_LENGTH = 100
 
419
        a.makeConnection(t)
 
420
        for s in nonStrings:
 
421
            a.sendString(s)
 
422
            out = t.value()
 
423
            t.clear()
 
424
            length = out[:out.find(":")]
 
425
            data = out[out.find(":") + 1:-1] #[:-1] to ignore the trailing ","
 
426
            self.assertEquals(int(length), len(str(s)))
 
427
            self.assertEquals(data, str(s))
 
428
 
 
429
        warnings = self.flushWarnings(
 
430
            offendingFunctions=[self.test_sendNonStrings])
 
431
        self.assertEqual(len(warnings), 5)
 
432
        self.assertEqual(
 
433
            warnings[0]["message"],
 
434
            "data passed to sendString() must be a string. Non-string support "
 
435
            "is deprecated since Twisted 10.0")
 
436
        self.assertEqual(
 
437
            warnings[0]['category'],
 
438
            DeprecationWarning)
 
439
 
 
440
 
 
441
class IntNTestCaseMixin(LPTestCaseMixin):
 
442
    """
 
443
    TestCase mixin for int-prefixed protocols.
 
444
    """
 
445
 
 
446
    protocol = None
 
447
    strings = None
 
448
    illegalStrings = None
 
449
    partialStrings = None
 
450
 
 
451
    def test_receive(self):
 
452
        """
 
453
        Test receiving data find the same data send.
 
454
        """
 
455
        r = self.getProtocol()
 
456
        for s in self.strings:
 
457
            for c in struct.pack(r.structFormat,len(s)) + s:
 
458
                r.dataReceived(c)
 
459
        self.assertEquals(r.received, self.strings)
 
460
 
 
461
    def test_partial(self):
 
462
        """
 
463
        Send partial data, nothing should be definitely received.
 
464
        """
 
465
        for s in self.partialStrings:
 
466
            r = self.getProtocol()
 
467
            for c in s:
 
468
                r.dataReceived(c)
 
469
            self.assertEquals(r.received, [])
 
470
 
 
471
    def test_send(self):
 
472
        """
 
473
        Test sending data over protocol.
 
474
        """
 
475
        r = self.getProtocol()
 
476
        r.sendString("b" * 16)
 
477
        self.assertEquals(r.transport.value(),
 
478
            struct.pack(r.structFormat, 16) + "b" * 16)
 
479
 
 
480
 
 
481
    def test_lengthLimitExceeded(self):
 
482
        """
 
483
        When a length prefix is received which is greater than the protocol's
 
484
        C{MAX_LENGTH} attribute, the C{lengthLimitExceeded} method is called
 
485
        with the received length prefix.
 
486
        """
 
487
        length = []
 
488
        r = self.getProtocol()
 
489
        r.lengthLimitExceeded = length.append
 
490
        r.MAX_LENGTH = 10
 
491
        r.dataReceived(struct.pack(r.structFormat, 11))
 
492
        self.assertEqual(length, [11])
 
493
 
 
494
 
 
495
    def test_longStringNotDelivered(self):
 
496
        """
 
497
        If a length prefix for a string longer than C{MAX_LENGTH} is delivered
 
498
        to C{dataReceived} at the same time as the entire string, the string is
 
499
        not passed to C{stringReceived}.
 
500
        """
 
501
        r = self.getProtocol()
 
502
        r.MAX_LENGTH = 10
 
503
        r.dataReceived(
 
504
            struct.pack(r.structFormat, 11) + 'x' * 11)
 
505
        self.assertEqual(r.received, [])
 
506
 
 
507
 
 
508
 
 
509
class TestInt32(TestMixin, basic.Int32StringReceiver):
 
510
    """
 
511
    A L{basic.Int32StringReceiver} storing received strings in an array.
 
512
 
 
513
    @ivar received: array holding received strings.
 
514
    """
 
515
 
 
516
 
 
517
class Int32TestCase(unittest.TestCase, IntNTestCaseMixin):
 
518
    """
 
519
    Test case for int32-prefixed protocol
 
520
    """
 
521
    protocol = TestInt32
 
522
    strings = ["a", "b" * 16]
 
523
    illegalStrings = ["\x10\x00\x00\x00aaaaaa"]
 
524
    partialStrings = ["\x00\x00\x00", "hello there", ""]
 
525
 
 
526
    def test_data(self):
 
527
        """
 
528
        Test specific behavior of the 32-bits length.
 
529
        """
 
530
        r = self.getProtocol()
 
531
        r.sendString("foo")
 
532
        self.assertEquals(r.transport.value(), "\x00\x00\x00\x03foo")
 
533
        r.dataReceived("\x00\x00\x00\x04ubar")
 
534
        self.assertEquals(r.received, ["ubar"])
 
535
 
 
536
 
 
537
class TestInt16(TestMixin, basic.Int16StringReceiver):
 
538
    """
 
539
    A L{basic.Int16StringReceiver} storing received strings in an array.
 
540
 
 
541
    @ivar received: array holding received strings.
 
542
    """
 
543
 
 
544
 
 
545
class Int16TestCase(unittest.TestCase, IntNTestCaseMixin):
 
546
    """
 
547
    Test case for int16-prefixed protocol
 
548
    """
 
549
    protocol = TestInt16
 
550
    strings = ["a", "b" * 16]
 
551
    illegalStrings = ["\x10\x00aaaaaa"]
 
552
    partialStrings = ["\x00", "hello there", ""]
 
553
 
 
554
    def test_data(self):
 
555
        """
 
556
        Test specific behavior of the 16-bits length.
 
557
        """
 
558
        r = self.getProtocol()
 
559
        r.sendString("foo")
 
560
        self.assertEquals(r.transport.value(), "\x00\x03foo")
 
561
        r.dataReceived("\x00\x04ubar")
 
562
        self.assertEquals(r.received, ["ubar"])
 
563
 
 
564
    def test_tooLongSend(self):
 
565
        """
 
566
        Send too much data: that should cause an error.
 
567
        """
 
568
        r = self.getProtocol()
 
569
        tooSend = "b" * (2**(r.prefixLength*8) + 1)
 
570
        self.assertRaises(AssertionError, r.sendString, tooSend)
 
571
 
 
572
 
 
573
class TestInt8(TestMixin, basic.Int8StringReceiver):
 
574
    """
 
575
    A L{basic.Int8StringReceiver} storing received strings in an array.
 
576
 
 
577
    @ivar received: array holding received strings.
 
578
    """
 
579
 
 
580
 
 
581
class Int8TestCase(unittest.TestCase, IntNTestCaseMixin):
 
582
    """
 
583
    Test case for int8-prefixed protocol
 
584
    """
 
585
    protocol = TestInt8
 
586
    strings = ["a", "b" * 16]
 
587
    illegalStrings = ["\x00\x00aaaaaa"]
 
588
    partialStrings = ["\x08", "dzadz", ""]
 
589
 
 
590
    def test_data(self):
 
591
        """
 
592
        Test specific behavior of the 8-bits length.
 
593
        """
 
594
        r = self.getProtocol()
 
595
        r.sendString("foo")
 
596
        self.assertEquals(r.transport.value(), "\x03foo")
 
597
        r.dataReceived("\x04ubar")
 
598
        self.assertEquals(r.received, ["ubar"])
 
599
 
 
600
    def test_tooLongSend(self):
 
601
        """
 
602
        Send too much data: that should cause an error.
 
603
        """
 
604
        r = self.getProtocol()
 
605
        tooSend = "b" * (2**(r.prefixLength*8) + 1)
 
606
        self.assertRaises(AssertionError, r.sendString, tooSend)
 
607
 
 
608
 
 
609
class OnlyProducerTransport(object):
 
610
    # Transport which isn't really a transport, just looks like one to
 
611
    # someone not looking very hard.
 
612
 
 
613
    paused = False
 
614
    disconnecting = False
 
615
 
 
616
    def __init__(self):
 
617
        self.data = []
 
618
 
 
619
    def pauseProducing(self):
 
620
        self.paused = True
 
621
 
 
622
    def resumeProducing(self):
 
623
        self.paused = False
 
624
 
 
625
    def write(self, bytes):
 
626
        self.data.append(bytes)
 
627
 
 
628
 
 
629
class ConsumingProtocol(basic.LineReceiver):
 
630
    # Protocol that really, really doesn't want any more bytes.
 
631
 
 
632
    def lineReceived(self, line):
 
633
        self.transport.write(line)
 
634
        self.pauseProducing()
 
635
 
 
636
 
 
637
class ProducerTestCase(unittest.TestCase):
 
638
    def testPauseResume(self):
 
639
        p = ConsumingProtocol()
 
640
        t = OnlyProducerTransport()
 
641
        p.makeConnection(t)
 
642
 
 
643
        p.dataReceived('hello, ')
 
644
        self.failIf(t.data)
 
645
        self.failIf(t.paused)
 
646
        self.failIf(p.paused)
 
647
 
 
648
        p.dataReceived('world\r\n')
 
649
 
 
650
        self.assertEquals(t.data, ['hello, world'])
 
651
        self.failUnless(t.paused)
 
652
        self.failUnless(p.paused)
 
653
 
 
654
        p.resumeProducing()
 
655
 
 
656
        self.failIf(t.paused)
 
657
        self.failIf(p.paused)
 
658
 
 
659
        p.dataReceived('hello\r\nworld\r\n')
 
660
 
 
661
        self.assertEquals(t.data, ['hello, world', 'hello'])
 
662
        self.failUnless(t.paused)
 
663
        self.failUnless(p.paused)
 
664
 
 
665
        p.resumeProducing()
 
666
        p.dataReceived('goodbye\r\n')
 
667
 
 
668
        self.assertEquals(t.data, ['hello, world', 'hello', 'world'])
 
669
        self.failUnless(t.paused)
 
670
        self.failUnless(p.paused)
 
671
 
 
672
        p.resumeProducing()
 
673
 
 
674
        self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
 
675
        self.failUnless(t.paused)
 
676
        self.failUnless(p.paused)
 
677
 
 
678
        p.resumeProducing()
 
679
 
 
680
        self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
 
681
        self.failIf(t.paused)
 
682
        self.failIf(p.paused)
 
683
 
 
684
 
 
685
 
 
686
class TestableProxyClientFactory(portforward.ProxyClientFactory):
 
687
    """
 
688
    Test proxy client factory that keeps the last created protocol instance.
 
689
 
 
690
    @ivar protoInstance: the last instance of the protocol.
 
691
    @type protoInstance: L{portforward.ProxyClient}
 
692
    """
 
693
 
 
694
    def buildProtocol(self, addr):
 
695
        """
 
696
        Create the protocol instance and keeps track of it.
 
697
        """
 
698
        proto = portforward.ProxyClientFactory.buildProtocol(self, addr)
 
699
        self.protoInstance = proto
 
700
        return proto
 
701
 
 
702
 
 
703
 
 
704
class TestableProxyFactory(portforward.ProxyFactory):
 
705
    """
 
706
    Test proxy factory that keeps the last created protocol instance.
 
707
 
 
708
    @ivar protoInstance: the last instance of the protocol.
 
709
    @type protoInstance: L{portforward.ProxyServer}
 
710
 
 
711
    @ivar clientFactoryInstance: client factory used by C{protoInstance} to
 
712
        create forward connections.
 
713
    @type clientFactoryInstance: L{TestableProxyClientFactory}
 
714
    """
 
715
 
 
716
    def buildProtocol(self, addr):
 
717
        """
 
718
        Create the protocol instance, keeps track of it, and makes it use
 
719
        C{clientFactoryInstance} as client factory.
 
720
        """
 
721
        proto = portforward.ProxyFactory.buildProtocol(self, addr)
 
722
        self.clientFactoryInstance = TestableProxyClientFactory()
 
723
        # Force the use of this specific instance
 
724
        proto.clientProtocolFactory = lambda: self.clientFactoryInstance
 
725
        self.protoInstance = proto
 
726
        return proto
 
727
 
 
728
 
 
729
 
 
730
class Portforwarding(unittest.TestCase):
 
731
    """
 
732
    Test port forwarding.
 
733
    """
 
734
 
 
735
    def setUp(self):
 
736
        self.serverProtocol = wire.Echo()
 
737
        self.clientProtocol = protocol.Protocol()
 
738
        self.openPorts = []
 
739
 
 
740
 
 
741
    def tearDown(self):
 
742
        try:
 
743
            self.proxyServerFactory.protoInstance.transport.loseConnection()
 
744
        except AttributeError:
 
745
            pass
 
746
        try:
 
747
            self.proxyServerFactory.clientFactoryInstance.protoInstance.transport.loseConnection()
 
748
        except AttributeError:
 
749
            pass
 
750
        try:
 
751
            self.clientProtocol.transport.loseConnection()
 
752
        except AttributeError:
 
753
            pass
 
754
        try:
 
755
            self.serverProtocol.transport.loseConnection()
 
756
        except AttributeError:
 
757
            pass
 
758
        return defer.gatherResults(
 
759
            [defer.maybeDeferred(p.stopListening) for p in self.openPorts])
 
760
 
 
761
 
 
762
    def test_portforward(self):
 
763
        """
 
764
        Test port forwarding through Echo protocol.
 
765
        """
 
766
        realServerFactory = protocol.ServerFactory()
 
767
        realServerFactory.protocol = lambda: self.serverProtocol
 
768
        realServerPort = reactor.listenTCP(0, realServerFactory,
 
769
                                           interface='127.0.0.1')
 
770
        self.openPorts.append(realServerPort)
 
771
        self.proxyServerFactory = TestableProxyFactory('127.0.0.1',
 
772
                                realServerPort.getHost().port)
 
773
        proxyServerPort = reactor.listenTCP(0, self.proxyServerFactory,
 
774
                                            interface='127.0.0.1')
 
775
        self.openPorts.append(proxyServerPort)
 
776
 
 
777
        nBytes = 1000
 
778
        received = []
 
779
        d = defer.Deferred()
 
780
        def testDataReceived(data):
 
781
            received.extend(data)
 
782
            if len(received) >= nBytes:
 
783
                self.assertEquals(''.join(received), 'x' * nBytes)
 
784
                d.callback(None)
 
785
        self.clientProtocol.dataReceived = testDataReceived
 
786
 
 
787
        def testConnectionMade():
 
788
            self.clientProtocol.transport.write('x' * nBytes)
 
789
        self.clientProtocol.connectionMade = testConnectionMade
 
790
 
 
791
        clientFactory = protocol.ClientFactory()
 
792
        clientFactory.protocol = lambda: self.clientProtocol
 
793
 
 
794
        reactor.connectTCP(
 
795
            '127.0.0.1', proxyServerPort.getHost().port, clientFactory)
 
796
 
 
797
        return d
 
798
 
 
799
 
 
800
 
 
801
class StringTransportTestCase(unittest.TestCase):
 
802
    """
 
803
    Test L{proto_helpers.StringTransport} helper behaviour.
 
804
    """
 
805
 
 
806
    def test_noUnicode(self):
 
807
        """
 
808
        Test that L{proto_helpers.StringTransport} doesn't accept unicode data.
 
809
        """
 
810
        s = proto_helpers.StringTransport()
 
811
        self.assertRaises(TypeError, s.write, u'foo')