~justin-fathomdb/nova/justinsb-openstack-api-volumes

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/conch/test/test_userauth.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.conch.test.test_userauth -*-
 
2
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
"""
 
6
Tests for the implementation of the ssh-userauth service.
 
7
 
 
8
Maintainer: Paul Swartz
 
9
"""
 
10
 
 
11
from zope.interface import implements
 
12
 
 
13
from twisted.cred.checkers import ICredentialsChecker
 
14
from twisted.cred.credentials import IUsernamePassword, ISSHPrivateKey
 
15
from twisted.cred.credentials import IPluggableAuthenticationModules
 
16
from twisted.cred.credentials import IAnonymous
 
17
from twisted.cred.error import UnauthorizedLogin
 
18
from twisted.cred.portal import IRealm, Portal
 
19
from twisted.conch.error import ConchError, ValidPublicKey
 
20
from twisted.internet import defer, task
 
21
from twisted.protocols import loopback
 
22
from twisted.trial import unittest
 
23
 
 
24
try:
 
25
    import Crypto.Cipher.DES3, Crypto.Cipher.XOR
 
26
    import pyasn1
 
27
except ImportError:
 
28
    keys = None
 
29
 
 
30
 
 
31
    class transport:
 
32
        class SSHTransportBase:
 
33
            """
 
34
            A stub class so that later class definitions won't die.
 
35
            """
 
36
 
 
37
    class userauth:
 
38
        class SSHUserAuthClient:
 
39
            """
 
40
            A stub class so that leter class definitions won't die.
 
41
            """
 
42
else:
 
43
    from twisted.conch.ssh.common import NS
 
44
    from twisted.conch.checkers import SSHProtocolChecker
 
45
    from twisted.conch.ssh import keys, userauth, transport
 
46
    from twisted.conch.test import keydata
 
47
 
 
48
 
 
49
 
 
50
class ClientUserAuth(userauth.SSHUserAuthClient):
 
51
    """
 
52
    A mock user auth client.
 
53
    """
 
54
 
 
55
 
 
56
    def getPublicKey(self):
 
57
        """
 
58
        If this is the first time we've been called, return a blob for
 
59
        the DSA key.  Otherwise, return a blob
 
60
        for the RSA key.
 
61
        """
 
62
        if self.lastPublicKey:
 
63
            return keys.Key.fromString(keydata.publicRSA_openssh)
 
64
        else:
 
65
            return defer.succeed(keys.Key.fromString(keydata.publicDSA_openssh))
 
66
 
 
67
 
 
68
    def getPrivateKey(self):
 
69
        """
 
70
        Return the private key object for the RSA key.
 
71
        """
 
72
        return defer.succeed(keys.Key.fromString(keydata.privateRSA_openssh))
 
73
 
 
74
 
 
75
    def getPassword(self, prompt=None):
 
76
        """
 
77
        Return 'foo' as the password.
 
78
        """
 
79
        return defer.succeed('foo')
 
80
 
 
81
 
 
82
    def getGenericAnswers(self, name, information, answers):
 
83
        """
 
84
        Return 'foo' as the answer to two questions.
 
85
        """
 
86
        return defer.succeed(('foo', 'foo'))
 
87
 
 
88
 
 
89
 
 
90
class OldClientAuth(userauth.SSHUserAuthClient):
 
91
    """
 
92
    The old SSHUserAuthClient returned a PyCrypto key object from
 
93
    getPrivateKey() and a string from getPublicKey
 
94
    """
 
95
 
 
96
 
 
97
    def getPrivateKey(self):
 
98
        return defer.succeed(keys.Key.fromString(
 
99
            keydata.privateRSA_openssh).keyObject)
 
100
 
 
101
 
 
102
    def getPublicKey(self):
 
103
        return keys.Key.fromString(keydata.publicRSA_openssh).blob()
 
104
 
 
105
class ClientAuthWithoutPrivateKey(userauth.SSHUserAuthClient):
 
106
    """
 
107
    This client doesn't have a private key, but it does have a public key.
 
108
    """
 
109
 
 
110
 
 
111
    def getPrivateKey(self):
 
112
        return
 
113
 
 
114
 
 
115
    def getPublicKey(self):
 
116
        return keys.Key.fromString(keydata.publicRSA_openssh)
 
117
 
 
118
 
 
119
 
 
120
class FakeTransport(transport.SSHTransportBase):
 
121
    """
 
122
    L{userauth.SSHUserAuthServer} expects an SSH transport which has a factory
 
123
    attribute which has a portal attribute. Because the portal is important for
 
124
    testing authentication, we need to be able to provide an interesting portal
 
125
    object to the L{SSHUserAuthServer}.
 
126
 
 
127
    In addition, we want to be able to capture any packets sent over the
 
128
    transport.
 
129
 
 
130
    @ivar packets: a list of 2-tuples: (messageType, data).  Each 2-tuple is
 
131
        a sent packet.
 
132
    @type packets: C{list}
 
133
    @param lostConnecion: True if loseConnection has been called on us.
 
134
    @type lostConnection: C{bool}
 
135
    """
 
136
 
 
137
 
 
138
    class Service(object):
 
139
        """
 
140
        A mock service, representing the other service offered by the server.
 
141
        """
 
142
        name = 'nancy'
 
143
 
 
144
 
 
145
        def serviceStarted(self):
 
146
            pass
 
147
 
 
148
 
 
149
 
 
150
    class Factory(object):
 
151
        """
 
152
        A mock factory, representing the factory that spawned this user auth
 
153
        service.
 
154
        """
 
155
 
 
156
 
 
157
        def getService(self, transport, service):
 
158
            """
 
159
            Return our fake service.
 
160
            """
 
161
            if service == 'none':
 
162
                return FakeTransport.Service
 
163
 
 
164
 
 
165
 
 
166
    def __init__(self, portal):
 
