~ntt-pf-lab/nova/monkey_patch_notification

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/conch/test/test_channel.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2007-2008 Twisted Matrix Laboratories
 
2
# See LICENSE for details
 
3
 
 
4
"""
 
5
Test ssh/channel.py.
 
6
"""
 
7
from twisted.conch.ssh import channel
 
8
from twisted.trial import unittest
 
9
 
 
10
 
 
11
class MockTransport(object):
 
12
    """
 
13
    A mock Transport.  All we use is the getPeer() and getHost() methods.
 
14
    Channels implement the ITransport interface, and their getPeer() and
 
15
    getHost() methods return ('SSH', <transport's getPeer/Host value>) so
 
16
    we need to implement these methods so they have something to draw
 
17
    from.
 
18
    """
 
19
    def getPeer(self):
 
20
        return ('MockPeer',)
 
21
 
 
22
    def getHost(self):
 
23
        return ('MockHost',)
 
24
 
 
25
 
 
26
class MockConnection(object):
 
27
    """
 
28
    A mock for twisted.conch.ssh.connection.SSHConnection.  Record the data
 
29
    that channels send, and when they try to close the connection.
 
30
 
 
31
    @ivar data: a C{dict} mapping channel id #s to lists of data sent by that
 
32
        channel.
 
33
    @ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
 
34
        (extended data type, data) sent by that channel.
 
35
    @ivar closes: a C{dict} mapping channel id #s to True if that channel sent
 
36
        a close message.
 
37
    """
 
38
    transport = MockTransport()
 
39
 
 
40
    def __init__(self):
 
41
        self.data = {}
 
42
        self.extData = {}
 
43
        self.closes = {}
 
44
 
 
45
    def logPrefix(self):
 
46
        """
 
47
        Return our logging prefix.
 
48
        """
 
49
        return "MockConnection"
 
50
 
 
51
    def sendData(self, channel, data):
 
52
        """
 
53
        Record the sent data.
 
54
        """
 
55
        self.data.setdefault(channel, []).append(data)
 
56
 
 
57
    def sendExtendedData(self, channel, type, data):
 
58
        """
 
59
        Record the sent extended data.
 
60
        """
 
61
        self.extData.setdefault(channel, []).append((type, data))
 
62
 
 
63
    def sendClose(self, channel):
 
64
        """
 
65
        Record that the channel sent a close message.
 
66
        """
 
67
        self.closes[channel] = True
 
68
 
 
69
 
 
70
class ChannelTestCase(unittest.TestCase):
 
71
 
 
72
    def setUp(self):
 
73
        """
 
74
        Initialize the channel.  remoteMaxPacket is 10 so that data is able
 
75
        to be sent (the default of 0 means no data is sent because no packets
 
76
        are made).
 
77
        """
 
78
        self.conn = MockConnection()
 
79
        self.channel = channel.SSHChannel(conn=self.conn,
 
80
                remoteMaxPacket=10)
 
81
        self.channel.name = 'channel'
 
82
 
 
83
    def test_init(self):
 
84
        """
 
85
        Test that SSHChannel initializes correctly.  localWindowSize defaults
 
86
        to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
 
87
        defaults (what OpenSSH uses for those variables).
 
88
 
 
89
        The values in the second set of assertions are meaningless; they serve
 
90
        only to verify that the instance variables are assigned in the correct
 
91
        order.
 
92
        """
 
93
        c = channel.SSHChannel(conn=self.conn)
 
94
        self.assertEquals(c.localWindowSize, 131072)
 
95
        self.assertEquals(c.localWindowLeft, 131072)
 
96
        self.assertEquals(c.localMaxPacket, 32768)
 
97
        self.assertEquals(c.remoteWindowLeft, 0)
 
98
        self.assertEquals(c.remoteMaxPacket, 0)
 
99
        self.assertEquals(c.conn, self.conn)
 
100
        self.assertEquals(c.data, None)
 
101
        self.assertEquals(c.avatar, None)
 
102
 
 
103
        c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
 
104
        self.assertEquals(c2.localWindowSize, 1)
 
105
        self.assertEquals(c2.localWindowLeft, 1)
 
