~myers-1/pyopenssl/npn

« back to all changes in this revision

Viewing changes to OpenSSL/test/test_ssl.py

  • Committer: Jean-Paul Calderone
  • Date: 2010-09-14 22:07:08 UTC
  • mfrom: (132.1.54 shore-up-tests)
  • Revision ID: exarkun@divmod.com-20100914220708-qbu4htpfdfsff9cm
Merge shore-up-tests, adding lots of new unit test coverage and fixing a few bugs as well

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
 
1
# Copyright (C) Jean-Paul Calderone 2008-2010, All rights reserved
2
2
 
3
3
"""
4
4
Unit tests for L{OpenSSL.SSL}.
5
5
"""
6
6
 
 
7
from errno import ECONNREFUSED, EINPROGRESS
7
8
from sys import platform
8
 
from socket import socket
 
9
from socket import error, socket
9
10
from os import makedirs
10
11
from os.path import join
11
12
from unittest import main
12
13
 
13
 
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey, load_certificate, load_privatekey
14
 
from OpenSSL.SSL import WantReadError, Context, ContextType, Connection, ConnectionType, Error
 
14
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, FILETYPE_ASN1
 
15
from OpenSSL.crypto import PKey, X509, X509Extension
 
16
from OpenSSL.crypto import dump_privatekey, load_privatekey
 
17
from OpenSSL.crypto import dump_certificate, load_certificate
 
18
 
 
19
from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
15
20
from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
16
21
from OpenSSL.SSL import OP_NO_SSLv2, OP_NO_SSLv3, OP_SINGLE_DH_USE
17
22
from OpenSSL.SSL import VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE
 
23
from OpenSSL.SSL import Error, SysCallError, WantReadError, ZeroReturnError
 
24
from OpenSSL.SSL import Context, ContextType, Connection, ConnectionType
 
25
 
18
26
from OpenSSL.test.util import TestCase
19
27
from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM
20
 
from OpenSSL.test.test_crypto import client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, root_cert_pem
 
28
from OpenSSL.test.test_crypto import client_cert_pem, client_key_pem
 
29
from OpenSSL.test.test_crypto import server_cert_pem, server_key_pem, root_cert_pem
 
30
 
21
31
try:
22
32
    from OpenSSL.SSL import OP_NO_QUERY_MTU
23
33
except ImportError:
32
42
    OP_NO_TICKET = None
33
43
 
34
44
 
 
45
# openssl dhparam 128 -out dh-128.pem (note that 128 is a small number of bits
 
46
# to use)
 
47
dhparam = """\
 
48
-----BEGIN DH PARAMETERS-----
 
49
MBYCEQCobsg29c9WZP/54oAPcwiDAgEC
 
50
-----END DH PARAMETERS-----
 
51
"""
 
52
 
 
53
 
 
54
def verify_cb(conn, cert, errnum, depth, ok):
 
55
    # print conn, cert, X509_verify_cert_error_string(errnum), depth, ok
 
56
    return ok
 
57
 
35
58
def socket_pair():
36
59
    """
37
60
    Establish and return a pair of network sockets connected to each other.
54
77
    client.send("y")
55
78
    assert server.recv(1024) == "y"
56
79
 
57
 
    # All our callers want non-blocking sockets, make it easy for them.
 
80
    # Most of our callers want non-blocking sockets, make it easy for them.
58
81
    server.setblocking(False)
59
82
    client.setblocking(False)
60
83
 
62
85
 
63
86
 
64
87
 
65
 
class ContextTests(TestCase):
 
88
class _LoopbackMixin:
 
89
    """
 
90
    Helper mixin which defines methods for creating a connected socket pair and
 
91
    for forcing two connected SSL sockets to talk to each other via memory BIOs.
 
92
    """
 
93
    def _loopback(self):
 
94
        (server, client) = socket_pair()
 
95
 
 
96
        ctx = Context(TLSv1_METHOD)
 
97
        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
 
98
        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
 
99
        server = Connection(ctx, server)
 
100
        server.set_accept_state()
 
101
        client = Connection(Context(TLSv1_METHOD), client)
 
102
        client.set_connect_state()
 
103
 
 
104
        for i in range(3):
 
105
            for conn in [client, server]:
 
106
                try:
 
107
                    conn.do_handshake()
 
108
                except WantReadError:
 
109
                    pass
 
110
 
 
111
        server.setblocking(True)
 
112
        client.setblocking(True)
 
113
        return server, client
 
114
 
 
115
 
 
116
    def _interactInMemory(self, client_conn, server_conn):
 
117
        """
 
118
        Try to read application bytes from each of the two L{Connection}
 
119
        objects.  Copy bytes back and forth between their send/receive buffers
 
120
        for as long as there is anything to copy.  When there is nothing more
 
121
        to copy, return C{None}.  If one of them actually manages to deliver
 
122
        some application bytes, return a two-tuple of the connection from which
 
123
        the bytes were read and the bytes themselves.
 
124
        """
 
125
        wrote = True
 
126
        while wrote:
 
127
            # Loop until neither side has anything to say
 
128
            wrote = False
 
129
 
 
130
            # Copy stuff from each side's send buffer to the other side's
 
131
            # receive buffer.
 
132
            for (read, write) in [(client_conn, server_conn),
 
133
                                  (server_conn, client_conn)]:
 
134
 
 
135
                # Give the side a chance to generate some more bytes, or
 
136
                # succeed.
 
137
                try:
 
138
                    bytes = read.recv(2 ** 16)
 
139
                except WantReadError:
 
140
                    # It didn't succeed, so we'll hope it generated some
 
141
                    # output.
 
142
                    pass
 
143
                else:
 
144
                    # It did succeed, so we'll stop now and let the caller deal
 
145
                    # with it.
 
146
                    return (read, bytes)
 
147
 
 
148
                while True:
 
149
                    # Keep copying as long as there's more stuff there.
 
150
                    try:
 
151
                        dirty = read.bio_read(4096)
 
152
                    except WantReadError:
 
153
                        # Okay, nothing more waiting to be sent.  Stop
 
154
                        # processing this send buffer.
 
155
                        break
 
156
                    else:
 
157
                        # Keep track of the fact that someone generated some
 
158
                        # output.
 
159
                        wrote = True
 
160
                        write.bio_write(dirty)
 
161
 
 
162
 
 
163
 
 
164
class ContextTests(TestCase, _LoopbackMixin):
66
165
    """