167
        self.factory = self.Factory()
 
168
        self.factory.portal = portal
 
169
        self.lostConnection = False
 
170
        self.transport = self
 
171
        self.packets = []
 
172
 
 
173
 
 
174
 
 
175
    def sendPacket(self, messageType, message):
 
176
        """
 
177
        Record the packet sent by the service.
 
178
        """
 
179
        self.packets.append((messageType, message))
 
180
 
 
181
 
 
182
    def isEncrypted(self, direction):
 
183
        """
 
184
        Pretend that this transport encrypts traffic in both directions. The
 
185
        SSHUserAuthServer disables password authentication if the transport
 
186
        isn't encrypted.
 
187
        """
 
188
        return True
 
189
 
 
190
 
 
191
    def loseConnection(self):
 
192
        self.lostConnection = True
 
193
 
 
194
 
 
195
 
 
196
class Realm(object):
 
197
    """
 
198
    A mock realm for testing L{userauth.SSHUserAuthServer}.
 
199
 
 
200
    This realm is not actually used in the course of testing, so it returns the
 
201
    simplest thing that could possibly work.
 
202
    """
 
203
    implements(IRealm)
 
204
 
 
205
 
 
206
    def requestAvatar(self, avatarId, mind, *interfaces):
 
207
        return defer.succeed((interfaces[0], None, lambda: None))
 
208
 
 
209
 
 
210
 
 
211
class PasswordChecker(object):
 
212
    """
 
213
    A very simple username/password checker which authenticates anyone whose
 
214
    password matches their username and rejects all others.
 
215
    """
 
216
    credentialInterfaces = (IUsernamePassword,)
 
217
    implements(ICredentialsChecker)
 
218
 
 
219
 
 
220
    def requestAvatarId(self, creds):
 
221
        if creds.username == creds.password:
 
222
            return defer.succeed(creds.username)
 
223
        return defer.fail(UnauthorizedLogin("Invalid username/password pair"))
 
224
 
 
225
 
 
226
 
 
227
class PrivateKeyChecker(object):
 
228
    """
 
229
    A very simple public key checker which authenticates anyone whose
 
230
    public/private keypair is the same keydata.public/privateRSA_openssh.
 
231
    """
 
232
    credentialInterfaces = (ISSHPrivateKey,)
 
233
    implements(ICredentialsChecker)
 
234
 
 
235
 
 
236
 
 
237
    def requestAvatarId(self, creds):
 
238
        if creds.blob == keys.Key.fromString(keydata.publicRSA_openssh).blob():
 
239
            if creds.signature is not None:
 
240
                obj = keys.Key.fromString(creds.blob)
 
241
                if obj.verify(creds.signature, creds.sigData):
 
242
                    return creds.username
 
243
            else:
 
244
                raise ValidPublicKey()
 
245
        raise UnauthorizedLogin()
 
246
 
 
247
 
 
248
 
 
249
class PAMChecker(object):
 
250
    """
 
251
    A simple PAM checker which asks the user for a password, verifying them
 
252
    if the password is the same as their username.
 
253
    """
 
254
    credentialInterfaces = (IPluggableAuthenticationModules,)
 
255
    implements(ICredentialsChecker)
 
256
 
 
257
 
 
258
    def requestAvatarId(self, creds):
 
259
        d = creds.pamConversion([('Name: ', 2), ("Password: ", 1)])
 
260
        def check(values):
 
261
            if values == [(creds.username, 0), (creds.username, 0)]:
 
262
                return creds.username
 
263
            raise UnauthorizedLogin()
 
264
        return d.addCallback(check)
 
265
 
 
266
 
 
267
 
 
268
class AnonymousChecker(object):
 
269
    """
 
270
    A simple checker which isn't supported by L{SSHUserAuthServer}.
 
271
    """
 
272
    credentialInterfaces = (IAnonymous,)
 
273
    implements(ICredentialsChecker)
 
274
 
 
275
 
 
276
 
 
277
class SSHUserAuthServerTestCase(unittest.TestCase):
 
278
    """
 
279
    Tests for SSHUserAuthServer.
 
280
    """
 
281
 
 
282
 
 
283
    if keys is None:
 
284
        skip = "cannot run w/o PyCrypto"
 
285
 
 
286
 
 
287
    def setUp(self):
 
288
        self.realm = Realm()
 
289
        self.portal = Portal(self.realm)
 
290
        self.portal.registerChecker(PasswordChecker())
 
291
        self.portal.registerChecker(PrivateKeyChecker())
 
292
        self.portal.registerChecker(PAMChecker())
 
293
        self.authServer = userauth.SSHUserAuthServer()
 
294
        self.authServer.transport = FakeTransport(self.portal)
 
295
        self.authServer.serviceStarted()
 
296
        self.authServer.supportedAuthentications.sort() # give a consistent
 
297
                                                        # order
 
298
 
 
299
 
 
300
    def tearDown(self):
 
301
        self.authServer.serviceStopped()
 
302
        self.authServer = None
 
303
 
 
304
 
 
305
    def _checkFailed(self, ignored):
 
306
        """
 
307
        Check that the authentication has failed.
 
308
        """
 
309
        self.assertEquals(self.authServer.transport.packets[-1],
 
310
                (userauth.MSG_USERAUTH_FAILURE,
 
311
                NS('keyboard-interactive,password,publickey') + '\x00'))
 
312
 
 
313
 
 
314
    def test_noneAuthentication(self):
 
315
        """
 
316
        A client may request a list of authentication 'method name' values
 
317
        that may continue by using the "none" authentication 'method name'.
 
318
 
 
319
        See RFC 4252 Section 5.2.
 
320
        """
 
321
        d = self.authServer.ssh_USERAUTH_REQUEST(NS('foo') + NS('service') +
 
322
                                                 NS('none'))
 