106
        self.assertEquals(c2.localMaxPacket, 2)
 
107
        self.assertEquals(c2.remoteWindowLeft, 3)
 
108
        self.assertEquals(c2.remoteMaxPacket, 4)
 
109
        self.assertEquals(c2.conn, 5)
 
110
        self.assertEquals(c2.data, 6)
 
111
        self.assertEquals(c2.avatar, 7)
 
112
 
 
113
    def test_str(self):
 
114
        """
 
115
        Test that str(SSHChannel) works gives the channel name and local and
 
116
        remote windows at a glance..
 
117
        """
 
118
        self.assertEquals(str(self.channel), '<SSHChannel channel (lw 131072 '
 
119
                'rw 0)>')
 
120
 
 
121
    def test_logPrefix(self):
 
122
        """
 
123
        Test that SSHChannel.logPrefix gives the name of the channel, the
 
124
        local channel ID and the underlying connection.
 
125
        """
 
126
        self.assertEquals(self.channel.logPrefix(), 'SSHChannel channel '
 
127
                '(unknown) on MockConnection')
 
128
 
 
129
    def test_addWindowBytes(self):
 
130
        """
 
131
        Test that addWindowBytes adds bytes to the window and resumes writing
 
132
        if it was paused.
 
133
        """
 
134
        cb = [False]
 
135
        def stubStartWriting():
 
136
            cb[0] = True
 
137
        self.channel.startWriting = stubStartWriting
 
138
        self.channel.write('test')
 
139
        self.channel.writeExtended(1, 'test')
 
140
        self.channel.addWindowBytes(50)
 
141
        self.assertEquals(self.channel.remoteWindowLeft, 50 - 4 - 4)
 
142
        self.assertTrue(self.channel.areWriting)
 
143
        self.assertTrue(cb[0])
 
144
        self.assertEquals(self.channel.buf, '')
 
145
        self.assertEquals(self.conn.data[self.channel], ['test'])
 
146
        self.assertEquals(self.channel.extBuf, [])
 
147
        self.assertEquals(self.conn.extData[self.channel], [(1, 'test')])
 
148
 
 
149
        cb[0] = False
 
150
        self.channel.addWindowBytes(20)
 
151
        self.assertFalse(cb[0])
 
152
 
 
153
        self.channel.write('a'*80)
 
154
        self.channel.loseConnection()
 
155
        self.channel.addWindowBytes(20)
 
156
        self.assertFalse(cb[0])
 
157
 
 
158
    def test_requestReceived(self):
 
159
        """
 
160
        Test that requestReceived handles requests by dispatching them to
 
161
        request_* methods.
 
162
        """
 
163
        self.channel.request_test_method = lambda data: data == ''
 
164
        self.assertTrue(self.channel.requestReceived('test-method', ''))
 
165
        self.assertFalse(self.channel.requestReceived('test-method', 'a'))
 
166
        self.assertFalse(self.channel.requestReceived('bad-method', ''))
 
167
 
 
168
    def test_closeReceieved(self):
 
169
        """
 
170
        Test that the default closeReceieved closes the connection.
 
171
        """
 
172
        self.assertFalse(self.channel.closing)
 
173
        self.channel.closeReceived()
 
174
        self.assertTrue(self.channel.closing)
 
175
 
 
176
    def test_write(self):
 
177
        """
 
178
        Test that write handles data correctly.  Send data up to the size
 
179
        of the remote window, splitting the data into packets of length
 
180
        remoteMaxPacket.
 
181
        """
 
182
        cb = [False]
 
183
        def stubStopWriting():
 
184
            cb[0] = True
 
185
        # no window to start with
 
186
        self.channel.stopWriting = stubStopWriting
 
187
        self.channel.write('d')
 
188
        self.channel.write('a')
 
189
        self.assertFalse(self.channel.areWriting)
 
190
        self.assertTrue(cb[0])
 
191
        # regular write
 
192
        self.channel.addWindowBytes(20)
 
193
        self.channel.write('ta')
 
194
        data = self.conn.data[self.channel]
 
195
        self.assertEquals(data, ['da', 'ta'])
 
