~certify-web-dev/twisted/certify-trunk

« back to all changes in this revision

Viewing changes to twisted/conch/test/test_ssh.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-17 14:52:35 UTC
  • mfrom: (1.1.5 upstream) (2.1.2 etch)
  • Revision ID: james.westby@ubuntu.com-20070117145235-btmig6qfmqfen0om
Tags: 2.5.0-0ubuntu1
New upstream version, compatible with python2.5.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.conch.test.test_ssh -*-
 
2
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
try:
 
6
    import Crypto
 
7
except ImportError:
 
8
    Crypto = None
 
9
 
 
10
from twisted.conch.ssh import common, session, forwarding
 
11
from twisted.conch import avatar, error
 
12
from twisted.cred import portal
 
13
from twisted.internet import defer, protocol, reactor
 
14
from twisted.internet.error import ProcessTerminated
 
15
from twisted.python import failure, log
 
16
from twisted.trial import unittest
 
17
 
 
18
from test_recvline import LoopbackRelay
 
19
 
 
20
import struct
 
21
 
 
22
 
 
23
class ConchTestRealm:
 
24
 
 
25
    def requestAvatar(self, avatarID, mind, *interfaces):
 
26
        unittest.assertEquals(avatarID, 'testuser')
 
27
        a = ConchTestAvatar()
 
28
        return interfaces[0], a, a.logout
 
29
 
 
30
class ConchTestAvatar(avatar.ConchUser):
 
31
    loggedOut = False
 
32
 
 
33
    def __init__(self):
 
34
        avatar.ConchUser.__init__(self)
 
35
        self.listeners = {}
 
36
        self.channelLookup.update({'session': session.SSHSession,
 
37
                        'direct-tcpip':forwarding.openConnectForwardingClient})
 
38
        self.subsystemLookup.update({'crazy': CrazySubsystem})
 
39
 
 
40
    def global_foo(self, data):
 
41
        unittest.assertEquals(data, 'bar')
 
42
        return 1
 
43
 
 
44
    def global_foo_2(self, data):
 
45
        unittest.assertEquals(data, 'bar2')
 
46
        return 1, 'data'
 
47
 
 
48
    def global_tcpip_forward(self, data):
 
49
        host, port = forwarding.unpackGlobal_tcpip_forward(data)
 
50
        try: listener = reactor.listenTCP(port,
 
51
                forwarding.SSHListenForwardingFactory(self.conn,
 
52
                    (host, port),
 
53
                    forwarding.SSHListenServerForwardingChannel),
 
54
                interface = host)
 
55
        except:
 
56
            log.err()
 
57
            unittest.fail("something went wrong with remote->local forwarding")
 
58
            return 0
 
59
        else:
 
60
            self.listeners[(host, port)] = listener
 
61
            return 1
 
62
 
 
63
    def global_cancel_tcpip_forward(self, data):
 
64
        host, port = forwarding.unpackGlobal_tcpip_forward(data)
 
65
        listener = self.listeners.get((host, port), None)
 
66
        if not listener:
 
67
            return 0
 
68
        del self.listeners[(host, port)]
 
69
        listener.stopListening()
 
70
        return 1
 
71
 
 
72
    def logout(self):
 
73
        loggedOut = True
 
74
        for listener in self.listeners.values():
 
75
            log.msg('stopListening %s' % listener)
 
76
            listener.stopListening()
 
77
 
 
78
class ConchSessionForTestAvatar:
 
79
 
 
80
    def __init__(self, avatar):
 
81
        unittest.assert_(isinstance(avatar, ConchTestAvatar))
 
82
        self.avatar = avatar
 
83
        self.cmd = None
 
84
        self.proto = None
 
85
        self.ptyReq = False
 
86
        self.eof = 0
 
87
 
 
88
    def getPty(self, term, windowSize, attrs):
 
89
        log.msg('pty req')
 
90
        unittest.assertEquals(term, 'conch-test-term')
 
91
        unittest.assertEquals(windowSize, (24, 80, 0, 0))
 
92
        self.ptyReq = True
 
93
 
 
94
    def openShell(self, proto):
 
95
        log.msg('openning shell')
 