323
        return d.addCallback(self._checkFailed)
 
324
 
 
325
 
 
326
    def test_successfulPasswordAuthentication(self):
 
327
        """
 
328
        When provided with correct password authentication information, the
 
329
        server should respond by sending a MSG_USERAUTH_SUCCESS message with
 
330
        no other data.
 
331
 
 
332
        See RFC 4252, Section 5.1.
 
333
        """
 
334
        packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('foo')
 
335
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
336
        def check(ignored):
 
337
            self.assertEqual(
 
338
                self.authServer.transport.packets,
 
339
                [(userauth.MSG_USERAUTH_SUCCESS, '')])
 
340
        return d.addCallback(check)
 
341
 
 
342
 
 
343
    def test_failedPasswordAuthentication(self):
 
344
        """
 
345
        When provided with invalid authentication details, the server should
 
346
        respond by sending a MSG_USERAUTH_FAILURE message which states whether
 
347
        the authentication was partially successful, and provides other, open
 
348
        options for authentication.
 
349
 
 
350
        See RFC 4252, Section 5.1.
 
351
        """
 
352
        # packet = username, next_service, authentication type, FALSE, password
 
353
        packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('bar')
 
354
        self.authServer.clock = task.Clock()
 
355
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
356
        self.assertEquals(self.authServer.transport.packets, [])
 
357
        self.authServer.clock.advance(2)
 
358
        return d.addCallback(self._checkFailed)
 
359
 
 
360
 
 
361
    def test_successfulPrivateKeyAuthentication(self):
 
362
        """
 
363
        Test that private key authentication completes sucessfully,
 
364
        """
 
365
        blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
 
366
        obj = keys.Key.fromString(keydata.privateRSA_openssh)
 
367
        packet = (NS('foo') + NS('none') + NS('publickey') + '\xff'
 
368
                + NS(obj.sshType()) + NS(blob))
 
369
        self.authServer.transport.sessionID = 'test'
 
370
        signature = obj.sign(NS('test') + chr(userauth.MSG_USERAUTH_REQUEST)
 
371
                + packet)
 
372
        packet += NS(signature)
 
373
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
374
        def check(ignored):
 
375
            self.assertEquals(self.authServer.transport.packets,
 
376
                    [(userauth.MSG_USERAUTH_SUCCESS, '')])
 
377
        return d.addCallback(check)
 
378
 
 
379
 
 
380
    def test_requestRaisesConchError(self):
 
381
        """
 
382
        ssh_USERAUTH_REQUEST should raise a ConchError if tryAuth returns
 
383
        None. Added to catch a bug noticed by pyflakes.
 
384
        """
 
385
        d = defer.Deferred()
 
386
 
 
387
        def mockCbFinishedAuth(self, ignored):
 
388
            self.fail('request should have raised ConochError')
 
389
 
 
390
        def mockTryAuth(kind, user, data):
 
391
            return None
 
392
 
 
393
        def mockEbBadAuth(reason):
 
394
            d.errback(reason.value)
 
395
 
 
396
        self.patch(self.authServer, 'tryAuth', mockTryAuth)
 
397
        self.patch(self.authServer, '_cbFinishedAuth', mockCbFinishedAuth)
 
398
        self.patch(self.authServer, '_ebBadAuth', mockEbBadAuth)
 
399
 
 
400
        packet = NS('user') + NS('none') + NS('public-key') + NS('data')
 
401
        # If an error other than ConchError is raised, this will trigger an
 
402
        # exception.
 
403
        self.authServer.ssh_USERAUTH_REQUEST(packet)
 
404
        return self.assertFailure(d, ConchError)
 
405
 
 
406
 
 
407
    def test_verifyValidPrivateKey(self):
 
408
        """
 
409
        Test that verifying a valid private key works.
 
410
        """
 
411
        blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
 
412
        packet = (NS('foo') + NS('none') + NS('publickey') + '\x00'
 
413
                + NS('ssh-rsa') + NS(blob))
 
414
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
415
        def check(ignored):
 
416
            self.assertEquals(self.authServer.transport.packets,
 
417
                    [(userauth.MSG_USERAUTH_PK_OK, NS('ssh-rsa') + NS(blob))])
 
418
        return d.addCallback(check)
 
419
 
 
420
 
 
421
    def test_failedPrivateKeyAuthenticationWithoutSignature(self):
 
422
        """
 
423
        Test that private key authentication fails when the public key
 
424
        is invalid.
 
425
        """
 
426
        blob = keys.Key.fromString(keydata.publicDSA_openssh).blob()
 
427
        packet = (NS('foo') + NS('none') + NS('publickey') + '\x00'
 
428
                + NS('ssh-dsa') + NS(blob))
 
429
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
430
        return d.addCallback(self._checkFailed)
 
431
 
 
432
 
 
433
    def test_failedPrivateKeyAuthenticationWithSignature(self):
 
434
        """
 
435
        Test that private key authentication fails when the public key
 
436
        is invalid.
 
437
        """
 
438
        blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
 
439
        obj = keys.Key.fromString(keydata.privateRSA_openssh)
 
440
        packet = (NS('foo') + NS('none') + NS('publickey') + '\xff'
 
441
                + NS('ssh-rsa') + NS(blob) + NS(obj.sign(blob)))
 
442
        self.authServer.transport.sessionID = 'test'
 
443
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
444
        return d.addCallback(self._checkFailed)
 
445
 
 
446
 
 
447
    def test_successfulPAMAuthentication(self):
 
448
        """
 
449
        Test that keyboard-interactive authentication succeeds.
 
450
        """
 
451
        packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
 
452
                + NS('') + NS(''))
 
453
        response = '\x00\x00\x00\x02' + NS('foo') + NS('foo')
 
454
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
455
        self.authServer.ssh_USERAUTH_INFO_RESPONSE(response)
 