196
        self.assertEquals(self.channel.remoteWindowLeft, 16)
 
197
        # larger than max packet
 
198
        self.channel.write('12345678901')
 
199
        self.assertEquals(data, ['da', 'ta', '1234567890', '1'])
 
200
        self.assertEquals(self.channel.remoteWindowLeft, 5)
 
201
        # running out of window
 
202
        cb[0] = False
 
203
        self.channel.write('123456')
 
204
        self.assertFalse(self.channel.areWriting)
 
205
        self.assertTrue(cb[0])
 
206
        self.assertEquals(data, ['da', 'ta', '1234567890', '1', '12345'])
 
207
        self.assertEquals(self.channel.buf, '6')
 
208
        self.assertEquals(self.channel.remoteWindowLeft, 0)
 
209
 
 
210
    def test_writeExtended(self):
 
211
        """
 
212
        Test that writeExtended handles data correctly.  Send extended data
 
213
        up to the size of the window, splitting the extended data into packets
 
214
        of length remoteMaxPacket.
 
215
        """
 
216
        cb = [False]
 
217
        def stubStopWriting():
 
218
            cb[0] = True
 
219
        # no window to start with
 
220
        self.channel.stopWriting = stubStopWriting
 
221
        self.channel.writeExtended(1, 'd')
 
222
        self.channel.writeExtended(1, 'a')
 
223
        self.channel.writeExtended(2, 't')
 
224
        self.assertFalse(self.channel.areWriting)
 
225
        self.assertTrue(cb[0])
 
226
        # regular write
 
227
        self.channel.addWindowBytes(20)
 
228
        self.channel.writeExtended(2, 'a')
 
229
        data = self.conn.extData[self.channel]
 
230
        self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a')])
 
231
        self.assertEquals(self.channel.remoteWindowLeft, 16)
 
232
        # larger than max packet
 
233
        self.channel.writeExtended(3, '12345678901')
 
234
        self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'),
 
235
            (3, '1234567890'), (3, '1')])
 
236
        self.assertEquals(self.channel.remoteWindowLeft, 5)
 
237
        # running out of window
 
238
        cb[0] = False
 
239
        self.channel.writeExtended(4, '123456')
 
240
        self.assertFalse(self.channel.areWriting)
 
241
        self.assertTrue(cb[0])
 
242
        self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'),
 
243
            (3, '1234567890'), (3, '1'), (4, '12345')])
 
244
        self.assertEquals(self.channel.extBuf, [[4, '6']])
 
245
        self.assertEquals(self.channel.remoteWindowLeft, 0)
 
246
 
 
247
    def test_writeSequence(self):
 
248
        """
 
249
        Test that writeSequence is equivalent to write(''.join(sequece)).
 
250
        """
 
251
        self.channel.addWindowBytes(20)
 
252
        self.channel.writeSequence(map(str, range(10)))
 
253
        self.assertEquals(self.conn.data[self.channel], ['0123456789'])
 
254
 
 
255
    def test_loseConnection(self):
 
256
        """
 
257
        Tesyt that loseConnection() doesn't close the channel until all
 
258
        the data is sent.
 
259
        """
 
260
        self.channel.write('data')
 
261
        self.channel.writeExtended(1, 'datadata')
 
262
        self.channel.loseConnection()
 
263
        self.assertEquals(self.conn.closes.get(self.channel), None)
 
264
        self.channel.addWindowBytes(4) # send regular data
 
265
        self.assertEquals(self.conn.closes.get(self.channel), None)
 
266
        self.channel.addWindowBytes(8) # send extended data
 
267
        self.assertTrue(self.conn.closes.get(self.channel))
 
268
 
 
269
    def test_getPeer(self):
 
270
        """
 
271
        Test that getPeer() returns ('SSH', <connection transport peer>).
 
272
        """
 
273
        self.assertEquals(self.channel.getPeer(), ('SSH', 'MockPeer'))
 
274
 
 
275
    def test_getHost(self):
 
276
        """
 
277
        Test that getHost() returns ('SSH', <connection transport host>).
 
278
        """
 
279
        self.assertEquals(self.channel.getHost(), ('SSH', 'MockHost'))