96
        unittest.assertEquals(self.ptyReq, True)
 
97
        self.proto = proto
 
98
        EchoTransport(proto)
 
99
        self.cmd = 'shell'
 
100
 
 
101
    def execCommand(self, proto, cmd):
 
102
        self.cmd = cmd
 
103
        unittest.assert_(cmd.split()[0] in ['false', 'echo', 'secho', 'eecho','jumboliah'],
 
104
                'invalid command: %s' % cmd.split()[0])
 
105
        if cmd == 'jumboliah':
 
106
            raise error.ConchError('bad exec')
 
107
        self.proto = proto
 
108
        f = cmd.split()[0]
 
109
        if f == 'false':
 
110
            FalseTransport(proto)
 
111
        elif f == 'echo':
 
112
            t = EchoTransport(proto)
 
113
            t.write(cmd[5:])
 
114
            t.loseConnection()
 
115
        elif f == 'secho':
 
116
            t = SuperEchoTransport(proto)
 
117
            t.write(cmd[6:])
 
118
            t.loseConnection()
 
119
        elif f == 'eecho':
 
120
            t = ErrEchoTransport(proto)
 
121
            t.write(cmd[6:])
 
122
            t.loseConnection()
 
123
        self.avatar.conn.transport.expectedLoseConnection = 1
 
124
 
 
125
#    def closeReceived(self):
 
126
#        #if self.proto:
 
127
#        #   self.proto.transport.loseConnection()
 
128
#        self.loseConnection()
 
129
 
 
130
    def eofReceived(self):
 
131
        self.eof = 1
 
132
 
 
133
    def closed(self):
 
134
        log.msg('closed cmd "%s"' % self.cmd)
 
135
        if self.cmd == 'echo hello':
 
136
            rwl = self.proto.session.remoteWindowLeft
 
137
            unittest.assertEquals(rwl, 4)
 
138
        elif self.cmd == 'eecho hello':
 
139
            rwl = self.proto.session.remoteWindowLeft
 
140
            unittest.assertEquals(rwl, 4)
 
141
        elif self.cmd == 'shell':
 
142
            unittest.assert_(self.eof)
 
143
 
 
144
from twisted.python import components
 
145
components.registerAdapter(ConchSessionForTestAvatar, ConchTestAvatar, session.ISession)
 
146
 
 
147
class CrazySubsystem(protocol.Protocol):
 
148
 
 
149
    def __init__(self, *args, **kw):
 
150
        pass
 
151
 
 
152
    def connectionMade(self):
 
153
        """
 
154
        good ... good
 
155
        """
 
156
 
 
157
class FalseTransport:
 
158
 
 
159
    def __init__(self, p):
 
160
        p.makeConnection(self)
 
161
        p.processEnded(failure.Failure(ProcessTerminated(255, None, None)))
 
162
 
 
163
    def loseConnection(self):
 
164
        pass
 
165
 
 
166
class EchoTransport:
 
167
 
 
168
    def __init__(self, p):
 
169
        self.proto = p
 
170
        p.makeConnection(self)
 
171
        self.closed = 0
 
172
 
 
173
    def write(self, data):
 
174
        log.msg(repr(data))
 
175
        self.proto.outReceived(data)
 
176
        self.proto.outReceived('\r\n')
 
177
        if '\x00' in data: # mimic 'exit' for the shell test
 
178
            self.loseConnection()
 
179
 
 
180
    def loseConnection(self):
 
181
        if self.closed: return
 
182
        self.closed = 1
 
183
        self.proto.inConnectionLost()
 
184
        self.proto.outConnectionLost()
 
185
        self.proto.errConnectionLost()
 
186
        self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
 
187
 
 
188
class ErrEchoTransport:
 
189
 
 
190
    def __init__(self, p):
 
191
        self.proto = p
 
192
        p.makeConnection(self)
 
193
        self.closed = 0
 
194
 
 
195
    def write(self, data):
 
196
        self.proto.errReceived(data)
 
197
        self.proto.errReceived('\r\n')
 
198
 
 
199
    def loseConnection(self):
 
200
        if self.closed: return
 
201
        self.closed = 1
 
202
        self.proto.inConnectionLost()
 