456
        def check(ignored):
 
457
            self.assertEquals(self.authServer.transport.packets,
 
458
                    [(userauth.MSG_USERAUTH_INFO_REQUEST, (NS('') + NS('')
 
459
                        + NS('') + '\x00\x00\x00\x02' + NS('Name: ') + '\x01'
 
460
                        + NS('Password: ') + '\x00')),
 
461
                     (userauth.MSG_USERAUTH_SUCCESS, '')])
 
462
 
 
463
        return d.addCallback(check)
 
464
 
 
465
 
 
466
    def test_failedPAMAuthentication(self):
 
467
        """
 
468
        Test that keyboard-interactive authentication fails.
 
469
        """
 
470
        packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
 
471
                + NS('') + NS(''))
 
472
        response = '\x00\x00\x00\x02' + NS('bar') + NS('bar')
 
473
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
474
        self.authServer.ssh_USERAUTH_INFO_RESPONSE(response)
 
475
        def check(ignored):
 
476
            self.assertEquals(self.authServer.transport.packets[0],
 
477
                    (userauth.MSG_USERAUTH_INFO_REQUEST, (NS('') + NS('')
 
478
                        + NS('') + '\x00\x00\x00\x02' + NS('Name: ') + '\x01'
 
479
                        + NS('Password: ') + '\x00')))
 
480
        return d.addCallback(check).addCallback(self._checkFailed)
 
481
 
 
482
 
 
483
    def test_invalid_USERAUTH_INFO_RESPONSE_not_enough_data(self):
 
484
        """
 
485
        If ssh_USERAUTH_INFO_RESPONSE gets an invalid packet,
 
486
        the user authentication should fail.
 
487
        """
 
488
        packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
 
489
                + NS('') + NS(''))
 
490
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
491
        self.authServer.ssh_USERAUTH_INFO_RESPONSE(NS('\x00\x00\x00\x00' +
 
492
            NS('hi')))
 
493
        return d.addCallback(self._checkFailed)
 
494
 
 
495
 
 
496
    def test_invalid_USERAUTH_INFO_RESPONSE_too_much_data(self):
 
497
        """
 
498
        If ssh_USERAUTH_INFO_RESPONSE gets too much data, the user
 
499
        authentication should fail.
 
500
        """
 
501
        packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
 
502
                + NS('') + NS(''))
 
503
        response = '\x00\x00\x00\x02' + NS('foo') + NS('foo') + NS('foo')
 
504
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
505
        self.authServer.ssh_USERAUTH_INFO_RESPONSE(response)
 
506
        return d.addCallback(self._checkFailed)
 
507
 
 
508
 
 
509
    def test_onlyOnePAMAuthentication(self):
 
510
        """
 
511
        Because it requires an intermediate message, one can't send a second
 
512
        keyboard-interactive request while the first is still pending.
 
513
        """
 
514
        packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
 
515
                + NS('') + NS(''))
 
516
        self.authServer.ssh_USERAUTH_REQUEST(packet)
 
517
        self.authServer.ssh_USERAUTH_REQUEST(packet)
 
518
        self.assertEquals(self.authServer.transport.packets[-1][0],
 
519
                transport.MSG_DISCONNECT)
 
520
        self.assertEquals(self.authServer.transport.packets[-1][1][3],
 
521
                chr(transport.DISCONNECT_PROTOCOL_ERROR))
 
522
 
 
523
 
 
524
    def test_ignoreUnknownCredInterfaces(self):
 
525
        """
 
526
        L{SSHUserAuthServer} sets up
 
527
        C{SSHUserAuthServer.supportedAuthentications} by checking the portal's
 
528
        credentials interfaces and mapping them to SSH authentication method
 
529
        strings.  If the Portal advertises an interface that
 
530
        L{SSHUserAuthServer} can't map, it should be ignored.  This is a white
 
531
        box test.
 
532
        """
 
533
        server = userauth.SSHUserAuthServer()
 
534
        server.transport = FakeTransport(self.portal)
 
535
        self.portal.registerChecker(AnonymousChecker())
 
536
        server.serviceStarted()
 
537
        server.serviceStopped()
 
538
        server.supportedAuthentications.sort() # give a consistent order
 
539
        self.assertEquals(server.supportedAuthentications,
 
540
                          ['keyboard-interactive', 'password', 'publickey'])
 
541
 
 
542
 
 
543
    def test_removePasswordIfUnencrypted(self):
 
544
        """
 
545
        Test that the userauth service does not advertise password
 
546
        authentication if the password would be send in cleartext.
 
547
        """
 
548
        self.assertIn('password', self.authServer.supportedAuthentications)
 
549
        # no encryption
 
550
        clearAuthServer = userauth.SSHUserAuthServer()
 
551
        clearAuthServer.transport = FakeTransport(self.portal)
 
552
        clearAuthServer.transport.isEncrypted = lambda x: False
 
553
        clearAuthServer.serviceStarted()
 
554
        clearAuthServer.serviceStopped()
 
555
        self.failIfIn('password', clearAuthServer.supportedAuthentications)
 
556
        # only encrypt incoming (the direction the password is sent)
 
557
        halfAuthServer = userauth.SSHUserAuthServer()
 
558
        halfAuthServer.transport = FakeTransport(self.portal)
 
559
        halfAuthServer.transport.isEncrypted = lambda x: x == 'in'
 
560
        halfAuthServer.serviceStarted()
 
561
        halfAuthServer.serviceStopped()
 
562
        self.assertIn('password', halfAuthServer.supportedAuthentications)
 
563
 
 
564
 
 
565
    def test_removeKeyboardInteractiveIfUnencrypted(self):
 
566
        """
 
567
        Test that the userauth service does not advertise keyboard-interactive
 
568
        authentication if the password would be send in cleartext.
 
569
        """
 