67
166
    Unit tests for L{OpenSSL.SSL.Context}.
68
167
    """
97
196
        self.assertRaises(TypeError, ctx.use_privatekey, "")
98
197
 
99
198
 
 
199
    def test_set_app_data_wrong_args(self):
 
200
        """
 
201
        L{Context.set_app_data} raises L{TypeError} if called with other than
 
202
        one argument.
 
203
        """
 
204
        context = Context(TLSv1_METHOD)
 
205
        self.assertRaises(TypeError, context.set_app_data)
 
206
        self.assertRaises(TypeError, context.set_app_data, None, None)
 
207
 
 
208
 
 
209
    def test_get_app_data_wrong_args(self):
 
210
        """
 
211
        L{Context.get_app_data} raises L{TypeError} if called with any
 
212
        arguments.
 
213
        """
 
214
        context = Context(TLSv1_METHOD)
 
215
        self.assertRaises(TypeError, context.get_app_data, None)
 
216
 
 
217
 
 
218
    def test_app_data(self):
 
219
        """
 
220
        L{Context.set_app_data} stores an object for later retrieval using
 
221
        L{Context.get_app_data}.
 
222
        """
 
223
        app_data = object()
 
224
        context = Context(TLSv1_METHOD)
 
225
        context.set_app_data(app_data)
 
226
        self.assertIdentical(context.get_app_data(), app_data)
 
227
 
 
228
 
 
229
    def test_set_options_wrong_args(self):
 
230
        """
 
231
        L{Context.set_options} raises L{TypeError} if called with the wrong
 
232
        number of arguments or a non-C{int} argument.
 
233
        """
 
234
        context = Context(TLSv1_METHOD)
 
235
        self.assertRaises(TypeError, context.set_options)
 
236
        self.assertRaises(TypeError, context.set_options, None)
 
237
        self.assertRaises(TypeError, context.set_options, 1, None)
 
238
 
 
239
 
 
240
    def test_set_timeout_wrong_args(self):
 
241
        """
 
242
        L{Context.set_timeout} raises L{TypeError} if called with the wrong
 
243
        number of arguments or a non-C{int} argument.
 
244
        """
 
245
        context = Context(TLSv1_METHOD)
 
246
        self.assertRaises(TypeError, context.set_timeout)
 
247
        self.assertRaises(TypeError, context.set_timeout, None)
 
248
        self.assertRaises(TypeError, context.set_timeout, 1, None)
 
249
 
 
250
 
 
251
    def test_get_timeout_wrong_args(self):
 
252
        """
 
253
        L{Context.get_timeout} raises L{TypeError} if called with any arguments.
 
254
        """
 
255
        context = Context(TLSv1_METHOD)
 
256
        self.assertRaises(TypeError, context.get_timeout, None)
 
257
 
 
258
 
 
259
    def test_timeout(self):
 
260
        """
 
261
        L{Context.set_timeout} sets the session timeout for all connections
 
262
        created using the context object.  L{Context.get_timeout} retrieves this
 
263
        value.
 
264
        """
 
265
        context = Context(TLSv1_METHOD)
 
266
        context.set_timeout(1234)
 
267
        self.assertEquals(context.get_timeout(), 1234)
 
268
 
 
269
 
 
270
    def test_set_verify_depth_wrong_args(self):
 
271
        """
 
272
        L{Context.set_verify_depth} raises L{TypeError} if called with the wrong
 
273
        number of arguments or a non-C{int} argument.
 
274
        """
 
275
        context = Context(TLSv1_METHOD)
 
276
        self.assertRaises(TypeError, context.set_verify_depth)
 
277
        self.assertRaises(TypeError, context.set_verify_depth, None)
 
278
        self.assertRaises(TypeError, context.set_verify_depth, 1, None)
 
279
 
 
280
 
 
281
    def test_get_verify_depth_wrong_args(self):
 
282
        """
 
283
        L{Context.get_verify_depth} raises L{TypeError} if called with any arguments.
 
284
        """
 
285
        context = Context(TLSv1_METHOD)
 
286
        self.assertRaises(TypeError, context.get_verify_depth, None)
 
287
 
 
288
 
 
289
    def test_verify_depth(self):
 
290
        """
 
291
        L{Context.set_verify_depth} sets the number of certificates in a chain
 
292
        to follow before giving up.  The value can be retrieved with
 
293
        L{Context.get_verify_depth}.
 
294
        """
 
295
        context = Context(TLSv1_METHOD)
 
296
        context.set_verify_depth(11)
 
297
        self.assertEquals(context.get_verify_depth(), 11)
 
298
 
 
299
 
 
300
    def _write_encrypted_pem(self, passphrase):
 
301
        """
 
302
        Write a new private key out to a new file, encrypted using the given
 
303
        passphrase.  Return the path to the new file.
 
304
        """
 
305
        key = PKey()
 
306
        key.generate_key(TYPE_RSA, 128)
 
307
        pemFile = self.mktemp()
 
308
        fObj = file(pemFile, 'w')
 
309
        fObj.write(dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase))
 
310
        fObj.close()
 
311
        return pemFile
 
312
 
 
313
 
 
314
    def test_set_passwd_cb_wrong_args(self):
 
315
        """
 
316
        L{Context.set_passwd_cb} raises L{TypeError} if called with the
 
317
        wrong arguments or with a non-callable first argument.
 
