1
# Copyright (c) 2007 Twisted Matrix Laboratories.
2
# See LICENSE for details
5
This module tests twisted.conch.ssh.connection.
10
from twisted.conch import error
11
from twisted.conch.ssh import channel, common, connection
12
from twisted.trial import unittest
13
from twisted.conch.test import test_userauth
16
class TestChannel(channel.SSHChannel):
18
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
20
@ivar gotOpen: True if channelOpen has been called.
21
@type gotOpen: C{bool}
22
@ivar specificData: the specific channel open data passed to channelOpen.
23
@type specificData: C{str}
24
@ivar openFailureReason: the reason passed to openFailed.
25
@type openFailed: C{error.ConchError}
26
@ivar inBuffer: a C{list} of strings received by the channel.
27
@type inBuffer: C{list}
28
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
30
@type extBuffer: C{list}
31
@ivar numberRequests: the number of requests that have been made to this
33
@type numberRequests: C{int}
34
@ivar gotEOF: True if the other side sent EOF.
36
@ivar gotOneClose: True if the other side closed the connection.
37
@type gotOneClose: C{bool}
38
@ivar gotClosed: True if the channel is closed.
39
@type gotClosed: C{bool}
45
return "TestChannel %i" % self.id
47
def channelOpen(self, specificData):
49
The channel is open. Set up the instance variables.
52
self.specificData = specificData
55
self.numberRequests = 0
57
self.gotOneClose = False
58
self.gotClosed = False
60
def openFailed(self, reason):
62
Opening the channel failed. Store the reason why.
64
self.openFailureReason = reason
66
def request_test(self, data):
68
A test request. Return True if data is 'data'.
72
self.numberRequests += 1
75
def dataReceived(self, data):
77
Data was received. Store it in the buffer.
79
self.inBuffer.append(data)
81
def extReceived(self, code, data):
83
Extended data was received. Store it in the buffer.
85
self.extBuffer.append((code, data))
87
def eofReceived(self):
89
EOF was received. Remember it.
93
def closeReceived(self):
95
Close was received. Remember it.
97
self.gotOneClose = True
101
The channel is closed. Rembember it.
103
self.gotClosed = True
107
A mocked-up version of twisted.conch.avatar.ConchUser
110
def lookupChannel(self, channelType, windowSize, maxPacket, data):
112
The server wants us to return a channel. If the requested channel is
113
our TestChannel, return it, otherwise return None.
115
if channelType == TestChannel.name:
116
return TestChannel(remoteWindow=windowSize,
117
remoteMaxPacket=maxPacket,
118
data=data, avatar=self)
120
def gotGlobalRequest(self, requestType, data):
122
The client has made a global request. If the global request is
123
'TestGlobal', return True. If the global request is 'TestData',
124
return True and the request-specific data we received. Otherwise,
127
if requestType == 'TestGlobal':
129
elif requestType == 'TestData':
134
class TestConnection(connection.SSHConnection):
136
A subclass of SSHConnection for testing.
138
@ivar channel: the current channel.
139
@type channel. C{TestChannel}
143
return "TestConnection"
145
def global_TestGlobal(self, data):
147
The other side made the 'TestGlobal' global request. Return True.
151
def global_Test_Data(self, data):
153
The other side made the 'Test-Data' global request. Return True and
154
the data we received.
158
def channel_TestChannel(self, windowSize, maxPacket, data):
160
The other side is requesting the TestChannel. Create a C{TestChannel}
161
instance, store it, and return it.
163
self.channel = TestChannel(remoteWindow=windowSize,
164
remoteMaxPacket=maxPacket, data=data)
167
def channel_ErrorChannel(self, windowSize, maxPacket, data):
169
The other side is requesting the ErrorChannel. Raise an exception.
171
raise AssertionError('no such thing')
175
class ConnectionTestCase(unittest.TestCase):
177
if test_userauth.transport is None:
178
skip = "Cannot run without PyCrypto"
181
self.transport = test_userauth.FakeTransport(None)
182
self.transport.avatar = TestAvatar()
183
self.conn = TestConnection()
184
self.conn.transport = self.transport
185
self.conn.serviceStarted()
187
def _openChannel(self, channel):
189
Open the channel with the default connection.
191
self.conn.openChannel(channel)
192
self.transport.packets = self.transport.packets[:-1]
193
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
194
channel.id, 255) + '\x00\x02\x00\x00\x00\x00\x80\x00')
197
self.conn.serviceStopped()
199
def test_linkAvatar(self):
201
Test that the connection links itself to the avatar in the
204
self.assertIdentical(self.transport.avatar.conn, self.conn)
206
def test_serviceStopped(self):
208
Test that serviceStopped() closes any open channels.
210
channel1 = TestChannel()
211
channel2 = TestChannel()
212
self.conn.openChannel(channel1)
213
self.conn.openChannel(channel2)
214
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00' * 4)
215
self.assertTrue(channel1.gotOpen)
216
self.assertFalse(channel2.gotOpen)
217
self.conn.serviceStopped()
218
self.assertTrue(channel1.gotClosed)
220
def test_GLOBAL_REQUEST(self):
222
Test that global request packets are dispatched to the global_*
223
methods and the return values are translated into success or failure
226
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\xff')
227
self.assertEquals(self.transport.packets,
228
[(connection.MSG_REQUEST_SUCCESS, '')])
229
self.transport.packets = []
230
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestData') + '\xff' +
232
self.assertEquals(self.transport.packets,
233
[(connection.MSG_REQUEST_SUCCESS, 'test data')])
234
self.transport.packets = []
235
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestBad') + '\xff')
236
self.assertEquals(self.transport.packets,
237
[(connection.MSG_REQUEST_FAILURE, '')])
238
self.transport.packets = []
239
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\x00')
240
self.assertEquals(self.transport.packets, [])
242
def test_REQUEST_SUCCESS(self):
244
Test that global request success packets cause the Deferred to be
247
d = self.conn.sendGlobalRequest('request', 'data', True)
248
self.conn.ssh_REQUEST_SUCCESS('data')
250
self.assertEquals(data, 'data')
252
d.addErrback(self.fail)
255
def test_REQUEST_FAILURE(self):
257
Test that global request failure packets cause the Deferred to be
260
d = self.conn.sendGlobalRequest('request', 'data', True)
261
self.conn.ssh_REQUEST_FAILURE('data')
263
self.assertEquals(f.value.data, 'data')
264
d.addCallback(self.fail)
268
def test_CHANNEL_OPEN(self):
270
Test that open channel packets cause a channel to be created and
271
opened or a failure message to be returned.
273
del self.transport.avatar
274
self.conn.ssh_CHANNEL_OPEN(common.NS('TestChannel') +
275
'\x00\x00\x00\x01' * 4)
276
self.assertTrue(self.conn.channel.gotOpen)
277
self.assertEquals(self.conn.channel.conn, self.conn)
278
self.assertEquals(self.conn.channel.data, '\x00\x00\x00\x01')
279
self.assertEquals(self.conn.channel.specificData, '\x00\x00\x00\x01')
280
self.assertEquals(self.conn.channel.remoteWindowLeft, 1)
281
self.assertEquals(self.conn.channel.remoteMaxPacket, 1)
282
self.assertEquals(self.transport.packets,
283
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
284
'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
285
'\x00\x00\x80\x00')])
286
self.transport.packets = []
287
self.conn.ssh_CHANNEL_OPEN(common.NS('BadChannel') +
288
'\x00\x00\x00\x02' * 4)
289
self.flushLoggedErrors()
290
self.assertEquals(self.transport.packets,
291
[(connection.MSG_CHANNEL_OPEN_FAILURE,
292
'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
293
'unknown channel') + common.NS(''))])
294
self.transport.packets = []
295
self.conn.ssh_CHANNEL_OPEN(common.NS('ErrorChannel') +
296
'\x00\x00\x00\x02' * 4)
297
self.flushLoggedErrors()
298
self.assertEquals(self.transport.packets,
299
[(connection.MSG_CHANNEL_OPEN_FAILURE,
300
'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
301
'unknown failure') + common.NS(''))])
303
def test_CHANNEL_OPEN_CONFIRMATION(self):
305
Test that channel open confirmation packets cause the channel to be
306
notified that it's open.
308
channel = TestChannel()
309
self.conn.openChannel(channel)
310
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00'*5)
311
self.assertEquals(channel.remoteWindowLeft, 0)
312
self.assertEquals(channel.remoteMaxPacket, 0)
313
self.assertEquals(channel.specificData, '\x00\x00\x00\x00')
314
self.assertEquals(self.conn.channelsToRemoteChannel[channel],
316
self.assertEquals(self.conn.localToRemoteChannel[0], 0)
318
def test_CHANNEL_OPEN_FAILURE(self):
320
Test that channel open failure packets cause the channel to be
321
notified that its opening failed.
323
channel = TestChannel()
324
self.conn.openChannel(channel)
325
self.conn.ssh_CHANNEL_OPEN_FAILURE('\x00\x00\x00\x00\x00\x00\x00'
326
'\x01' + common.NS('failure!'))
327
self.assertEquals(channel.openFailureReason.args, ('failure!', 1))
328
self.assertEquals(self.conn.channels.get(channel), None)
331
def test_CHANNEL_WINDOW_ADJUST(self):
333
Test that channel window adjust messages add bytes to the channel
336
channel = TestChannel()
337
self._openChannel(channel)
338
oldWindowSize = channel.remoteWindowLeft
339
self.conn.ssh_CHANNEL_WINDOW_ADJUST('\x00\x00\x00\x00\x00\x00\x00'
341
self.assertEquals(channel.remoteWindowLeft, oldWindowSize + 1)
343
def test_CHANNEL_DATA(self):
345
Test that channel data messages are passed up to the channel, or
346
cause the channel to be closed if the data is too large.
348
channel = TestChannel(localWindow=6, localMaxPacket=5)
349
self._openChannel(channel)
350
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS('data'))
351
self.assertEquals(channel.inBuffer, ['data'])
352
self.assertEquals(self.transport.packets,
353
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
354
'\x00\x00\x00\x04')])
355
self.transport.packets = []
356
longData = 'a' * (channel.localWindowLeft + 1)
357
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS(longData))
358
self.assertEquals(channel.inBuffer, ['data'])
359
self.assertEquals(self.transport.packets,
360
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
361
channel = TestChannel()
362
self._openChannel(channel)
363
bigData = 'a' * (channel.localMaxPacket + 1)
364
self.transport.packets = []
365
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x01' + common.NS(bigData))
366
self.assertEquals(channel.inBuffer, [])
367
self.assertEquals(self.transport.packets,
368
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
370
def test_CHANNEL_EXTENDED_DATA(self):
372
Test that channel extended data messages are passed up to the channel,
373
or cause the channel to be closed if they're too big.
375
channel = TestChannel(localWindow=6, localMaxPacket=5)
376
self._openChannel(channel)
377
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
378
'\x00' + common.NS('data'))
379
self.assertEquals(channel.extBuffer, [(0, 'data')])
380
self.assertEquals(self.transport.packets,
381
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
382
'\x00\x00\x00\x04')])
383
self.transport.packets = []
384
longData = 'a' * (channel.localWindowLeft + 1)
385
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
386
'\x00' + common.NS(longData))
387
self.assertEquals(channel.extBuffer, [(0, 'data')])
388
self.assertEquals(self.transport.packets,
389
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
390
channel = TestChannel()
391
self._openChannel(channel)
392
bigData = 'a' * (channel.localMaxPacket + 1)
393
self.transport.packets = []
394
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x01\x00\x00\x00'
395
'\x00' + common.NS(bigData))
396
self.assertEquals(channel.extBuffer, [])
397
self.assertEquals(self.transport.packets,
398
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
400
def test_CHANNEL_EOF(self):
402
Test that channel eof messages are passed up to the channel.
404
channel = TestChannel()
405
self._openChannel(channel)
406
self.conn.ssh_CHANNEL_EOF('\x00\x00\x00\x00')
407
self.assertTrue(channel.gotEOF)
409
def test_CHANNEL_CLOSE(self):
411
Test that channel close messages are passed up to the channel. Also,
412
test that channel.close() is called if both sides are closed when this
415
channel = TestChannel()
416
self._openChannel(channel)
417
self.conn.sendClose(channel)
418
self.conn.ssh_CHANNEL_CLOSE('\x00\x00\x00\x00')
419
self.assertTrue(channel.gotOneClose)
420
self.assertTrue(channel.gotClosed)
422
def test_CHANNEL_REQUEST_success(self):
424
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
426
channel = TestChannel()
427
self._openChannel(channel)
428
self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS('test')
430
self.assertEquals(channel.numberRequests, 1)
431
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
432
'test') + '\xff' + 'data')
434
self.assertEquals(self.transport.packets,
435
[(connection.MSG_CHANNEL_SUCCESS, '\x00\x00\x00\xff')])
439
def test_CHANNEL_REQUEST_failure(self):
441
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
443
channel = TestChannel()
444
self._openChannel(channel)
445
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
448
self.assertEquals(self.transport.packets,
449
[(connection.MSG_CHANNEL_FAILURE, '\x00\x00\x00\xff'
451
d.addCallback(self.fail)
455
def test_CHANNEL_REQUEST_SUCCESS(self):
457
Test that channel request success messages cause the Deferred to be
460
channel = TestChannel()
461
self._openChannel(channel)
462
d = self.conn.sendRequest(channel, 'test', 'data', True)
463
self.conn.ssh_CHANNEL_SUCCESS('\x00\x00\x00\x00')
465
self.assertTrue(result)
468
def test_CHANNEL_REQUEST_FAILURE(self):
470
Test that channel request failure messages cause the Deferred to be
473
channel = TestChannel()
474
self._openChannel(channel)
475
d = self.conn.sendRequest(channel, 'test', '', True)
476
self.conn.ssh_CHANNEL_FAILURE('\x00\x00\x00\x00')
478
self.assertEquals(result.value.value, 'channel request failed')
479
d.addCallback(self.fail)
483
def test_sendGlobalRequest(self):
485
Test that global request messages are sent in the right format.
487
d = self.conn.sendGlobalRequest('wantReply', 'data', True)
488
self.conn.sendGlobalRequest('noReply', '', False)
489
self.assertEquals(self.transport.packets,
490
[(connection.MSG_GLOBAL_REQUEST, common.NS('wantReply') +
492
(connection.MSG_GLOBAL_REQUEST, common.NS('noReply') +
494
self.assertEquals(self.conn.deferreds, {'global':[d]})
496
def test_openChannel(self):
498
Test that open channel messages are sent in the right format.
500
channel = TestChannel()
501
self.conn.openChannel(channel, 'aaaa')
502
self.assertEquals(self.transport.packets,
503
[(connection.MSG_CHANNEL_OPEN, common.NS('TestChannel') +
504
'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
505
self.assertEquals(channel.id, 0)
506
self.assertEquals(self.conn.localChannelID, 1)
508
def test_sendRequest(self):
510
Test that channel request messages are sent in the right format.
512
channel = TestChannel()
513
self._openChannel(channel)
514
d = self.conn.sendRequest(channel, 'test', 'test', True)
515
self.conn.sendRequest(channel, 'test2', '', False)
516
channel.localClosed = True # emulate sending a close message
517
self.conn.sendRequest(channel, 'test3', '', True)
518
self.assertEquals(self.transport.packets,
519
[(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
520
common.NS('test') + '\x01test'),
521
(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
522
common.NS('test2') + '\x00')])
523
self.assertEquals(self.conn.deferreds, {0:[d]})
525
def test_adjustWindow(self):
527
Test that channel window adjust messages cause bytes to be added
530
channel = TestChannel(localWindow=5)
531
self._openChannel(channel)
532
channel.localWindowLeft = 0
533
self.conn.adjustWindow(channel, 1)
534
self.assertEquals(channel.localWindowLeft, 1)
535
channel.localClosed = True
536
self.conn.adjustWindow(channel, 2)
537
self.assertEquals(channel.localWindowLeft, 1)
538
self.assertEquals(self.transport.packets,
539
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
540
'\x00\x00\x00\x01')])
542
def test_sendData(self):
544
Test that channel data messages are sent in the right format.
546
channel = TestChannel()
547
self._openChannel(channel)
548
self.conn.sendData(channel, 'a')
549
channel.localClosed = True
550
self.conn.sendData(channel, 'b')
551
self.assertEquals(self.transport.packets,
552
[(connection.MSG_CHANNEL_DATA, '\x00\x00\x00\xff' +
555
def test_sendExtendedData(self):
557
Test that channel extended data messages are sent in the right format.
559
channel = TestChannel()
560
self._openChannel(channel)
561
self.conn.sendExtendedData(channel, 1, 'test')
562
channel.localClosed = True
563
self.conn.sendExtendedData(channel, 2, 'test2')
564
self.assertEquals(self.transport.packets,
565
[(connection.MSG_CHANNEL_EXTENDED_DATA, '\x00\x00\x00\xff' +
566
'\x00\x00\x00\x01' + common.NS('test'))])
568
def test_sendEOF(self):
570
Test that channel EOF messages are sent in the right format.
572
channel = TestChannel()
573
self._openChannel(channel)
574
self.conn.sendEOF(channel)
575
self.assertEquals(self.transport.packets,
576
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
577
channel.localClosed = True
578
self.conn.sendEOF(channel)
579
self.assertEquals(self.transport.packets,
580
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
582
def test_sendClose(self):
584
Test that channel close messages are sent in the right format.
586
channel = TestChannel()
587
self._openChannel(channel)
588
self.conn.sendClose(channel)
589
self.assertTrue(channel.localClosed)
590
self.assertEquals(self.transport.packets,
591
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
592
self.conn.sendClose(channel)
593
self.assertEquals(self.transport.packets,
594
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
596
channel2 = TestChannel()
597
self._openChannel(channel2)
598
channel2.remoteClosed = True
599
self.conn.sendClose(channel2)
600
self.assertTrue(channel2.gotClosed)
602
def test_getChannelWithAvatar(self):
604
Test that getChannel dispatches to the avatar when an avatar is
605
present. Correct functioning without the avatar is verified in
608
channel = self.conn.getChannel('TestChannel', 50, 30, 'data')
609
self.assertEquals(channel.data, 'data')
610
self.assertEquals(channel.remoteWindowLeft, 50)
611
self.assertEquals(channel.remoteMaxPacket, 30)
612
self.assertRaises(error.ConchError, self.conn.getChannel,
613
'BadChannel', 50, 30, 'data')
615
def test_gotGlobalRequestWithoutAvatar(self):
617
Test that gotGlobalRequests dispatches to global_* without an avatar.
619
del self.transport.avatar
620
self.assertTrue(self.conn.gotGlobalRequest('TestGlobal', 'data'))
621
self.assertEquals(self.conn.gotGlobalRequest('Test-Data', 'data'),
623
self.assertFalse(self.conn.gotGlobalRequest('BadGlobal', 'data'))