570
        self.assertIn('keyboard-interactive',
 
571
                self.authServer.supportedAuthentications)
 
572
        # no encryption
 
573
        clearAuthServer = userauth.SSHUserAuthServer()
 
574
        clearAuthServer.transport = FakeTransport(self.portal)
 
575
        clearAuthServer.transport.isEncrypted = lambda x: False
 
576
        clearAuthServer.serviceStarted()
 
577
        clearAuthServer.serviceStopped()
 
578
        self.failIfIn('keyboard-interactive',
 
579
                clearAuthServer.supportedAuthentications)
 
580
        # only encrypt incoming (the direction the password is sent)
 
581
        halfAuthServer = userauth.SSHUserAuthServer()
 
582
        halfAuthServer.transport = FakeTransport(self.portal)
 
583
        halfAuthServer.transport.isEncrypted = lambda x: x == 'in'
 
584
        halfAuthServer.serviceStarted()
 
585
        halfAuthServer.serviceStopped()
 
586
        self.assertIn('keyboard-interactive',
 
587
                halfAuthServer.supportedAuthentications)
 
588
 
 
589
 
 
590
    def test_unencryptedConnectionWithoutPasswords(self):
 
591
        """
 
592
        If the L{SSHUserAuthServer} is not advertising passwords, then an
 
593
        unencrypted connection should not cause any warnings or exceptions.
 
594
        This is a white box test.
 
595
        """
 
596
        # create a Portal without password authentication
 
597
        portal = Portal(self.realm)
 
598
        portal.registerChecker(PrivateKeyChecker())
 
599
 
 
600
        # no encryption
 
601
        clearAuthServer = userauth.SSHUserAuthServer()
 
602
        clearAuthServer.transport = FakeTransport(portal)
 
603
        clearAuthServer.transport.isEncrypted = lambda x: False
 
604
        clearAuthServer.serviceStarted()
 
605
        clearAuthServer.serviceStopped()
 
606
        self.assertEquals(clearAuthServer.supportedAuthentications,
 
607
                          ['publickey'])
 
608
 
 
609
        # only encrypt incoming (the direction the password is sent)
 
610
        halfAuthServer = userauth.SSHUserAuthServer()
 
611
        halfAuthServer.transport = FakeTransport(portal)
 
612
        halfAuthServer.transport.isEncrypted = lambda x: x == 'in'
 
613
        halfAuthServer.serviceStarted()
 
614
        halfAuthServer.serviceStopped()
 
615
        self.assertEquals(clearAuthServer.supportedAuthentications,
 
616
                          ['publickey'])
 
617
 
 
618
 
 
619
    def test_loginTimeout(self):
 
620
        """
 
621
        Test that the login times out.
 
622
        """
 
623
        timeoutAuthServer = userauth.SSHUserAuthServer()
 
624
        timeoutAuthServer.clock = task.Clock()
 
625
        timeoutAuthServer.transport = FakeTransport(self.portal)
 
626
        timeoutAuthServer.serviceStarted()
 
627
        timeoutAuthServer.clock.advance(11 * 60 * 60)
 
628
        timeoutAuthServer.serviceStopped()
 
629
        self.assertEquals(timeoutAuthServer.transport.packets,
 
630
                [(transport.MSG_DISCONNECT,
 
631
                '\x00' * 3 +
 
632
                chr(transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) +
 
633
                NS("you took too long") + NS(''))])
 
634
        self.assertTrue(timeoutAuthServer.transport.lostConnection)
 
635
 
 
636
 
 
637
    def test_cancelLoginTimeout(self):
 
638
        """
 
639
        Test that stopping the service also stops the login timeout.
 
640
        """
 
641
        timeoutAuthServer = userauth.SSHUserAuthServer()
 
642
        timeoutAuthServer.clock = task.Clock()
 
643
        timeoutAuthServer.transport = FakeTransport(self.portal)
 
644
        timeoutAuthServer.serviceStarted()
 
645
        timeoutAuthServer.serviceStopped()
 
646
        timeoutAuthServer.clock.advance(11 * 60 * 60)
 
647
        self.assertEquals(timeoutAuthServer.transport.packets, [])
 
648
        self.assertFalse(timeoutAuthServer.transport.lostConnection)
 
649
 
 
650
 
 
651
    def test_tooManyAttempts(self):
 
652
        """
 
653
        Test that the server disconnects if the client fails authentication
 
654
        too many times.
 
655
        """
 
656
        packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('bar')
 
657
        self.authServer.clock = task.Clock()
 
658
        for i in range(21):
 
659
            d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
660
            self.authServer.clock.advance(2)
 
661
        def check(ignored):
 
662
            self.assertEquals(self.authServer.transport.packets[-1],
 
663
                (transport.MSG_DISCONNECT,
 
664
                '\x00' * 3 +
 
665
                chr(transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) +
 
666
                NS("too many bad auths") + NS('')))
 
667
        return d.addCallback(check)
 
668
 
 
669
 
 
670
    def test_failIfUnknownService(self):
 
671
        """
 
672
        If the user requests a service that we don't support, the
 
673
        authentication should fail.
 
674
        """
 
675
        packet = NS('foo') + NS('') + NS('password') + chr(0) + NS('foo')
 
676
        self.authServer.clock = task.Clock()
 
677
        d = self.authServer.ssh_USERAUTH_REQUEST(packet)
 
678
        return d.addCallback(self._checkFailed)
 
679
 
 
680
 
 
681
    def test__pamConvErrors(self):
 
682
        """
 
683
        _pamConv should fail if it gets a message that's not 1 or 2.
 
684
        """
 
685
        def secondTest(ignored):
 
686
            d2 = self.authServer._pamConv([('', 90)])
 
687
            return self.assertFailure(d2, ConchError)
 
