~myers-1/pyopenssl/npn

« back to all changes in this revision

Viewing changes to OpenSSL/test/test_ssl.py

  • Committer: Jean-Paul Calderone
  • Date: 2010-10-08 02:19:58 UTC
  • mfrom: (132.2.57 py3k-port)
  • Revision ID: exarkun@divmod.com-20101008021958-xhodr0riwrhtmqvj
Python 3.x support

Show diffs side-by-side

added added

removed removed

Lines of Context:
23
23
from OpenSSL.SSL import Error, SysCallError, WantReadError, ZeroReturnError
24
24
from OpenSSL.SSL import Context, ContextType, Connection, ConnectionType
25
25
 
26
 
from OpenSSL.test.util import TestCase
 
26
from OpenSSL.test.util import TestCase, bytes, b
27
27
from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM
28
28
from OpenSSL.test.test_crypto import client_cert_pem, client_key_pem
29
29
from OpenSSL.test.test_crypto import server_cert_pem, server_key_pem, root_cert_pem
52
52
 
53
53
 
54
54
def verify_cb(conn, cert, errnum, depth, ok):
55
 
    # print conn, cert, X509_verify_cert_error_string(errnum), depth, ok
56
55
    return ok
57
56
 
58
57
def socket_pair():
72
71
    # Let's pass some unencrypted data to make sure our socket connection is
73
72
    # fine.  Just one byte, so we don't have to worry about buffers getting
74
73
    # filled up or fragmentation.
75
 
    server.send("x")
76
 
    assert client.recv(1024) == "x"
77
 
    client.send("y")
78
 
    assert server.recv(1024) == "y"
 
74
    server.send(b("x"))
 
75
    assert client.recv(1024) == b("x")
 
76
    client.send(b("y"))
 
77
    assert server.recv(1024) == b("y")
79
78
 
80
79
    # Most of our callers want non-blocking sockets, make it easy for them.
81
80
    server.setblocking(False)
85
84
 
86
85
 
87
86
 
 
87
def handshake(client, server):
 
88
    conns = [client, server]
 
89
    while conns:
 
90
        for conn in conns:
 
91
            try:
 
92
                conn.do_handshake()
 
93
            except WantReadError:
 
94
                pass
 
95
            else:
 
96
                conns.remove(conn)
 
97
 
 
98
 
88
99
class _LoopbackMixin:
89
100
    """
90
101
    Helper mixin which defines methods for creating a connected socket pair and
101
112
        client = Connection(Context(TLSv1_METHOD), client)
102
113
        client.set_connect_state()
103
114
 
104
 
        for i in range(3):
105
 
            for conn in [client, server]:
106
 
                try:
107
 
                    conn.do_handshake()
108
 
                except WantReadError:
109
 
                    pass
 
115
        handshake(client, server)
110
116
 
111
117
        server.setblocking(True)
112
118
        client.setblocking(True)
305
311
        key = PKey()
306
312
        key.generate_key(TYPE_RSA, 128)
307
313
        pemFile = self.mktemp()
308
 
        fObj = file(pemFile, 'w')
309
 
        fObj.write(dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase))
 
314
        fObj = open(pemFile, 'w')
 
315
        pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase)
 
316
        fObj.write(pem.decode('ascii'))
310
317
        fObj.close()
311
318
        return pemFile
312
319
 
327
334
        L{Context.set_passwd_cb} accepts a callable which will be invoked when
328
335
        a private key is loaded from an encrypted PEM.
329
336
        """
330
 
        passphrase = "foobar"
 
337
        passphrase = b("foobar")
331
338
        pemFile = self._write_encrypted_pem(passphrase)
332
339
        calledWith = []
333
340
        def passphraseCallback(maxlen, verify, extra):
347
354
        L{Context.use_privatekey_file} propagates any exception raised by the
348
355
        passphrase callback.
349
356
        """
350
 
        pemFile = self._write_encrypted_pem("monkeys are nice")
 
357
        pemFile = self._write_encrypted_pem(b("monkeys are nice"))
351
358
        def passphraseCallback(maxlen, verify, extra):
352
359
            raise RuntimeError("Sorry, I am a fail.")
353
360
 
361
368
        L{Context.use_privatekey_file} raises L{OpenSSL.SSL.Error} if the
362
369
        passphrase callback returns a false value.
363
370
        """