318
        """
 
319
        context = Context(TLSv1_METHOD)
 
320
        self.assertRaises(TypeError, context.set_passwd_cb)
 
321
        self.assertRaises(TypeError, context.set_passwd_cb, None)
 
322
        self.assertRaises(TypeError, context.set_passwd_cb, lambda: None, None, None)
 
323
 
 
324
 
100
325
    def test_set_passwd_cb(self):
101
326
        """
102
327
        L{Context.set_passwd_cb} accepts a callable which will be invoked when
103
328
        a private key is loaded from an encrypted PEM.
104
329
        """
105
 
        key = PKey()
106
 
        key.generate_key(TYPE_RSA, 128)
107
 
        pemFile = self.mktemp()
108
 
        fObj = file(pemFile, 'w')
109
330
        passphrase = "foobar"
110
 
        fObj.write(dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase))
111
 
        fObj.close()
112
 
 
 
331
        pemFile = self._write_encrypted_pem(passphrase)
113
332
        calledWith = []
114
333
        def passphraseCallback(maxlen, verify, extra):
115
334
            calledWith.append((maxlen, verify, extra))
123
342
        self.assertEqual(calledWith[0][2], None)
124
343
 
125
344
 
 
345
    def test_passwd_callback_exception(self):
 
346
        """
 
347
        L{Context.use_privatekey_file} propagates any exception raised by the
 
348
        passphrase callback.
 
349
        """
 
350
        pemFile = self._write_encrypted_pem("monkeys are nice")
 
351
        def passphraseCallback(maxlen, verify, extra):
 
352
            raise RuntimeError("Sorry, I am a fail.")
 
353
 
 
354
        context = Context(TLSv1_METHOD)
 
355
        context.set_passwd_cb(passphraseCallback)
 
356
        self.assertRaises(RuntimeError, context.use_privatekey_file, pemFile)
 
357
 
 
358
 
 
359
    def test_passwd_callback_false(self):
 
360
        """
 
361
        L{Context.use_privatekey_file} raises L{OpenSSL.SSL.Error} if the
 
362
        passphrase callback returns a false value.
 
363
        """
 
364
        pemFile = self._write_encrypted_pem("monkeys are nice")
 
365
        def passphraseCallback(maxlen, verify, extra):
 
366
            return None
 
367
 
 
368
        context = Context(TLSv1_METHOD)
 
369
        context.set_passwd_cb(passphraseCallback)
 
370
        self.assertRaises(Error, context.use_privatekey_file, pemFile)
 
371
 
 
372
 
 
373
    def test_passwd_callback_non_string(self):
 
374
        """
 
375
        L{Context.use_privatekey_file} raises L{OpenSSL.SSL.Error} if the
 
376
        passphrase callback returns a true non-string value.
 
377
        """
 
378
        pemFile = self._write_encrypted_pem("monkeys are nice")
 
379
        def passphraseCallback(maxlen, verify, extra):
 
380
            return 10
 
381
 
 
382
        context = Context(TLSv1_METHOD)
 
383
        context.set_passwd_cb(passphraseCallback)
 
384
        self.assertRaises(Error, context.use_privatekey_file, pemFile)
 
385
 
 
386
 
 
387
    def test_passwd_callback_too_long(self):
 
388
        """
 
389
        If the passphrase returned by the passphrase callback returns a string
 
390
        longer than the indicated maximum length, it is truncated.
 
391
        """
 
392
        # A priori knowledge!
 
393
        passphrase = "x" * 1024
 
394
        pemFile = self._write_encrypted_pem(passphrase)
 
395
        def passphraseCallback(maxlen, verify, extra):
 
396
            assert maxlen == 1024
 
397
            return passphrase + "y"
 
398
 
 
399
        context = Context(TLSv1_METHOD)
 
400
        context.set_passwd_cb(passphraseCallback)
 
401
        # This shall succeed because the truncated result is the correct
 
402
        # passphrase.
 
403
        context.use_privatekey_file(pemFile)
 
404
 
 
405
 
126
406
    def test_set_info_callback(self):
127
407
        """
128
408
        L{Context.set_info_callback} accepts a callable which will be invoked
158
438
 
159
439
 
160
440
    def _load_verify_locations_test(self, *args):
 
441
        """
 
442
        Create a client context which will verify the peer certificate and call
 
443
        its C{load_verify_locations} method with C{*args}.  Then connect it to a
 
444
        server and ensure that the handshake succeeds.
 
445
        """
161
446
        (server, client) = socket_pair()
162
447
 
163
448
        clientContext = Context(TLSv1_METHOD)
235
520
        self._load_verify_locations_test(None, capath)
236
521
 
237
522
 
238
 
    if platform in ("darwin", "win32"):
239
 
        "set_default_verify_paths appears not to work on OS X or Windows"
 
523
    def test_load_verify_locations_wrong_args(self):
 
524
        """
 
525
        L{Context.load_verify_locations} raises L{TypeError} if called with
 
526
        the wrong number of arguments or with non-C{str} arguments.
 
527
        """
 
528
        context = Context(TLSv1_METHOD)
 
529
        self.assertRaises(TypeError, context.load_verify_locations)
 
530
        self.assertRaises(TypeError, context.load_verify_locations, object())
 
531
        self.assertRaises(TypeError, context.load_verify_locations, object(), object())
 
532
        self.assertRaises(TypeError, context.load_verify_locations, None, None, None)
 
533
 
 
534
 
 
535
    if platform == "win32":
 
536
        "set_default_verify_paths appears not to work on Windows.  "
240
537
        "See LP#404343 and LP#404344."
241
538
    else:
242
539
        def test_set_default_verify_paths(self):
278
575
        self.assertRaises(TypeError, context.set_default_verify_paths, 1)
279
576
        self.assertRaises(TypeError, context.set_default_verify_paths, "")
280
577
 
 
578
 
281
579
    def test_add_extra_chain_cert_invalid_cert(self):
282
580
        """
283
581
        L{Context.add_extra_chain_cert} raises L{TypeError} if called with
290
588
        self.assertRaises(TypeError, context.add_extra_chain_cert, object(), object())
291
589
 
292
590
 
 
591
    def _create_certificate_chain(self):
 
592
        """
 