688
 
 
689
        d = self.authServer._pamConv([('', 3)])
 
690
        return self.assertFailure(d, ConchError).addCallback(secondTest)
 
691
 
 
692
 
 
693
    def test_tryAuthEdgeCases(self):
 
694
        """
 
695
        tryAuth() has two edge cases that are difficult to reach.
 
696
 
 
697
        1) an authentication method auth_* returns None instead of a Deferred.
 
698
        2) an authentication type that is defined does not have a matching
 
699
           auth_* method.
 
700
 
 
701
        Both these cases should return a Deferred which fails with a
 
702
        ConchError.
 
703
        """
 
704
        def mockAuth(packet):
 
705
            return None
 
706
 
 
707
        self.patch(self.authServer, 'auth_publickey', mockAuth) # first case
 
708
        self.patch(self.authServer, 'auth_password', None) # second case
 
709
 
 
710
        def secondTest(ignored):
 
711
            d2 = self.authServer.tryAuth('password', None, None)
 
712
            return self.assertFailure(d2, ConchError)
 
713
 
 
714
        d1 = self.authServer.tryAuth('publickey', None, None)
 
715
        return self.assertFailure(d1, ConchError).addCallback(secondTest)
 
716
 
 
717
 
 
718
 
 
719
 
 
720
class SSHUserAuthClientTestCase(unittest.TestCase):
 
721
    """
 
722
    Tests for SSHUserAuthClient.
 
723
    """
 
724
 
 
725
 
 
726
    if keys is None:
 
727
        skip = "cannot run w/o PyCrypto"
 
728
 
 
729
 
 
730
    def setUp(self):
 
731
        self.authClient = ClientUserAuth('foo', FakeTransport.Service())
 
732
        self.authClient.transport = FakeTransport(None)
 
733
        self.authClient.transport.sessionID = 'test'
 
734
        self.authClient.serviceStarted()
 
735
 
 
736
 
 
737
    def tearDown(self):
 
738
        self.authClient.serviceStopped()
 
739
        self.authClient = None
 
740
 
 
741
 
 
742
    def test_init(self):
 
743
        """
 
744
        Test that client is initialized properly.
 
745
        """
 
746
        self.assertEquals(self.authClient.user, 'foo')
 
747
        self.assertEquals(self.authClient.instance.name, 'nancy')
 
748
        self.assertEquals(self.authClient.transport.packets,
 
749
                [(userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
750
                    + NS('none'))])
 
751
 
 
752
 
 
753
    def test_USERAUTH_SUCCESS(self):
 
754
        """
 
755
        Test that the client succeeds properly.
 
756
        """
 
757
        instance = [None]
 
758
        def stubSetService(service):
 
759
            instance[0] = service
 
760
        self.authClient.transport.setService = stubSetService
 
761
        self.authClient.ssh_USERAUTH_SUCCESS('')
 
762
        self.assertEquals(instance[0], self.authClient.instance)
 
763
 
 
764
 
 
765
    def test_publickey(self):
 
766
        """
 
767
        Test that the client can authenticate with a public key.
 
768
        """
 
769
        self.authClient.ssh_USERAUTH_FAILURE(NS('publickey') + '\x00')
 
770
        self.assertEquals(self.authClient.transport.packets[-1],
 
771
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
772
                    + NS('publickey') + '\x00' + NS('ssh-dss')
 
773
                    + NS(keys.Key.fromString(
 
774
                        keydata.publicDSA_openssh).blob())))
 
775
       # that key isn't good
 
776
        self.authClient.ssh_USERAUTH_FAILURE(NS('publickey') + '\x00')
 
777
        blob = NS(keys.Key.fromString(keydata.publicRSA_openssh).blob())
 
778
        self.assertEquals(self.authClient.transport.packets[-1],
 
779
                (userauth.MSG_USERAUTH_REQUEST, (NS('foo') + NS('nancy')
 
780
                    + NS('publickey') + '\x00'+ NS('ssh-rsa') + blob)))
 
781
        self.authClient.ssh_USERAUTH_PK_OK(NS('ssh-rsa')
 
782
            + NS(keys.Key.fromString(keydata.publicRSA_openssh).blob()))
 
783
        sigData = (NS(self.authClient.transport.sessionID)
 
784
                + chr(userauth.MSG_USERAUTH_REQUEST) + NS('foo')
 
785
                + NS('nancy') + NS('publickey') + '\xff' + NS('ssh-rsa')
 
786
                + blob)
 
787
        obj = keys.Key.fromString(keydata.privateRSA_openssh)
 
788
        self.assertEquals(self.authClient.transport.packets[-1],
 
789
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
790
                    + NS('publickey') + '\xff' + NS('ssh-rsa') + blob
 
791
                    + NS(obj.sign(sigData))))
 
792
 
 
793
 
 
794
    def test_publickey_without_privatekey(self):
 
795
        """
 
796
        If the SSHUserAuthClient doesn't return anything from signData,
 
797
        the client should start the authentication over again by requesting
 
798
        'none' authentication.
 
799
        """
 
800
        authClient = ClientAuthWithoutPrivateKey('foo',
 
801
                                                 FakeTransport.Service())
 
802
 
 
803
        authClient.transport = FakeTransport(None)
 
804
        authClient.transport.sessionID = 'test'
 
805
        authClient.serviceStarted()
 
806
        authClient.tryAuth('publickey')
 
807
        authClient.transport.packets = []
 
808
        self.assertIdentical(authClient.ssh_USERAUTH_PK_OK(''), None)
 
809
        self.assertEquals(authClient.transport.packets, [
 
810
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy') +
 
811
                 NS('none'))])
 
812
 
 
813
 
 
814
    def test_old_publickey_getPublicKey(self):
 
815
        """
 
816
        Old SSHUserAuthClients returned strings of public key blobs from
 
817
        getPublicKey().  Test that a Deprecation warning is raised but the key is
 
818
        verified correctly.
 
819
        """
 