364
 
        pemFile = self._write_encrypted_pem("monkeys are nice")
 
371
        pemFile = self._write_encrypted_pem(b("monkeys are nice"))
365
372
        def passphraseCallback(maxlen, verify, extra):
366
373
            return None
367
374
 
375
382
        L{Context.use_privatekey_file} raises L{OpenSSL.SSL.Error} if the
376
383
        passphrase callback returns a true non-string value.
377
384
        """
378
 
        pemFile = self._write_encrypted_pem("monkeys are nice")
 
385
        pemFile = self._write_encrypted_pem(b("monkeys are nice"))
379
386
        def passphraseCallback(maxlen, verify, extra):
380
387
            return 10
381
388
 
390
397
        longer than the indicated maximum length, it is truncated.
391
398
        """
392
399
        # A priori knowledge!
393
 
        passphrase = "x" * 1024
 
400
        passphrase = b("x") * 1024
394
401
        pemFile = self._write_encrypted_pem(passphrase)
395
402
        def passphraseCallback(maxlen, verify, extra):
396
403
            assert maxlen == 1024
397
 
            return passphrase + "y"
 
404
            return passphrase + b("y")
398
405
 
399
406
        context = Context(TLSv1_METHOD)
400
407
        context.set_passwd_cb(passphraseCallback)
465
472
        serverSSL = Connection(serverContext, server)
466
473
        serverSSL.set_accept_state()
467
474
 
468
 
        for i in range(3):
469
 
            for ssl in clientSSL, serverSSL:
470
 
                try:
471
 
                    # Without load_verify_locations above, the handshake
472
 
                    # will fail:
473
 
                    # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE',
474
 
                    #          'certificate verify failed')]
475
 
                    ssl.do_handshake()
476
 
                except WantReadError:
477
 
                    pass
 
475
        # Without load_verify_locations above, the handshake
 
476
        # will fail:
 
477
        # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE',
 
478
        #          'certificate verify failed')]
 
479
        handshake(clientSSL, serverSSL)
478
480
 
479
481
        cert = clientSSL.get_peer_certificate()
480
482
        self.assertEqual(cert.get_subject().CN, 'Testing Root CA')
486
488
        certificates within for verification purposes.
487
489
        """
488
490
        cafile = self.mktemp()
489
 
        fObj = file(cafile, 'w')
490
 
        fObj.write(cleartextCertificatePEM)
 
491
        fObj = open(cafile, 'w')
 
492
        fObj.write(cleartextCertificatePEM.decode('ascii'))
491
493
        fObj.close()
492
494
 
493
495
        self._load_verify_locations_test(cafile)
513
515
        # Hash value computed manually with c_rehash to avoid depending on
514
516
        # c_rehash in the test suite.
515
517
        cafile = join(capath, 'c7adac82.0')
516
 
        fObj = file(cafile, 'w')
517
 
        fObj.write(cleartextCertificatePEM)
 
518
        fObj = open(cafile, 'w')
 
519
        fObj.write(cleartextCertificatePEM.decode('ascii'))
518
520
        fObj.close()
519
521
 
520
522
        self._load_verify_locations_test(None, capath)
596
598
            2. A new intermediate certificate signed by cacert (icert)
597
599
            3. A new server certificate signed by icert (scert)
598
600
        """
599
 
        caext = X509Extension('basicConstraints', False, 'CA:true')
 
601
        caext = X509Extension(b('basicConstraints'), False, b('CA:true'))
600
602
 
601
603
        # Step 1
602
604
        cakey = PKey()
605
607
        cacert.get_subject().commonName = "Authority Certificate"
606
608
        cacert.set_issuer(cacert.get_subject())
607
609
        cacert.set_pubkey(cakey)
608
 
        cacert.set_notBefore("20000101000000Z")
609
 
        cacert.set_notAfter("20200101000000Z")
 
610
        cacert.set_notBefore(b("20000101000000Z"))
 
611
        cacert.set_notAfter(b("20200101000000Z"))
610
612
        cacert.add_extensions([caext])
611
613
        cacert.set_serial_number(0)
612
614
        cacert.sign(cakey, "sha1")
618
620
        icert.get_subject().commonName = "Intermediate Certificate"
619
621
        icert.set_issuer(cacert.get_subject())
620
622
        icert.set_pubkey(ikey)