203
        self.proto.outConnectionLost()
 
204
        self.proto.errConnectionLost()
 
205
        self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
 
206
 
 
207
class SuperEchoTransport:
 
208
 
 
209
    def __init__(self, p):
 
210
        self.proto = p
 
211
        p.makeConnection(self)
 
212
        self.closed = 0
 
213
 
 
214
    def write(self, data):
 
215
        self.proto.outReceived(data)
 
216
        self.proto.outReceived('\r\n')
 
217
        self.proto.errReceived(data)
 
218
        self.proto.errReceived('\r\n')
 
219
 
 
220
    def loseConnection(self):
 
221
        if self.closed: return
 
222
        self.closed = 1
 
223
        self.proto.inConnectionLost()
 
224
        self.proto.outConnectionLost()
 
225
        self.proto.errConnectionLost()
 
226
        self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
 
227
 
 
228
class _LogTimeFormatMixin:
 
229
 
 
230
    def setUpClass(self):
 
231
        from twisted.python import log
 
232
        self._oldTimeFormat = log.FileLogObserver.timeFormat
 
233
        log.FileLogObserver.timeFormat = '%Y/%m/%d %H:%M:%S %Z'
 
234
 
 
235
    def tearDownClass(self):
 
236
        log.FileLogObserver.timeFormat = self._oldTimeFormat
 
237
 
 
238
if Crypto: # stuff that needs PyCrypto to even import
 
239
    from twisted.conch import checkers
 
240
    from twisted.conch.ssh import channel, connection, factory, keys
 
241
    from twisted.conch.ssh import transport, userauth
 
242
 
 
243
    from test_keys import publicRSA_openssh, privateRSA_openssh
 
244
    from test_keys import publicDSA_openssh, privateDSA_openssh
 
245
 
 
246
 
 
247
    class UtilityTestCase(unittest.TestCase):
 
248
        def testCounter(self):
 
249
            c = transport._Counter('\x00\x00', 2)
 
250
            for i in xrange(256 * 256):
 