820
        oldAuth = OldClientAuth('foo', FakeTransport.Service())
 
821
        oldAuth.transport = FakeTransport(None)
 
822
        oldAuth.transport.sessionID = 'test'
 
823
        oldAuth.serviceStarted()
 
824
        oldAuth.transport.packets = []
 
825
        self.assertWarns(DeprecationWarning, "Returning a string from "
 
826
                         "SSHUserAuthClient.getPublicKey() is deprecated since "
 
827
                         "Twisted 9.0.  Return a keys.Key() instead.",
 
828
                         userauth.__file__, oldAuth.tryAuth, 'publickey')
 
829
        self.assertEquals(oldAuth.transport.packets, [
 
830
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy') +
 
831
                 NS('publickey') + '\x00' + NS('ssh-rsa') +
 
832
                 NS(keys.Key.fromString(keydata.publicRSA_openssh).blob()))])
 
833
 
 
834
 
 
835
    def test_old_publickey_getPrivateKey(self):
 
836
        """
 
837
        Old SSHUserAuthClients returned a PyCrypto key object from
 
838
        getPrivateKey().  Test that _cbSignData signs the data warns the
 
839
        user about the deprecation, but signs the data correctly.
 
840
        """
 
841
        oldAuth = OldClientAuth('foo', FakeTransport.Service())
 
842
        d = self.assertWarns(DeprecationWarning, "Returning a PyCrypto key "
 
843
                             "object from SSHUserAuthClient.getPrivateKey() is "
 
844
                             "deprecated since Twisted 9.0.  "
 
845
                             "Return a keys.Key() instead.", userauth.__file__,
 
846
                             oldAuth.signData, None, 'data')
 
847
        def _checkSignedData(sig):
 
848
            self.assertEquals(sig,
 
849
                keys.Key.fromString(keydata.privateRSA_openssh).sign(
 
850
                    'data'))
 
851
        d.addCallback(_checkSignedData)
 
852
        return d
 
853
 
 
854
 
 
855
    def test_no_publickey(self):
 
856
        """
 
857
        If there's no public key, auth_publickey should return a Deferred
 
858
        called back with a False value.
 
859
        """
 
860
        self.authClient.getPublicKey = lambda x: None
 
861
        d = self.authClient.tryAuth('publickey')
 
862
        def check(result):
 
863
            self.assertFalse(result)
 
864
        return d.addCallback(check)
 
865
 
 
866
    def test_password(self):
 
867
        """
 
868
        Test that the client can authentication with a password.  This
 
869
        includes changing the password.
 
870
        """
 
871
        self.authClient.ssh_USERAUTH_FAILURE(NS('password') + '\x00')
 
872
        self.assertEquals(self.authClient.transport.packets[-1],
 
873
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
874
                    + NS('password') + '\x00' + NS('foo')))
 
875
        self.authClient.ssh_USERAUTH_PK_OK(NS('') + NS(''))
 
876
        self.assertEquals(self.authClient.transport.packets[-1],
 
877
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
878
                    + NS('password') + '\xff' + NS('foo') * 2))
 
879
 
 
880
 
 
881
    def test_no_password(self):
 
882
        """
 
883
        If getPassword returns None, tryAuth should return False.
 
884
        """
 
885
        self.authClient.getPassword = lambda: None
 
886
        self.assertFalse(self.authClient.tryAuth('password'))
 
887
 
 
888
 
 
889
    def test_keyboardInteractive(self):
 
890
        """
 
891
        Test that the client can authenticate using keyboard-interactive
 
892
        authentication.
 
893
        """
 
894
        self.authClient.ssh_USERAUTH_FAILURE(NS('keyboard-interactive')
 
895
               + '\x00')
 
896
        self.assertEquals(self.authClient.transport.packets[-1],
 
897
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
898
                    + NS('keyboard-interactive') + NS('')*2))
 
899
        self.authClient.ssh_USERAUTH_PK_OK(NS('')*3 + '\x00\x00\x00\x02'
 
900
                + NS('Name: ') + '\xff' + NS('Password: ') + '\x00')
 
901
        self.assertEquals(self.authClient.transport.packets[-1],
 
902
                (userauth.MSG_USERAUTH_INFO_RESPONSE, '\x00\x00\x00\x02'
 
903
                    + NS('foo')*2))
 
904
 
 
905
 
 
906
    def test_USERAUTH_PK_OK_unknown_method(self):
 
907
        """
 
908
        If C{SSHUserAuthClient} gets a MSG_USERAUTH_PK_OK packet when it's not
 
909
        expecting it, it should fail the current authentication and move on to
 
910
        the next type.
 
911
        """
 
912
        self.authClient.lastAuth = 'unknown'
 
913
        self.authClient.transport.packets = []
 
914
        self.authClient.ssh_USERAUTH_PK_OK('')
 
915
        self.assertEquals(self.authClient.transport.packets,
 
916
                          [(userauth.MSG_USERAUTH_REQUEST, NS('foo') +
 
917
                            NS('nancy') + NS('none'))])
 
918
 
 
919
 
 
920
    def test_USERAUTH_FAILURE_sorting(self):
 
921
        """
 
922
        ssh_USERAUTH_FAILURE should sort the methods by their position
 
923
        in SSHUserAuthClient.preferredOrder.  Methods that are not in
 
924
        preferredOrder should be sorted at the end of that list.
 
925
        """
 
926
        def auth_firstmethod():
 
927
            self.authClient.transport.sendPacket(255, 'here is data')
 
928
        def auth_anothermethod():
 
929
            self.authClient.transport.sendPacket(254, 'other data')
 
930
            return True
 
931
        self.authClient.auth_firstmethod = auth_firstmethod
 
932
        self.authClient.auth_anothermethod = auth_anothermethod
 