621
 
        icert.set_notBefore("20000101000000Z")
622
 
        icert.set_notAfter("20200101000000Z")
 
623
        icert.set_notBefore(b("20000101000000Z"))
 
624
        icert.set_notAfter(b("20200101000000Z"))
623
625
        icert.add_extensions([caext])
624
626
        icert.set_serial_number(0)
625
627
        icert.sign(cakey, "sha1")
631
633
        scert.get_subject().commonName = "Server Certificate"
632
634
        scert.set_issuer(icert.get_subject())
633
635
        scert.set_pubkey(skey)
634
 
        scert.set_notBefore("20000101000000Z")
635
 
        scert.set_notAfter("20200101000000Z")
636
 
        scert.add_extensions([X509Extension('basicConstraints', True, 'CA:false')])
 
636
        scert.set_notBefore(b("20000101000000Z"))
 
637
        scert.set_notAfter(b("20200101000000Z"))
 
638
        scert.add_extensions([
 
639
                X509Extension(b('basicConstraints'), True, b('CA:false'))])
637
640
        scert.set_serial_number(0)
638
641
        scert.sign(ikey, "sha1")
639
642
 
681
684
        # Dump the CA certificate to a file because that's the only way to load
682
685
        # it as a trusted CA in the client context.
683
686
        for cert, name in [(cacert, 'ca.pem'), (icert, 'i.pem'), (scert, 's.pem')]:
684
 
            fObj = file(name, 'w')
685
 
            fObj.write(dump_certificate(FILETYPE_PEM, cert))
686
 
            fObj.close()
687
 
            fObj = file(name.replace('pem', 'asn1'), 'w')
688
 
            fObj.write(dump_certificate(FILETYPE_ASN1, cert))
 
687
            fObj = open(name, 'w')
 
688
            fObj.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii'))
689
689
            fObj.close()
690
690
 
691
691
        for key, name in [(cakey, 'ca.key'), (ikey, 'i.key'), (skey, 's.key')]:
692
 
            fObj = file(name, 'w')
693
 
            fObj.write(dump_privatekey(FILETYPE_PEM, key))
 
692
            fObj = open(name, 'w')
 
693
            fObj.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii'))
694
694
            fObj.close()
695
695
 
696
696
        # Create the server context
724
724
 
725
725
        # Write out the chain file.
726
726
        chainFile = self.mktemp()
727
 
        fObj = file(chainFile, 'w')
 
727
        fObj = open(chainFile, 'w')
728
728
        # Most specific to least general.
729
 
        fObj.write(dump_certificate(FILETYPE_PEM, scert))
730
 
        fObj.write(dump_certificate(FILETYPE_PEM, icert))
731
 
        fObj.write(dump_certificate(FILETYPE_PEM, cacert))
 
729
        fObj.write(dump_certificate(FILETYPE_PEM, scert).decode('ascii'))
 
730
        fObj.write(dump_certificate(FILETYPE_PEM, icert).decode('ascii'))
 
731
        fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii'))
732
732
        fObj.close()
733
733
 
734
734
        serverContext = Context(TLSv1_METHOD)
735
735
        serverContext.use_certificate_chain_file(chainFile)
736
736
        serverContext.use_privatekey(skey)
737
737
 
738
 
        fObj = file('ca.pem', 'w')
739
 
        fObj.write(dump_certificate(FILETYPE_PEM, cacert))
 
738
        fObj = open('ca.pem', 'w')
 
739
        fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii'))
740
740
        fObj.close()
741
741
 
742
742
        clientContext = Context(TLSv1_METHOD)
921
921
        # XXX An assertion?  Or something?
922
922
 
923
923
 
924
 
    def test_connect_ex(self):
925
 
        """
926
 
        If there is a connection error, L{Connection.connect_ex} returns the
927
 
        errno instead of raising an exception.
928
 
        """
929
 
        port = socket()
930
 
        port.bind(('', 0))
931
 
        port.listen(3)
 
924
    if platform == "darwin":
 
925
        "connect_ex sometimes causes a kernel panic on OS X 10.6.4"
 
926
    else:
 
927
        def test_connect_ex(self):
 
928
            """
 
929
            If there is a connection error, L{Connection.connect_ex} returns the
 
930
            errno instead of raising an exception.
 
931
            """
 