593
        Construct and return a chain of certificates.
 
594
 
 
595
            1. A new self-signed certificate authority certificate (cacert)
 
596
            2. A new intermediate certificate signed by cacert (icert)
 
597
            3. A new server certificate signed by icert (scert)
 
598
        """
 
599
        caext = X509Extension('basicConstraints', False, 'CA:true')
 
600
 
 
601
        # Step 1
 
602
        cakey = PKey()
 
603
        cakey.generate_key(TYPE_RSA, 512)
 
604
        cacert = X509()
 
605
        cacert.get_subject().commonName = "Authority Certificate"
 
606
        cacert.set_issuer(cacert.get_subject())
 
607
        cacert.set_pubkey(cakey)
 
608
        cacert.set_notBefore("20000101000000Z")
 
609
        cacert.set_notAfter("20200101000000Z")
 
610
        cacert.add_extensions([caext])
 
611
        cacert.set_serial_number(0)
 
612
        cacert.sign(cakey, "sha1")
 
613
 
 
614
        # Step 2
 
615
        ikey = PKey()
 
616
        ikey.generate_key(TYPE_RSA, 512)
 
617
        icert = X509()
 
618
        icert.get_subject().commonName = "Intermediate Certificate"
 
619
        icert.set_issuer(cacert.get_subject())
 
620
        icert.set_pubkey(ikey)
 
621
        icert.set_notBefore("20000101000000Z")
 
622
        icert.set_notAfter("20200101000000Z")
 
623
        icert.add_extensions([caext])
 
624
        icert.set_serial_number(0)
 
625
        icert.sign(cakey, "sha1")
 
626
 
 
627
        # Step 3
 
628
        skey = PKey()
 
629
        skey.generate_key(TYPE_RSA, 512)
 
630
        scert = X509()
 
631
        scert.get_subject().commonName = "Server Certificate"
 
632
        scert.set_issuer(icert.get_subject())
 
633
        scert.set_pubkey(skey)
 
634
        scert.set_notBefore("20000101000000Z")
 
635
        scert.set_notAfter("20200101000000Z")
 
636
        scert.add_extensions([X509Extension('basicConstraints', True, 'CA:false')])
 
637
        scert.set_serial_number(0)
 
638
        scert.sign(ikey, "sha1")
 
639
 
 
640
        return [(cakey, cacert), (ikey, icert), (skey, scert)]
 
641
 
 
642
 
 
643
    def _handshake_test(self, serverContext, clientContext):
 
644
        """
 
645
        Verify that a client and server created with the given contexts can
 
646
        successfully handshake and communicate.
 
647
        """
 
648
        serverSocket, clientSocket = socket_pair()
 
649
 
 
650
        server = Connection(serverContext, serverSocket)
 
651
        server.set_accept_state()
 
652
 
 
653
        client = Connection(clientContext, clientSocket)
 
654
        client.set_connect_state()
 
655
 
 
656
        # Make them talk to each other.
 
657
        # self._interactInMemory(client, server)
 
658
        for i in range(3):
 
659
            for s in [client, server]:
 
660
                try:
 
661
                    s.do_handshake()
 
662
                except WantReadError:
 
663
                    pass
 
664
 
 
665
 
293
666
    def test_add_extra_chain_cert(self):
294
667
        """
295
668
        L{Context.add_extra_chain_cert} accepts an L{X509} instance to add to
296
669
        the certificate chain.
297
 
        """
298
 
        context = Context(TLSv1_METHOD)
299
 
        context.add_extra_chain_cert(load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
300
 
        # XXX Oh no, actually asserting something about its behavior would be really hard.
301
 
        # See #477521.
302
 
 
303
 
 
304
 
 
305
 
class ConnectionTests(TestCase):
 
670
 
 
671
        See L{_create_certificate_chain} for the details of the certificate
 
672
        chain tested.
 
673
 
 
674
        The chain is tested by starting a server with scert and connecting
 
675
        to it with a client which trusts cacert and requires verification to
 
676
        succeed.
 
677
        """
 
678
        chain = self._create_certificate_chain()
 
679
        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
 
680
 
 
681
        # Dump the CA certificate to a file because that's the only way to load
 
682
        # it as a trusted CA in the client context.
 
683
        for cert, name in [(cacert, 'ca.pem'), (icert, 'i.pem'), (scert, 's.pem')]:
 
684
            fObj = file(name, 'w')
 
685
            fObj.write(dump_certificate(FILETYPE_PEM, cert))
 
686
            fObj.close()
 
687
            fObj = file(name.replace('pem', 'asn1'), 'w')
 
688
            fObj.write(dump_certificate(FILETYPE_ASN1, cert))
 
689
            fObj.close()
 
690
 
 
691
        for key, name in [(cakey, 'ca.key'), (ikey, 'i.key'), (skey, 's.key')]:
 
692
            fObj = file(name, 'w')
 
693
            fObj.write(dump_privatekey(FILETYPE_PEM, key))
 
694
            fObj.close()
 
695
 
 
696
        # Create the server context
 
697
        serverContext = Context(TLSv1_METHOD)
 
698
        serverContext.use_privatekey(skey)
 
699
        serverContext.use_certificate(scert)
 
700
        # The client already has cacert, we only need to give them icert.
 
701
        serverContext.add_extra_chain_cert(icert)
 
702
 
 
703
        # Create the client
 
704
        clientContext = Context(TLSv1_METHOD)
 