933
 
 
934
        # although they shouldn't get called, method callbacks auth_* MUST
 
935
        # exist in order for the test to work properly.
 
936
        self.authClient.ssh_USERAUTH_FAILURE(NS('anothermethod,password') +
 
937
                                             '\x00')
 
938
        # should send password packet
 
939
        self.assertEquals(self.authClient.transport.packets[-1],
 
940
                (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
941
                    + NS('password') + '\x00' + NS('foo')))
 
942
        self.authClient.ssh_USERAUTH_FAILURE(
 
943
            NS('firstmethod,anothermethod,password') + '\xff')
 
944
        self.assertEquals(self.authClient.transport.packets[-2:],
 
945
                          [(255, 'here is data'), (254, 'other data')])
 
946
 
 
947
 
 
948
    def test_disconnectIfNoMoreAuthentication(self):
 
949
        """
 
950
        If there are no more available user authentication messages,
 
951
        the SSHUserAuthClient should disconnect with code
 
952
        DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE.
 
953
        """
 
954
        self.authClient.ssh_USERAUTH_FAILURE(NS('password') + '\x00')
 
955
        self.authClient.ssh_USERAUTH_FAILURE(NS('password') + '\xff')
 
956
        self.assertEquals(self.authClient.transport.packets[-1],
 
957
                          (transport.MSG_DISCONNECT, '\x00\x00\x00\x0e' +
 
958
                           NS('no more authentication methods available') +
 
959
                           '\x00\x00\x00\x00'))
 
960
 
 
961
 
 
962
    def test_ebAuth(self):
 
963
        """
 
964
        _ebAuth (the generic authentication error handler) should send
 
965
        a request for the 'none' authentication method.
 
966
        """
 
967
        self.authClient.transport.packets = []
 
968
        self.authClient._ebAuth(None)
 
969
        self.assertEquals(self.authClient.transport.packets,
 
970
                [(userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
 
971
                    + NS('none'))])
 
972
 
 
973
 
 
974
    def test_defaults(self):
 
975
        """
 
976
        getPublicKey() should return None.  getPrivateKey() should return a
 
977
        failed Deferred.  getPassword() should return a failed Deferred.
 
978
        getGenericAnswers() should return a failed Deferred.
 
979
        """
 
980
        authClient = userauth.SSHUserAuthClient('foo', FakeTransport.Service())
 
981
        self.assertIdentical(authClient.getPublicKey(), None)
 
982
        def check(result):
 
983
            result.trap(NotImplementedError)
 
984
            d = authClient.getPassword()
 
985
            return d.addCallback(self.fail).addErrback(check2)
 
986
        def check2(result):
 
987
            result.trap(NotImplementedError)
 
988
            d = authClient.getGenericAnswers(None, None, None)
 
989
            return d.addCallback(self.fail).addErrback(check3)
 
990
        def check3(result):
 
991
            result.trap(NotImplementedError)
 
992
        d = authClient.getPrivateKey()
 
993
        return d.addCallback(self.fail).addErrback(check)
 
994
 
 
995
 
 
996
 
 
997
class LoopbackTestCase(unittest.TestCase):
 
998
 
 
999
 
 
1000
    if keys is None:
 
1001
        skip = "cannot run w/o PyCrypto or PyASN1"
 
1002
 
 
1003
 
 
1004
    class Factory:
 
1005
        class Service:
 
1006
            name = 'TestService'
 
1007
 
 
1008
 
 
1009
            def serviceStarted(self):
 
1010
                self.transport.loseConnection()
 
1011
 
 
1012
 
 
1013
            def serviceStopped(self):
 
1014
                pass
 
1015
 
 
1016
 
 
1017
        def getService(self, avatar, name):
 
1018
            return self.Service
 
1019
 
 
1020
 
 
1021
    def test_loopback(self):
 
1022
        """
 
1023
        Test that the userauth server and client play nicely with each other.
 
1024
        """
 
1025
        server = userauth.SSHUserAuthServer()
 
1026
        client = ClientUserAuth('foo', self.Factory.Service())
 
1027
 
 
1028
        # set up transports
 
1029
        server.transport = transport.SSHTransportBase()
 
1030
        server.transport.service = server
 
1031
        server.transport.isEncrypted = lambda x: True
 
1032
        client.transport = transport.SSHTransportBase()
 
1033
        client.transport.service = client
 
1034
        server.transport.sessionID = client.transport.sessionID = ''
 
1035
        # don't send key exchange packet
 
1036
        server.transport.sendKexInit = client.transport.sendKexInit = \
 
1037
                lambda: None
 
1038
 
 
1039
        # set up server authentication
 
1040
        server.transport.factory = self.Factory()
 
1041
        server.passwordDelay = 0 # remove bad password delay
 
1042
        realm = Realm()
 
1043
        portal = Portal(realm)
 
1044
        checker = SSHProtocolChecker()
 
1045
        checker.registerChecker(PasswordChecker())
 
1046
        checker.registerChecker(PrivateKeyChecker())
 
1047
        checker.registerChecker(PAMChecker())
 
1048
        checker.areDone = lambda aId: (
 
1049
            len(checker.successfulCredentials[aId]) == 3)
 
1050
        portal.registerChecker(checker)
 
1051
        server.transport.factory.portal = portal
 
1052
 
 
1053
        d = loopback.loopbackAsync(server.transport, client.transport)
 
1054
        server.transport.transport.logPrefix = lambda: '_ServerLoopback'
 
1055
        client.transport.transport.logPrefix = lambda: '_ClientLoopback'
 
1056
 
 
1057
        server.serviceStarted()
 
1058
        client.serviceStarted()
 
1059
 
 
1060
        def check(ignored):
 
1061
            self.assertEquals(server.transport.service.name, 'TestService')
 
1062
        return d.addCallback(check)