251
                self.assertEquals(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
 
252
            # It should wrap around, too.
 
253
            for i in xrange(256 * 256):
 
254
                self.assertEquals(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
 
255
 
 
256
 
 
257
    class ConchTestPublicKeyChecker(checkers.SSHPublicKeyDatabase):
 
258
        def checkKey(self, credentials):
 
259
            unittest.assertEquals(credentials.username, 'testuser', 'bad username')
 
260
            unittest.assertEquals(credentials.blob, keys.getPublicKeyString(data=publicDSA_openssh))
 
261
            return 1
 
262
 
 
263
    class ConchTestPasswordChecker:
 
264
        credentialInterfaces = checkers.IUsernamePassword,
 
265
 
 
266
        def requestAvatarId(self, credentials):
 
267
            unittest.assertEquals(credentials.username, 'testuser', 'bad username')
 
268
            unittest.assertEquals(credentials.password, 'testpass', 'bad password')
 
269
            return defer.succeed(credentials.username)
 
270
 
 
271
    class ConchTestSSHChecker(checkers.SSHProtocolChecker):
 
272
 
 
273
        def areDone(self, avatarId):
 
274
            unittest.assertEquals(avatarId, 'testuser')
 
275
            if len(self.successfulCredentials[avatarId]) < 2:
 
276
                return 0
 
277
            else:
 
278
                return 1
 
279
 
 
280
    class ConchTestServerFactory(factory.SSHFactory):
 
281
        noisy = 0
 
282
 
 
283
        services = {
 
284
            'ssh-userauth':userauth.SSHUserAuthServer,
 
285
            'ssh-connection':connection.SSHConnection
 
286
        }
 
287
 
 
288
        def buildProtocol(self, addr):
 
289
            proto = ConchTestServer()
 
290
            proto.supportedPublicKeys = self.privateKeys.keys()
 
291
            proto.factory = self
 
292
 
 
293
            if hasattr(self, 'expectedLoseConnection'):
 
294
                proto.expectedLoseConnection = self.expectedLoseConnection
 
295
 
 
296
            self.proto = proto
 
297
            return proto
 
298
 
 
299
        def getPublicKeys(self):
 
300
            return {
 
301
                'ssh-rsa':keys.getPublicKeyString(data=publicRSA_openssh),
 
302
                'ssh-dss':keys.getPublicKeyString(data=publicDSA_openssh)
 
303
            }
 
304
 
 
305
        def getPrivateKeys(self):
 
306
            return {
 
307
                'ssh-rsa':keys.getPrivateKeyObject(data=privateRSA_openssh),
 
308
                'ssh-dss':keys.getPrivateKeyObject(data=privateDSA_openssh)
 
309
            }
 
310
 
 
311
        def getPrimes(self):
 
312
            return {
 
313
                2048:[(transport.DH_GENERATOR, transport.DH_PRIME)]
 
314
            }
 
315
 
 
316
        def getService(self, trans, name):
 
317
            return factory.SSHFactory.getService(self, trans, name)
 
318
 
 
319
    class ConchTestBase:
 
320
 
 
321
        done = 0
 
322
        allowedToError = 0
 
323
 
 
324
        def connectionLost(self, reason):
 
325
            if self.done:
 
326
                return
 
327
            if not hasattr(self,'expectedLoseConnection'):
 
328
                unittest.fail('unexpectedly lost connection %s\n%s' % (self, reason))
 
329
            self.done = 1
 
330
 
 
331
        def receiveError(self, reasonCode, desc):
 
332
            self.expectedLoseConnection = 1
 
333
            if not self.allowedToError:
 
334
                unittest.fail('got disconnect for %s: reason %s, desc: %s' %
 
335
                               (self, reasonCode, desc))
 
336
            self.loseConnection()
 
337
 
 
338
        def receiveUnimplemented(self, seqID):
 
339
            unittest.fail('got unimplemented: seqid %s'  % seqID)
 
340
            self.expectedLoseConnection = 1
 
341
            self.loseConnection()
 
342
 
 
343
    class ConchTestServer(ConchTestBase, transport.SSHServerTransport):
 
344
 
 
345
        def connectionLost(self, reason):
 
346
            ConchTestBase.connectionLost(self, reason)
 
347
            transport.SSHServerTransport.connectionLost(self, reason)
 
348
 
 
349
    class ConchTestClient(ConchTestBase, transport.SSHClientTransport):
 
350
 
 
351
        def connectionLost(self, reason):
 
352
            ConchTestBase.connectionLost(self, reason)
 
353
            transport.SSHClientTransport.connectionLost(self, reason)
 
354
 
 
355
        def verifyHostKey(self, key, fp):
 
356
            unittest.assertEquals(key, keys.getPublicKeyString(data = publicRSA_openssh))
 
357
            unittest.assertEquals(fp,'3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af')
 
358
            return defer.succeed(1)
 
359
 
 
360
        def connectionSecure(self):
 
361
            self.requestService(ConchTestClientAuth('testuser',
 
362
                ConchTestClientConnection()))
 
363
 
 
364
    class ConchTestClientAuth(userauth.SSHUserAuthClient):
 
365
 
 
366
        hasTriedNone = 0 # have we tried the 'none' auth yet?
 
367
        canSucceedPublicKey = 0 # can we succed with this yet?
 
368
        canSucceedPassword = 0
 
369
 
 
370
        def ssh_USERAUTH_SUCCESS(self, packet):
 
371
            if not self.canSucceedPassword and self.canSucceedPublicKey:
 
372
                unittest.fail('got USERAUTH_SUCESS before password and publickey')
 
373
            userauth.SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
 
374
 
 
375
        def getPassword(self):
 
376
            self.canSucceedPassword = 1
 
377
            return defer.succeed('testpass')
 
378
 
 
379
        def getPrivateKey(self):
 
380
            self.canSucceedPublicKey = 1
 
381
            return defer.succeed(keys.getPrivateKeyObject(data=privateDSA_openssh))
 
382
 
 
383
        def getPublicKey(self):
 
384
            return keys.getPublicKeyString(data=publicDSA_openssh)
 
385
 
 
386
    class ConchTestClientConnection(connection.SSHConnection):
 
387
 
 
388
        name = 'ssh-connection'
 
389
        results = 0
 
390
        totalResults = 8
 
391
 
 
392
        def serviceStarted(self):
 
393
            self.openChannel(SSHTestFailExecChannel(conn = self))
 
394
            self.openChannel(SSHTestFalseChannel(conn = self))
 
395
            self.openChannel(SSHTestEchoChannel(localWindow=4, localMaxPacket=5, conn = self))
 
396
            self.openChannel(SSHTestErrChannel(localWindow=4, localMaxPacket=5, conn = self))
 
397
            self.openChannel(SSHTestMaxPacketChannel(localWindow=12, localMaxPacket=1, conn = self))
 
398
            self.openChannel(SSHTestShellChannel(conn = self))
 
399
            self.openChannel(SSHTestSubsystemChannel(conn = self))
 
400
            self.openChannel(SSHUnknownChannel(conn = self))
 
401
 
 
402
        def addResult(self):
 
403
            self.results += 1
 
404
            log.msg('got %s of %s results' % (self.results, self.totalResults))
 
405
            if self.results == self.totalResults:
 
406
                self.transport.expectedLoseConnection = 1
 
407
                self.serviceStopped()
 
408
 
 
409
    class SSHUnknownChannel(channel.SSHChannel):
 
410
 
 
411
        name = 'crazy-unknown-channel'
 
412
 
 
413
        def openFailed(self, reason):
 
414
            """
 
415
            good .... good
 
416
            """
 
417
            log.msg('unknown open failed')
 
418
            log.flushErrors()
 
419
            self.conn.addResult()
 
420
 
 
421
        def channelOpen(self, ignored):
 
422
            unittest.fail("opened unknown channel")
 
423
 
 
424
    class SSHTestFailExecChannel(channel.SSHChannel):
 
425
 
 
426
        name = 'session'
 
427
 
 
428
        def openFailed(self, reason):
 
429
            unittest.fail('fail exec open failed: %s' % reason)
 
430
 
 
431
        def channelOpen(self, ignore):
 
432
            d = self.conn.sendRequest(self, 'exec', common.NS('jumboliah'), 1)
 
433
            d.addCallback(self._cbRequestWorked)
 
434
            d.addErrback(self._ebRequestWorked)
 
435
            log.msg('opened fail exec')
 
436
 
 
437
        def _cbRequestWorked(self, ignored):
 
438
            unittest.fail('fail exec succeeded')
 
439
 
 
440
        def _ebRequestWorked(self, ignored):
 
441
            log.msg('fail exec finished')
 
442
            log.flushErrors()
 
443
            self.conn.addResult()
 
444
            self.loseConnection()
 
445
 
 
446
    class SSHTestFalseChannel(channel.SSHChannel):
 
447
 
 
448
        name = 'session'
 
449
 
 
450
        def openFailed(self, reason):
 
451
            unittest.fail('false open failed: %s' % reason)
 
452
 
 
453
        def channelOpen(self, ignored):
 
454
            d = self.conn.sendRequest(self, 'exec', common.NS('false'), 1)
 
455
            d.addCallback(self._cbRequestWorked)
 
456
            d.addErrback(self._ebRequestFailed)
 
457
            log.msg('opened false')
 
458
 
 
459
        def _cbRequestWorked(self, ignored):
 
460
            pass
 
461
 
 
462
        def _ebRequestFailed(self, reason):
 
463
            unittest.fail('false exec failed: %s' % reason)
 
464
 
 
465
        def dataReceived(self, data):
 
466
            unittest.fail('got data when using false')
 
467
 
 
468
        def request_exit_status(self, status):
 
469
            status, = struct.unpack('>L', status)
 
470
            if status == 0:
 
471
                unittest.fail('false exit status was 0')
 
472
            log.msg('finished false')
 
473
            self.conn.addResult()
 
474
            return 1
 
475
 
 
476
    class SSHTestEchoChannel(channel.SSHChannel):
 
477
 
 
478
        name = 'session'
 
479
        testBuf = ''
 
480
        eofCalled = 0
 
481
 
 
482
        def openFailed(self, reason):
 
483
            unittest.fail('echo open failed: %s' % reason)
 
484
 
 
485
        def channelOpen(self, ignore):
 
486
            d = self.conn.sendRequest(self, 'exec', common.NS('echo hello'), 1)
 
487
            d.addErrback(self._ebRequestFailed)
 
488
            log.msg('opened echo')
 
489
 
 
490
        def _ebRequestFailed(self, reason):
 
491
            unittest.fail('echo exec failed: %s' % reason)
 
492
 
 
493
        def dataReceived(self, data):
 
494
            self.testBuf += data
 
495
 
 
496
        def errReceived(self, dataType, data):
 
497
            unittest.fail('echo channel got extended data')
 
498
 
 
499
        def request_exit_status(self, status):
 
500
            self.status ,= struct.unpack('>L', status)
 
501
 
 
502
        def eofReceived(self):
 
503
            log.msg('eof received')
 
504
            self.eofCalled = 1
 
505
 
 
506
        def closed(self):
 
507
            if self.status != 0:
 
508
                unittest.fail('echo exit status was not 0: %i' % self.status)
 
509
            if self.testBuf != "hello\r\n":
 
510
                unittest.fail('echo did not return hello: %s' % repr(self.testBuf))
 
511
            unittest.assertEquals(self.localWindowLeft, 4)
 
512
            unittest.assert_(self.eofCalled)
 
513
            log.msg('finished echo')
 
514
            self.conn.addResult()
 
515
            return 1
 
516
 
 
517
    class SSHTestErrChannel(channel.SSHChannel):
 
518
 
 
519
        name = 'session'
 
520
        testBuf = ''
 
521
        eofCalled = 0
 
522
 
 
523
        def openFailed(self, reason):
 
524
            unittest.fail('err open failed: %s' % reason)
 
525
 
 
526
        def channelOpen(self, ignore):
 
527
            d = self.conn.sendRequest(self, 'exec', common.NS('eecho hello'), 1)
 
528
            d.addErrback(self._ebRequestFailed)
 
529
            log.msg('opened err')
 
530
 
 
531
        def _ebRequestFailed(self, reason):
 
532
            unittest.fail('err exec failed: %s' % reason)
 
533
 
 
534
        def dataReceived(self, data):
 
535
            unittest.fail('err channel got regular data: %s' % repr(data))
 
536
 
 
537
        def extReceived(self, dataType, data):
 
538
            unittest.assertEquals(dataType, connection.EXTENDED_DATA_STDERR)
 
539
            self.testBuf += data
 
540
 
 
541
        def request_exit_status(self, status):
 
542
            self.status ,= struct.unpack('>L', status)
 
543
 
 
544
        def eofReceived(self):
 
545
            log.msg('eof received')
 
546
            self.eofCalled = 1
 
547
 
 
548
        def closed(self):
 
549
            if self.status != 0:
 
550
                unittest.fail('err exit status was not 0: %i' % self.status)
 
551
            if self.testBuf != "hello\r\n":
 
552
                unittest.fail('err did not return hello: %s' % repr(self.testBuf))
 
553
            unittest.assertEquals(self.localWindowLeft, 4)
 
554
            unittest.assert_(self.eofCalled)
 
555
            log.msg('finished err')
 
556
            self.conn.addResult()
 
557
            return 1
 
558
 
 
559
    class SSHTestMaxPacketChannel(channel.SSHChannel):
 
560
 
 
561
        name = 'session'
 
562
        testBuf = ''
 
563
        testExtBuf = ''
 
564
        eofCalled = 0
 
565
 
 
566
        def openFailed(self, reason):
 
567
            unittest.fail('max packet open failed: %s' % reason)
 
568
 
 
569
        def channelOpen(self, ignore):
 
570
            d = self.conn.sendRequest(self, 'exec', common.NS('secho hello'), 1)
 
571
            d.addErrback(self._ebRequestFailed)
 
572
            log.msg('opened max packet')
 
573
 
 
574
        def _ebRequestFailed(self, reason):
 
575
            unittest.fail('max packet exec failed: %s' % reason)
 
576
 
 
577
        def dataReceived(self, data):
 
578
            self.testBuf += data
 
579
 
 
580
        def extReceived(self, dataType, data):
 
581
            unittest.assertEquals(dataType, connection.EXTENDED_DATA_STDERR)
 
582
            self.testExtBuf += data
 
583
 
 
584
        def request_exit_status(self, status):
 
585
            self.status ,= struct.unpack('>L', status)
 
586
 
 
587
        def eofReceived(self):
 
588
            log.msg('eof received')
 
589
            self.eofCalled = 1
 
590
 
 
591
        def closed(self):
 
592
            if self.status != 0:
 
593
                unittest.fail('echo exit status was not 0: %i' % self.status)
 
594
            unittest.assertEquals(self.testBuf, 'hello\r\n')
 
595
            unittest.assertEquals(self.testExtBuf, 'hello\r\n')
 
596
            unittest.assertEquals(self.localWindowLeft, 12)
 
597
            unittest.assert_(self.eofCalled)
 
598
            log.msg('finished max packet')
 
599
            self.conn.addResult()
 
600
            return 1
 
601
 
 
602
    class SSHTestShellChannel(channel.SSHChannel):
 
603
 
 
604
        name = 'session'
 
605
        testBuf = ''
 
606
        eofCalled = 0
 
607
        closeCalled = 0
 
608
 
 
609
        def openFailed(self, reason):
 
610
            unittest.fail('shell open failed: %s' % reason)
 
611
 
 
612
        def channelOpen(self, ignored):
 
613
            data = session.packRequest_pty_req('conch-test-term', (24, 80, 0, 0), '')
 
614
            d = self.conn.sendRequest(self, 'pty-req', data, 1)
 
615
            d.addCallback(self._cbPtyReq)
 
616
            d.addErrback(self._ebPtyReq)
 
617
            log.msg('opened shell')
 
618
 
 
619
        def _cbPtyReq(self, ignored):
 
620
            d = self.conn.sendRequest(self, 'shell', '', 1)
 
621
            d.addCallback(self._cbShellOpen)
 
622
            d.addErrback(self._ebShellOpen)
 
623
 
 
624
        def _ebPtyReq(self, reason):
 
625
            unittest.fail('pty request failed: %s' % reason)
 
626
 
 
627
        def _cbShellOpen(self, ignored):
 
628
            self.write('testing the shell!\x00')
 
629
            self.conn.sendEOF(self)
 
630
 
 
631
        def _ebShellOpen(self, reason):
 
632
            unittest.fail('shell request failed: %s' % reason)
 
633
 
 
634
        def dataReceived(self, data):
 
635
            self.testBuf += data
 
636
 
 
637
        def request_exit_status(self, status):
 
638
            self.status ,= struct.unpack('>L', status)
 
639
 
 
640
        def eofReceived(self):
 
641
            self.eofCalled = 1
 
642
 
 
643
        def closed(self):
 
644
            log.msg('calling shell closed')
 
645
            if self.status != 0:
 
646
                log.msg('shell exit status was not 0: %i' % self.status)
 
647
            unittest.assertEquals(self.testBuf, 'testing the shell!\x00\r\n')
 
648
            unittest.assert_(self.eofCalled)
 
649
            log.msg('finished shell')
 
650
            self.conn.addResult()
 
651
 
 
652
    class SSHTestSubsystemChannel(channel.SSHChannel):
 
653
 
 
654
        name = 'session'
 
655
 
 
656
        def openFailed(self, reason):
 
657
            unittest.fail('subsystem open failed: %s' % reason)
 
658
 
 
659
        def channelOpen(self, ignore):
 
660
            d = self.conn.sendRequest(self, 'subsystem', common.NS('not-crazy'), 1)
 
661
            d.addCallback(self._cbRequestWorked)
 
662
            d.addErrback(self._ebRequestFailed)
 
663
 
 
664
 
 
665
        def _cbRequestWorked(self, ignored):
 
666
            unittest.fail('opened non-crazy subsystem')
 
667
 
 
668
        def _ebRequestFailed(self, ignored):
 
669
            d = self.conn.sendRequest(self, 'subsystem', common.NS('crazy'), 1)
 
670
            d.addCallback(self._cbRealRequestWorked)
 
671
            d.addErrback(self._ebRealRequestFailed)
 
672
 
 
673
        def _cbRealRequestWorked(self, ignored):
 
674
            d1 = self.conn.sendGlobalRequest('foo', 'bar', 1)
 
675
            d1.addErrback(self._ebFirstGlobal)
 
676
 
 
677
            d2 = self.conn.sendGlobalRequest('foo-2', 'bar2', 1)
 
678
            d2.addCallback(lambda x: unittest.assertEquals(x, 'data'))
 
679
            d2.addErrback(self._ebSecondGlobal)
 
680
 
 
681
            d3 = self.conn.sendGlobalRequest('bar', 'foo', 1)
 
682
            d3.addCallback(self._cbThirdGlobal)
 
683
            d3.addErrback(lambda x,s=self: log.msg('subsystem finished') or s.conn.addResult() or s.loseConnection())
 
684
 
 
685
        def _ebRealRequestFailed(self, reason):
 
686
            unittest.fail('opening crazy subsystem failed: %s' % reason)
 
687
 
 
688
        def _ebFirstGlobal(self, reason):
 
689
            unittest.fail('first global request failed: %s' % reason)
 
690
 
 
691
        def _ebSecondGlobal(self, reason):
 
692
            unittest.fail('second global request failed: %s' % reason)
 
693
 
 
694
        def _cbThirdGlobal(self, ignored):
 
695
            unittest.fail('second global request succeeded')
 
696
 
 
697
 
 
698
 
 
699
class SSHProtocolTestCase(unittest.TestCase):
 
700
 
 
701
    if not Crypto:
 
702
        skip = "can't run w/o PyCrypto"
 
703
 
 
704
    def testOurServerOurClient(self):
 
705
        """test the Conch server against the Conch client
 
706
        """
 
707
        realm = ConchTestRealm()
 
708
        p = portal.Portal(realm)
 
709
        sshpc = ConchTestSSHChecker()
 
710
        sshpc.registerChecker(ConchTestPasswordChecker())
 
711
        sshpc.registerChecker(ConchTestPublicKeyChecker())
 
712
        p.registerChecker(sshpc)
 
713
        fac = ConchTestServerFactory()
 
714
        fac.portal = p
 
715
        fac.startFactory()
 
716
        self.server = fac.buildProtocol(None)
 
717
        self.clientTransport = LoopbackRelay(self.server)
 
718
        self.client = ConchTestClient()
 
719
        self.serverTransport = LoopbackRelay(self.client)
 
720
 
 
721
        self.server.makeConnection(self.serverTransport)
 
722
        self.client.makeConnection(self.clientTransport)
 
723
 
 
724
        while self.serverTransport.buffer or self.clientTransport.buffer:
 
725
            log.callWithContext({'system': 'serverTransport'},
 
726
                                self.serverTransport.clearBuffer)
 
727
            log.callWithContext({'system': 'clientTransport'},
 
728
                                self.clientTransport.clearBuffer)
 
729
        self.failIf(self.server.done and self.client.done)
 
730
 
 
731
 
 
732
class TestSSHFactory(unittest.TestCase):
 
733
 
 
734
    if not Crypto:
 
735
        skip = "can't run w/o PyCrypto"
 
736
 
 
737
    def testMultipleFactories(self):
 
738
        f1 = factory.SSHFactory()
 
739
        f2 = factory.SSHFactory()
 
740
        gpk = lambda: {'ssh-rsa' : "don't use"}
 
741
        f1.getPrimes = lambda: None
 
742
        f2.getPrimes = lambda: {1:(2,3)}
 
743
        f1.getPublicKeys = f2.getPublicKeys = gpk
 
744
        f1.getPrivateKeys = f2.getPrivateKeys = gpk
 
745
        f1.startFactory()
 
746
        f2.startFactory()
 
747
        p1 = f1.buildProtocol(None)
 
748
        p2 = f2.buildProtocol(None)
 
749
        self.failIf('diffie-hellman-group-exchange-sha1' in p1.supportedKeyExchanges,
 
750
                p1.supportedKeyExchanges)
 
751
        self.failUnless('diffie-hellman-group-exchange-sha1' in p2.supportedKeyExchanges,
 
752
                p2.supportedKeyExchanges)