932
            port = socket()
 
933
            port.bind(('', 0))
 
934
            port.listen(3)
932
935
 
933
 
        clientSSL = Connection(Context(TLSv1_METHOD), socket())
934
 
        clientSSL.setblocking(False)
935
 
        result = clientSSL.connect_ex(port.getsockname())
936
 
        expected = (EINPROGRESS, EWOULDBLOCK)
937
 
        self.assertTrue(
938
 
                result in expected, "%r not in %r" % (result, expected))
 
936
            clientSSL = Connection(Context(TLSv1_METHOD), socket())
 
937
            clientSSL.setblocking(False)
 
938
            result = clientSSL.connect_ex(port.getsockname())
 
939
            expected = (EINPROGRESS, EWOULDBLOCK)
 
940
            self.assertTrue(
 
941
                    result in expected, "%r not in %r" % (result, expected))
939
942
 
940
943
 
941
944
    def test_accept_wrong_args(self):
1092
1095
        it.
1093
1096
        """
1094
1097
        server, client = self._loopback()
1095
 
        server.sendall('x')
1096
 
        self.assertEquals(client.recv(1), 'x')
 
1098
        server.sendall(b('x'))
 
1099
        self.assertEquals(client.recv(1), b('x'))
1097
1100
 
1098
1101
 
1099
1102
    def test_long(self):
1105
1108
        # Should be enough, underlying SSL_write should only do 16k at a time.
1106
1109
        # On Windows, after 32k of bytes the write will block (forever - because
1107
1110
        # no one is yet reading).
1108
 
        message ='x' * (1024 * 32 - 1) + 'y'
 
1111
        message = b('x') * (1024 * 32 - 1) + b('y')
1109
1112
        server.sendall(message)
1110
1113
        accum = []
1111
1114
        received = 0
1112
1115
        while received < len(message):
1113
 
            bytes = client.recv(1024)
1114
 
            accum.append(bytes)
1115
 
            received += len(bytes)
1116
 
        self.assertEquals(message, ''.join(accum))
 
1116
            data = client.recv(1024)
 
1117
            accum.append(data)
 
1118
            received += len(data)
 
1119
        self.assertEquals(message, b('').join(accum))
1117
1120
 
1118
1121
 
1119
1122
    def test_closed(self):
1122
1125
        write error from the low level write call.
1123
1126
        """
1124
1127
        server, client = self._loopback()
1125
 
        client.close()
1126
 
        server.sendall("hello, world")
 
1128
        server.sock_shutdown(2)
1127
1129
        self.assertRaises(SysCallError, server.sendall, "hello, world")
1128
1130
 
1129
1131
 
1314
1316
        self.assertNotEquals(client_conn.client_random(), client_conn.server_random())
1315
1317
 
1316
1318
        # Here are the bytes we'll try to send.
1317
 
        important_message = 'One if by land, two if by sea.'
 
1319
        important_message = b('One if by land, two if by sea.')
1318
1320
 
1319
1321
        server_conn.write(important_message)
1320
1322
        self.assertEquals(
1337
1339
        code, as no memory BIO is involved here).  Even though this isn't a
1338
1340
        memory BIO test, it's convenient to have it here.
1339
1341
        """
1340
 
        (server, client) = socket_pair()
1341
 
 
1342
 
        # Let the encryption begin...
1343
 
        client_conn = self._client(client)
1344
 
        server_conn = self._server(server)
1345
 
 
1346
 
        # Establish the connection
1347
 
        established = False
1348
 
        while not established:
1349
 
            established = True  # assume the best
1350
 
            for ssl in client_conn, server_conn:
1351
 
                try:
1352
 
                    # Generally a recv() or send() could also work instead
1353
 
                    # of do_handshake(), and we would stop on the first
1354
 
                    # non-exception.
1355
 
                    ssl.do_handshake()
1356
 
                except WantReadError:
1357
 
                    established = False
1358
 
 
1359
 
        important_message = "Help me Obi Wan Kenobi, you're my only hope."
 
1342
        server_conn, client_conn = self._loopback()
 
1343
 
 
1344
        important_message = b("Help me Obi Wan Kenobi, you're my only hope.")
1360
1345
        client_conn.send(important_message)
1361
1346
        msg = server_conn.recv(1024)
1362
1347
        self.assertEqual(msg, important_message)