1
# Copyright (C) 2007-2008 Twisted Matrix Laboratories
2
# See LICENSE for details
7
from twisted.conch.ssh import channel
8
from twisted.trial import unittest
11
class MockTransport(object):
13
A mock Transport. All we use is the getPeer() and getHost() methods.
14
Channels implement the ITransport interface, and their getPeer() and
15
getHost() methods return ('SSH', <transport's getPeer/Host value>) so
16
we need to implement these methods so they have something to draw
26
class MockConnection(object):
28
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
29
that channels send, and when they try to close the connection.
31
@ivar data: a C{dict} mapping channel id #s to lists of data sent by that
33
@ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
34
(extended data type, data) sent by that channel.
35
@ivar closes: a C{dict} mapping channel id #s to True if that channel sent
38
transport = MockTransport()
47
Return our logging prefix.
49
return "MockConnection"
51
def sendData(self, channel, data):
55
self.data.setdefault(channel, []).append(data)
57
def sendExtendedData(self, channel, type, data):
59
Record the sent extended data.
61
self.extData.setdefault(channel, []).append((type, data))
63
def sendClose(self, channel):
65
Record that the channel sent a close message.
67
self.closes[channel] = True
70
class ChannelTestCase(unittest.TestCase):
74
Initialize the channel. remoteMaxPacket is 10 so that data is able
75
to be sent (the default of 0 means no data is sent because no packets
78
self.conn = MockConnection()
79
self.channel = channel.SSHChannel(conn=self.conn,
81
self.channel.name = 'channel'
85
Test that SSHChannel initializes correctly. localWindowSize defaults
86
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
87
defaults (what OpenSSH uses for those variables).
89
The values in the second set of assertions are meaningless; they serve
90
only to verify that the instance variables are assigned in the correct
93
c = channel.SSHChannel(conn=self.conn)
94
self.assertEquals(c.localWindowSize, 131072)
95
self.assertEquals(c.localWindowLeft, 131072)
96
self.assertEquals(c.localMaxPacket, 32768)
97
self.assertEquals(c.remoteWindowLeft, 0)
98
self.assertEquals(c.remoteMaxPacket, 0)
99
self.assertEquals(c.conn, self.conn)
100
self.assertEquals(c.data, None)
101
self.assertEquals(c.avatar, None)
103
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
104
self.assertEquals(c2.localWindowSize, 1)
105
self.assertEquals(c2.localWindowLeft, 1)
106
self.assertEquals(c2.localMaxPacket, 2)
107
self.assertEquals(c2.remoteWindowLeft, 3)
108
self.assertEquals(c2.remoteMaxPacket, 4)
109
self.assertEquals(c2.conn, 5)
110
self.assertEquals(c2.data, 6)
111
self.assertEquals(c2.avatar, 7)
115
Test that str(SSHChannel) works gives the channel name and local and
116
remote windows at a glance..
118
self.assertEquals(str(self.channel), '<SSHChannel channel (lw 131072 '
121
def test_logPrefix(self):
123
Test that SSHChannel.logPrefix gives the name of the channel, the
124
local channel ID and the underlying connection.
126
self.assertEquals(self.channel.logPrefix(), 'SSHChannel channel '
127
'(unknown) on MockConnection')
129
def test_addWindowBytes(self):
131
Test that addWindowBytes adds bytes to the window and resumes writing
135
def stubStartWriting():
137
self.channel.startWriting = stubStartWriting
138
self.channel.write('test')
139
self.channel.writeExtended(1, 'test')
140
self.channel.addWindowBytes(50)
141
self.assertEquals(self.channel.remoteWindowLeft, 50 - 4 - 4)
142
self.assertTrue(self.channel.areWriting)
143
self.assertTrue(cb[0])
144
self.assertEquals(self.channel.buf, '')
145
self.assertEquals(self.conn.data[self.channel], ['test'])
146
self.assertEquals(self.channel.extBuf, [])
147
self.assertEquals(self.conn.extData[self.channel], [(1, 'test')])
150
self.channel.addWindowBytes(20)
151
self.assertFalse(cb[0])
153
self.channel.write('a'*80)
154
self.channel.loseConnection()
155
self.channel.addWindowBytes(20)
156
self.assertFalse(cb[0])
158
def test_requestReceived(self):
160
Test that requestReceived handles requests by dispatching them to
163
self.channel.request_test_method = lambda data: data == ''
164
self.assertTrue(self.channel.requestReceived('test-method', ''))
165
self.assertFalse(self.channel.requestReceived('test-method', 'a'))
166
self.assertFalse(self.channel.requestReceived('bad-method', ''))
168
def test_closeReceieved(self):
170
Test that the default closeReceieved closes the connection.
172
self.assertFalse(self.channel.closing)
173
self.channel.closeReceived()
174
self.assertTrue(self.channel.closing)
176
def test_write(self):
178
Test that write handles data correctly. Send data up to the size
179
of the remote window, splitting the data into packets of length
183
def stubStopWriting():
185
# no window to start with
186
self.channel.stopWriting = stubStopWriting
187
self.channel.write('d')
188
self.channel.write('a')
189
self.assertFalse(self.channel.areWriting)
190
self.assertTrue(cb[0])
192
self.channel.addWindowBytes(20)
193
self.channel.write('ta')
194
data = self.conn.data[self.channel]
195
self.assertEquals(data, ['da', 'ta'])
196
self.assertEquals(self.channel.remoteWindowLeft, 16)
197
# larger than max packet
198
self.channel.write('12345678901')
199
self.assertEquals(data, ['da', 'ta', '1234567890', '1'])
200
self.assertEquals(self.channel.remoteWindowLeft, 5)
201
# running out of window
203
self.channel.write('123456')
204
self.assertFalse(self.channel.areWriting)
205
self.assertTrue(cb[0])
206
self.assertEquals(data, ['da', 'ta', '1234567890', '1', '12345'])
207
self.assertEquals(self.channel.buf, '6')
208
self.assertEquals(self.channel.remoteWindowLeft, 0)
210
def test_writeExtended(self):
212
Test that writeExtended handles data correctly. Send extended data
213
up to the size of the window, splitting the extended data into packets
214
of length remoteMaxPacket.
217
def stubStopWriting():
219
# no window to start with
220
self.channel.stopWriting = stubStopWriting
221
self.channel.writeExtended(1, 'd')
222
self.channel.writeExtended(1, 'a')
223
self.channel.writeExtended(2, 't')
224
self.assertFalse(self.channel.areWriting)
225
self.assertTrue(cb[0])
227
self.channel.addWindowBytes(20)
228
self.channel.writeExtended(2, 'a')
229
data = self.conn.extData[self.channel]
230
self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a')])
231
self.assertEquals(self.channel.remoteWindowLeft, 16)
232
# larger than max packet
233
self.channel.writeExtended(3, '12345678901')
234
self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'),
235
(3, '1234567890'), (3, '1')])
236
self.assertEquals(self.channel.remoteWindowLeft, 5)
237
# running out of window
239
self.channel.writeExtended(4, '123456')
240
self.assertFalse(self.channel.areWriting)
241
self.assertTrue(cb[0])
242
self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'),
243
(3, '1234567890'), (3, '1'), (4, '12345')])
244
self.assertEquals(self.channel.extBuf, [[4, '6']])
245
self.assertEquals(self.channel.remoteWindowLeft, 0)
247
def test_writeSequence(self):
249
Test that writeSequence is equivalent to write(''.join(sequece)).
251
self.channel.addWindowBytes(20)
252
self.channel.writeSequence(map(str, range(10)))
253
self.assertEquals(self.conn.data[self.channel], ['0123456789'])
255
def test_loseConnection(self):
257
Tesyt that loseConnection() doesn't close the channel until all
260
self.channel.write('data')
261
self.channel.writeExtended(1, 'datadata')
262
self.channel.loseConnection()
263
self.assertEquals(self.conn.closes.get(self.channel), None)
264
self.channel.addWindowBytes(4) # send regular data
265
self.assertEquals(self.conn.closes.get(self.channel), None)
266
self.channel.addWindowBytes(8) # send extended data
267
self.assertTrue(self.conn.closes.get(self.channel))
269
def test_getPeer(self):
271
Test that getPeer() returns ('SSH', <connection transport peer>).
273
self.assertEquals(self.channel.getPeer(), ('SSH', 'MockPeer'))
275
def test_getHost(self):
277
Test that getHost() returns ('SSH', <connection transport host>).
279
self.assertEquals(self.channel.getHost(), ('SSH', 'MockHost'))