705
        clientContext.set_verify(
 
706
            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
 
707
        clientContext.load_verify_locations('ca.pem')
 
708
 
 
709
        # Try it out.
 
710
        self._handshake_test(serverContext, clientContext)
 
711
 
 
712
 
 
713
    def test_use_certificate_chain_file(self):
 
714
        """
 
715
        L{Context.use_certificate_chain_file} reads a certificate chain from
 
716
        the specified file.
 
717
 
 
718
        The chain is tested by starting a server with scert and connecting
 
719
        to it with a client which trusts cacert and requires verification to
 
720
        succeed.
 
721
        """
 
722
        chain = self._create_certificate_chain()
 
723
        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
 
724
 
 
725
        # Write out the chain file.
 
726
        chainFile = self.mktemp()
 
727
        fObj = file(chainFile, 'w')
 
728
        # Most specific to least general.
 
729
        fObj.write(dump_certificate(FILETYPE_PEM, scert))
 
730
        fObj.write(dump_certificate(FILETYPE_PEM, icert))
 
731
        fObj.write(dump_certificate(FILETYPE_PEM, cacert))
 
732
        fObj.close()
 
733
 
 
734
        serverContext = Context(TLSv1_METHOD)
 
735
        serverContext.use_certificate_chain_file(chainFile)
 
736
        serverContext.use_privatekey(skey)
 
737
 
 
738
        fObj = file('ca.pem', 'w')
 
739
        fObj.write(dump_certificate(FILETYPE_PEM, cacert))
 
740
        fObj.close()
 
741
 
 
742
        clientContext = Context(TLSv1_METHOD)
 
743
        clientContext.set_verify(
 
744
            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
 
745
        clientContext.load_verify_locations('ca.pem')
 
746
 
 
747
        self._handshake_test(serverContext, clientContext)
 
748
 
 
749
    # XXX load_client_ca
 
750
    # XXX set_session_id
 
751
 
 
752
    def test_get_verify_mode_wrong_args(self):
 
753
        """
 
754
        L{Context.get_verify_mode} raises L{TypeError} if called with any
 
755
        arguments.
 
756
        """
 
757
        context = Context(TLSv1_METHOD)
 
758
        self.assertRaises(TypeError, context.get_verify_mode, None)
 
759
 
 
760
 
 
761
    def test_get_verify_mode(self):
 
762
        """
 
763
        L{Context.get_verify_mode} returns the verify mode flags previously
 
764
        passed to L{Context.set_verify}.
 
765
        """
 
766
        context = Context(TLSv1_METHOD)
 
767
        self.assertEquals(context.get_verify_mode(), 0)
 
768
        context.set_verify(
 
769
            VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None)
 
770
        self.assertEquals(
 
771
            context.get_verify_mode(), VERIFY_PEER | VERIFY_CLIENT_ONCE)
 
772
 
 
773
 
 
774
    def test_load_tmp_dh_wrong_args(self):
 
775
        """
 
776
        L{Context.load_tmp_dh} raises L{TypeError} if called with the wrong
 
777
        number of arguments or with a non-C{str} argument.
 
778
        """
 
779
        context = Context(TLSv1_METHOD)
 
780
        self.assertRaises(TypeError, context.load_tmp_dh)
 
781
        self.assertRaises(TypeError, context.load_tmp_dh, "foo", None)
 
782
        self.assertRaises(TypeError, context.load_tmp_dh, object())
 
783
 
 
784
 
 
785
    def test_load_tmp_dh_missing_file(self):
 
786
        """
 
787
        L{Context.load_tmp_dh} raises L{OpenSSL.SSL.Error} if the specified file
 
788
        does not exist.
 
789
        """
 
790
        context = Context(TLSv1_METHOD)
 
791
        self.assertRaises(Error, context.load_tmp_dh, "hello")
 
792
 
 
793
 
 
794
    def test_load_tmp_dh(self):
 
795
        """
 
796
        L{Context.load_tmp_dh} loads Diffie-Hellman parameters from the
 
797
        specified file.
 
798
        """
 
799
        context = Context(TLSv1_METHOD)
 
800
        dhfilename = self.mktemp()
 
801
        dhfile = open(dhfilename, "w")
 
802
        dhfile.write(dhparam)
 
803
        dhfile.close()
 
804
        context.load_tmp_dh(dhfilename)
 
805
        # XXX What should I assert here? -exarkun
 
806
 
 
807
 
 
808
    def test_set_cipher_list(self):
 
809
        """
 
810
        L{Context.set_cipher_list} accepts a C{str} naming the ciphers which
 
811
        connections created with the context object will be able to choose from.
 
812
        """
 
813
        context = Context(TLSv1_METHOD)
 
814
        context.set_cipher_list("hello world:EXP-RC4-MD5")
 
815
        conn = Connection(context, None)
 
816
        self.assertEquals(conn.get_cipher_list(), ["EXP-RC4-MD5"])
 
817
 
 
818
 
 
819
 
 
820
class ConnectionTests(TestCase, _LoopbackMixin):
306
821
    """
307
822
    Unit tests for L{OpenSSL.SSL.Connection}.
308
823
    """
 
824
    # XXX want_write
 
825
    # XXX want_read
 
826
    # XXX get_peer_certificate -> None
 
827
    # XXX sock_shutdown
 
828
    # XXX master_key -> TypeError
 
829
    # XXX server_random -> TypeError
 
830
    # XXX state_string
 
831
    # XXX connect -> TypeError
 
832
    # XXX connect_ex -> TypeError
 
833
    # XXX set_connect_state -> TypeError
 
834
    # XXX set_accept_state -> TypeError
 
835
    # XXX renegotiate_pending
 
836
    # XXX do_handshake -> TypeError
 
837
    # XXX bio_read -> TypeError
 
838
    # XXX recv -> TypeError
 
839
    # XXX send -> TypeError
 
840
    # XXX bio_write -> TypeError
 
841
 
309
842
    def test_type(self):
310
843
        """
311
844
        L{Connection} and L{ConnectionType} refer to the same type object and
335
868
        self.assertRaises(TypeError, connection.get_context, None)
336
869
 
337
870
 
 
871
    def test_pending(self):
 
872
        """
 
873
        L{Connection.pending} returns the number of bytes available for
 
874
        immediate read.
 
875
        """
 
876
        connection = Connection(Context(TLSv1_METHOD), None)
 
877
        self.assertEquals(connection.pending(), 0)
 
878
 
 
879
 
 
880
    def test_pending_wrong_args(self):
 
881
        """
 
882
        L{Connection.pending} raises L{TypeError} if called with any arguments.
 
883
        """
 
884
        connection = Connection(Context(TLSv1_METHOD), None)
 
885
        self.assertRaises(TypeError, connection.pending, None)
 
886
 
 
887
 
 
888
    def test_connect_wrong_args(self):
 
889
        """
 
890
        L{Connection.connect} raises L{TypeError} if called with a non-address
 
891
        argument or with the wrong number of arguments.
 
892
        """
 
893
        connection = Connection(Context(TLSv1_METHOD), socket())
 
894
        self.assertRaises(TypeError, connection.connect, None)
 
895
        self.assertRaises(TypeError, connection.connect)
 
896
        self.assertRaises(TypeError, connection.connect, ("127.0.0.1", 1), None)
 
897
 
 
898
 
 
899
    def test_connect_refused(self):
 
900
        """
 
901
        L{Connection.connect} raises L{socket.error} if the underlying socket
 
902
        connect method raises it.
 
903
        """
 
904
        client = socket()
 
905
        context = Context(TLSv1_METHOD)
 
906
        clientSSL = Connection(context, client)
 
907
        exc = self.assertRaises(error, clientSSL.connect, ("127.0.0.1", 1))
 
908
        self.assertEquals(exc.args[0], ECONNREFUSED)
 
909
 
 
910
 
 
911
    def test_connect(self):
 
912
        """
 
913
        L{Connection.connect} establishes a connection to the specified address.
 
914
        """
 
915
        port = socket()
 
916
        port.bind(('', 0))
 
917
        port.listen(3)
 
918
 
 
919
        clientSSL = Connection(Context(TLSv1_METHOD), socket())
 
920
        clientSSL.connect(port.getsockname())
 
921
 
 
922
 
 
923
    def test_connect_ex(self):
 
924
        """
 
925
        If there is a connection error, L{Connection.connect_ex} returns the
 
926
        errno instead of raising an exception.
 
927
        """
 
928
        port = socket()
 
929
        port.bind(('', 0))
 
930
        port.listen(3)
 
931
 
 
932
        clientSSL = Connection(Context(TLSv1_METHOD), socket())
 
933
        clientSSL.setblocking(False)
 
934
        self.assertEquals(
 
935
            clientSSL.connect_ex(port.getsockname()), EINPROGRESS)
 
936
 
 
937
 
 
938
    def test_accept_wrong_args(self):
 
939
        """
 
940
        L{Connection.accept} raises L{TypeError} if called with any arguments.
 
941
        """
 
942
        connection = Connection(Context(TLSv1_METHOD), socket())
 
943
        self.assertRaises(TypeError, connection.accept, None)
 
944
 
 
945
 
 
946
    def test_accept(self):
 
947
        """
 
948
        L{Connection.accept} accepts a pending connection attempt and returns a
 
949
        tuple of a new L{Connection} (the accepted client) and the address the
 
950
        connection originated from.
 
951
        """
 
952
        ctx = Context(TLSv1_METHOD)
 
953
        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
 
954
        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
 
955
        port = socket()
 
956
        portSSL = Connection(ctx, port)
 
957
        portSSL.bind(('', 0))
 
958
        portSSL.listen(3)
 
959
 
 
960
        clientSSL = Connection(Context(TLSv1_METHOD), socket())
 
961
        clientSSL.connect(portSSL.getsockname())
 
962
 
 
963
        serverSSL, address = portSSL.accept()
 
964
 
 
965
        self.assertTrue(isinstance(serverSSL, Connection))
 
966
        self.assertIdentical(serverSSL.get_context(), ctx)
 
967
        self.assertEquals(address, clientSSL.getsockname())
 
968
 
 
969
 
 
970
    def test_shutdown_wrong_args(self):
 
971
        """
 
972
        L{Connection.shutdown} raises L{TypeError} if called with the wrong
 
973
        number of arguments or with arguments other than integers.
 
974
        """
 
975
        connection = Connection(Context(TLSv1_METHOD), None)
 
976
        self.assertRaises(TypeError, connection.shutdown, None)
 
977
        self.assertRaises(TypeError, connection.get_shutdown, None)
 
978
        self.assertRaises(TypeError, connection.set_shutdown)
 
979
        self.assertRaises(TypeError, connection.set_shutdown, None)
 
980
        self.assertRaises(TypeError, connection.set_shutdown, 0, 1)
 
981
 
 
982
 
 
983
    def test_shutdown(self):
 
984
        """
 
985
        L{Connection.shutdown} performs an SSL-level connection shutdown.
 
986
        """
 
987
        server, client = self._loopback()
 
988
        self.assertFalse(server.shutdown())
 
989
        self.assertEquals(server.get_shutdown(), SENT_SHUTDOWN)
 
990
        self.assertRaises(ZeroReturnError, client.recv, 1024)
 
991
        self.assertEquals(client.get_shutdown(), RECEIVED_SHUTDOWN)
 
992
        client.shutdown()
 
993
        self.assertEquals(client.get_shutdown(), SENT_SHUTDOWN|RECEIVED_SHUTDOWN)
 
994
        self.assertRaises(ZeroReturnError, server.recv, 1024)
 
995
        self.assertEquals(server.get_shutdown(), SENT_SHUTDOWN|RECEIVED_SHUTDOWN)
 
996
 
 
997
 
 
998
    def test_set_shutdown(self):
 
999
        """
 
1000
        L{Connection.set_shutdown} sets the state of the SSL connection shutdown
 
1001
        process.
 
1002
        """
 
1003
        connection = Connection(Context(TLSv1_METHOD), socket())
 
1004
        connection.set_shutdown(RECEIVED_SHUTDOWN)
 
1005
        self.assertEquals(connection.get_shutdown(), RECEIVED_SHUTDOWN)
 
1006
 
 
1007
 
 
1008
    def test_app_data_wrong_args(self):
 
1009
        """
 
1010
        L{Connection.set_app_data} raises L{TypeError} if called with other than
 
1011
        one argument.  L{Connection.get_app_data} raises L{TypeError} if called
 
1012
        with any arguments.
 
1013
        """
 
1014
        conn = Connection(Context(TLSv1_METHOD), None)
 
1015
        self.assertRaises(TypeError, conn.get_app_data, None)
 
1016
        self.assertRaises(TypeError, conn.set_app_data)
 
1017
        self.assertRaises(TypeError, conn.set_app_data, None, None)
 
1018
 
 
1019
 
 
1020
    def test_app_data(self):
 
1021
        """
 
1022
        Any object can be set as app data by passing it to
 
1023
        L{Connection.set_app_data} and later retrieved with
 
1024
        L{Connection.get_app_data}.
 
1025
        """
 
1026
        conn = Connection(Context(TLSv1_METHOD), None)
 
1027
        app_data = object()
 
1028
        conn.set_app_data(app_data)
 
1029
        self.assertIdentical(conn.get_app_data(), app_data)
 
1030
 
 
1031
 
 
1032
    def test_makefile(self):
 
1033
        """
 
1034
        L{Connection.makefile} is not implemented and calling that method raises
 
1035
        L{NotImplementedError}.
 
1036
        """
 
1037
        conn = Connection(Context(TLSv1_METHOD), None)
 
1038
        self.assertRaises(NotImplementedError, conn.makefile)
 
1039
 
 
1040
 
 
1041
 
 
1042
class ConnectionGetCipherListTests(TestCase):
 
1043
    """
 
1044
    Tests for L{Connection.get_cipher_list}.
 
1045
    """
 
1046
    def test_wrong_args(self):
 
1047
        """
 
1048
        L{Connection.get_cipher_list} raises L{TypeError} if called with any
 
1049
        arguments.
 
1050
        """
 
1051
        connection = Connection(Context(TLSv1_METHOD), None)
 
1052
        self.assertRaises(TypeError, connection.get_cipher_list, None)
 
1053
 
 
1054
 
 
1055
    def test_result(self):
 
1056
        """
 
1057
        L{Connection.get_cipher_list} returns a C{list} of C{str} giving the
 
1058
        names of the ciphers which might be used.
 
1059
        """
 
1060
        connection = Connection(Context(TLSv1_METHOD), None)
 
1061
        ciphers = connection.get_cipher_list()
 
1062
        self.assertTrue(isinstance(ciphers, list))
 
1063
        for cipher in ciphers:
 
1064
            self.assertTrue(isinstance(cipher, str))
 
1065
 
 
1066
 
 
1067
 
 
1068
class ConnectionSendallTests(TestCase, _LoopbackMixin):
 
1069
    """
 
1070
    Tests for L{Connection.sendall}.
 
1071
    """
 
1072
    def test_wrong_args(self):
 
1073
        """
 
1074
        When called with arguments other than a single string,
 
1075
        L{Connection.sendall} raises L{TypeError}.
 
1076
        """
 
1077
        connection = Connection(Context(TLSv1_METHOD), None)
 
1078
        self.assertRaises(TypeError, connection.sendall)
 
1079
        self.assertRaises(TypeError, connection.sendall, object())
 
1080
        self.assertRaises(TypeError, connection.sendall, "foo", "bar")
 
1081
 
 
1082
 
 
1083
    def test_short(self):
 
1084
        """
 
1085
        L{Connection.sendall} transmits all of the bytes in the string passed to
 
1086
        it.
 
1087
        """
 
1088
        server, client = self._loopback()
 
1089
        server.sendall('x')
 
1090
        self.assertEquals(client.recv(1), 'x')
 
1091
 
 
1092
 
 
1093
    def test_long(self):
 
1094
        """
 
1095
        L{Connection.sendall} transmits all of the bytes in the string passed to
 
1096
        it even if this requires multiple calls of an underlying write function.
 
1097
        """
 
1098
        server, client = self._loopback()
 
1099
        message ='x' * 1024 * 128 + 'y'
 
1100
        server.sendall(message)
 
1101
        accum = []
 
1102
        received = 0
 
1103
        while received < len(message):
 
1104
            bytes = client.recv(1024)
 
1105
            accum.append(bytes)
 
1106
            received += len(bytes)
 
1107
        self.assertEquals(message, ''.join(accum))
 
1108
 
 
1109
 
 
1110
    def test_closed(self):
 
1111
        """
 
1112
        If the underlying socket is closed, L{Connection.sendall} propagates the
 
1113
        write error from the low level write call.
 
1114
        """
 
1115
        server, client = self._loopback()
 
1116
        client.close()
 
1117
        server.sendall("hello, world")
 
1118
        self.assertRaises(SysCallError, server.sendall, "hello, world")
 
1119
 
 
1120
 
 
1121
 
 
1122
class ConnectionRenegotiateTests(TestCase, _LoopbackMixin):
 
1123
    """
 
1124
    Tests for SSL renegotiation APIs.
 
1125
    """
 
1126
    def test_renegotiate_wrong_args(self):
 
1127
        """
 
1128
        L{Connection.renegotiate} raises L{TypeError} if called with any
 
1129
        arguments.
 
1130
        """
 
1131
        connection = Connection(Context(TLSv1_METHOD), None)
 
1132
        self.assertRaises(TypeError, connection.renegotiate, None)
 
1133
 
 
1134
 
 
1135
    def test_total_renegotiations_wrong_args(self):
 
1136
        """
 
1137
        L{Connection.total_renegotiations} raises L{TypeError} if called with
 
1138
        any arguments.
 
1139
        """
 
1140
        connection = Connection(Context(TLSv1_METHOD), None)
 
1141
        self.assertRaises(TypeError, connection.total_renegotiations, None)
 
1142
 
 
1143
 
 
1144
    def test_total_renegotiations(self):
 
1145
        """
 
1146
        L{Connection.total_renegotiations} returns C{0} before any
 
1147
        renegotiations have happened.
 
1148
        """
 
1149
        connection = Connection(Context(TLSv1_METHOD), None)
 
1150
        self.assertEquals(connection.total_renegotiations(), 0)
 
1151
 
 
1152
 
 
1153
#     def test_renegotiate(self):
 
1154
#         """
 
1155
#         """
 
1156
#         server, client = self._loopback()
 
1157
 
 
1158
#         server.send("hello world")
 
1159
#         self.assertEquals(client.recv(len("hello world")), "hello world")
 
1160
 
 
1161
#         self.assertEquals(server.total_renegotiations(), 0)
 
1162
#         self.assertTrue(server.renegotiate())
 
1163
 
 
1164
#         server.setblocking(False)
 
1165
#         client.setblocking(False)
 
1166
#         while server.renegotiate_pending():
 
1167
#             client.do_handshake()
 
1168
#             server.do_handshake()
 
1169
 
 
1170
#         self.assertEquals(server.total_renegotiations(), 1)
 
1171
 
 
1172
 
 
1173
 
338
1174
 
339
1175
class ErrorTests(TestCase):
340
1176
    """
392
1228
 
393
1229
 
394
1230
 
395
 
def verify_cb(conn, cert, errnum, depth, ok):
396
 
    return ok
397
 
 
398
 
class MemoryBIOTests(TestCase):
 
1231
class MemoryBIOTests(TestCase, _LoopbackMixin):
399
1232
    """
400
1233
    Tests for L{OpenSSL.SSL.Connection} using a memory BIO.
401
1234
    """
441
1274
        return client_conn
442
1275
 
443
1276
 
444
 
    def _loopback(self, client_conn, server_conn):
445
 
        """
446
 
        Try to read application bytes from each of the two L{Connection}
447
 
        objects.  Copy bytes back and forth between their send/receive buffers
448
 
        for as long as there is anything to copy.  When there is nothing more
449
 
        to copy, return C{None}.  If one of them actually manages to deliver
450
 
        some application bytes, return a two-tuple of the connection from which
451
 
        the bytes were read and the bytes themselves.
452
 
        """
453
 
        wrote = True
454
 
        while wrote:
455
 
            # Loop until neither side has anything to say
456
 
            wrote = False
457
 
 
458
 
            # Copy stuff from each side's send buffer to the other side's
459
 
            # receive buffer.
460
 
            for (read, write) in [(client_conn, server_conn),
461
 
                                  (server_conn, client_conn)]:
462
 
 
463
 
                # Give the side a chance to generate some more bytes, or
464
 
                # succeed.
465
 
                try:
466
 
                    bytes = read.recv(2 ** 16)
467
 
                except WantReadError:
468
 
                    # It didn't succeed, so we'll hope it generated some
469
 
                    # output.
470
 
                    pass
471
 
                else:
472
 
                    # It did succeed, so we'll stop now and let the caller deal
473
 
                    # with it.
474
 
                    return (read, bytes)
475
 
 
476
 
                while True:
477
 
                    # Keep copying as long as there's more stuff there.
478
 
                    try:
479
 
                        dirty = read.bio_read(4096)
480
 
                    except WantReadError:
481
 
                        # Okay, nothing more waiting to be sent.  Stop
482
 
                        # processing this send buffer.
483
 
                        break
484
 
                    else:
485
 
                        # Keep track of the fact that someone generated some
486
 
                        # output.
487
 
                        wrote = True
488
 
                        write.bio_write(dirty)
489
 
 
490
 
 
491
1277
    def test_memoryConnect(self):
492
1278
        """
493
1279
        Two L{Connection}s which use memory BIOs can be manually connected by
506
1292
        # First, the handshake needs to happen.  We'll deliver bytes back and
507
1293
        # forth between the client and server until neither of them feels like
508
1294
        # speaking any more.
509
 
        self.assertIdentical(self._loopback(client_conn, server_conn), None)
 
1295
        self.assertIdentical(
 
1296
            self._interactInMemory(client_conn, server_conn), None)
510
1297
 
511
1298
        # Now that the handshake is done, there should be a key and nonces.
512
1299
        self.assertNotIdentical(server_conn.master_key(), None)
522
1309
 
523
1310
        server_conn.write(important_message)
524
1311
        self.assertEquals(
525
 
            self._loopback(client_conn, server_conn),
 
1312
            self._interactInMemory(client_conn, server_conn),
526
1313
            (client_conn, important_message))
527
1314
 
528
1315
        client_conn.write(important_message[::-1])
529
1316
        self.assertEquals(
530
 
            self._loopback(client_conn, server_conn),
 
1317
            self._interactInMemory(client_conn, server_conn),
531
1318
            (server_conn, important_message[::-1]))
532
1319
 
533
1320
 
595
1382
        server = self._server(None)
596
1383
        client = self._client(None)
597
1384
 
598
 
        self._loopback(client, server)
 
1385
        self._interactInMemory(client, server)
599
1386
 
600
1387
        size = 2 ** 15
601
1388
        sent = client.send("x" * size)
604
1391
        # meaningless.
605
1392
        self.assertTrue(sent < size)
606
1393
 
607
 
        receiver, received = self._loopback(client, server)
 
1394
        receiver, received = self._interactInMemory(client, server)
608
1395
        self.assertIdentical(receiver, server)
609
1396
 
610
1397
        # We can rely on all of these bytes being received at once because
644
1431
        expected = func(ctx)
645
1432
        self.assertEqual(client.get_client_ca_list(), [])
646
1433
        self.assertEqual(server.get_client_ca_list(), expected)
647
 
        self._loopback(client, server)
 
1434
        self._interactInMemory(client, server)
648
1435
        self.assertEqual(client.get_client_ca_list(), expected)
649
1436
        self.assertEqual(server.get_client_ca_list(), expected)
650
1437
 
827
1614
 
828
1615
        cadesc = cacert.get_subject()
829
1616
        sedesc = secert.get_subject()
830
 
        cldesc = clcert.get_subject()
831
1617
 
832
1618
        def set_replaces_add_ca(ctx):
833
1619
            ctx.add_client_ca(clcert)