~stepankk/pyopenssl/bug-845445

« back to all changes in this revision

Viewing changes to OpenSSL/test/test_ssl.py

  • Committer: Jean-Paul Calderone
  • Date: 2011-06-06 12:33:31 UTC
  • mfrom: (153.1.4 sni)
  • Revision ID: exarkun@divmod.com-20110606123331-vm00hkfitja61m7c
Add client and server support for SNI.

Show diffs side-by-side

added added

removed removed

Lines of Context:
5
5
Unit tests for L{OpenSSL.SSL}.
6
6
"""
7
7
 
 
8
from gc import collect
8
9
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK
9
 
from sys import platform
 
10
from sys import platform, version_info
10
11
from socket import error, socket
11
12
from os import makedirs
12
13
from os.path import join
13
14
from unittest import main
 
15
from weakref import ref
14
16
 
15
17
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
16
18
from OpenSSL.crypto import PKey, X509, X509Extension
873
875
 
874
876
 
875
877
 
 
878
class ServerNameCallbackTests(TestCase, _LoopbackMixin):
 
879
    """
 
880
    Tests for L{Context.set_tlsext_servername_callback} and its interaction with
 
881
    L{Connection}.
 
882
    """
 
883
    def test_wrong_args(self):
 
884
        """
 
885
        L{Context.set_tlsext_servername_callback} raises L{TypeError} if called
 
886
        with other than one argument.
 
887
        """
 
888
        context = Context(TLSv1_METHOD)
 
889
        self.assertRaises(TypeError, context.set_tlsext_servername_callback)
 
890
        self.assertRaises(
 
891
            TypeError, context.set_tlsext_servername_callback, 1, 2)
 
892
 
 
893
    def test_old_callback_forgotten(self):
 
894
        """
 
895
        If L{Context.set_tlsext_servername_callback} is used to specify a new
 
896
        callback, the one it replaces is dereferenced.
 
897
        """
 
898
        def callback(connection):
 
899
            pass
 
900
 
 
901
        def replacement(connection):
 
902
            pass
 
903
 
 
904
        context = Context(TLSv1_METHOD)
 
905
        context.set_tlsext_servername_callback(callback)
 
906
 
 
907
        tracker = ref(callback)
 
908
        del callback
 
909
 
 
910
        context.set_tlsext_servername_callback(replacement)
 
911
        collect()
 
912
        self.assertIdentical(None, tracker())
 
913
 
 
914
 
 
915
    def test_no_servername(self):
 
916
        """
 
917
        When a client specifies no server name, the callback passed to
 
918
        L{Context.set_tlsext_servername_callback} is invoked and the result of
 
919
        L{Connection.get_servername} is C{None}.
 
920
        """
 
921
        args = []
 
922
        def servername(conn):
 
923
            args.append((conn, conn.get_servername()))
 
924
        context = Context(TLSv1_METHOD)
 
925
        context.set_tlsext_servername_callback(servername)
 
926
 
 
927
        # Lose our reference to it.  The Context is responsible for keeping it
 
928
        # alive now.
 
929
        del servername
 
930
        collect()
 
931
 
 
932
        # Necessary to actually accept the connection
 
933
        context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
 
934
        context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
 
935
 
 
936
        # Do a little connection to trigger the logic
 
937
        server = Connection(context, None)
 
938
        server.set_accept_state()
 
939
 
 
940
        client = Connection(Context(TLSv1_METHOD), None)
 
941
        client.set_connect_state()
 
942
 
 
943
        self._interactInMemory(server, client)
 
944
 
 
945
        self.assertEqual([(server, None)], args)
 
946
 
 
947
 
 
948
    def test_servername(self):
 
949
        """
 
950
        When a client specifies a server name in its hello message, the callback
 
951
        passed to L{Contexts.set_tlsext_servername_callback} is invoked and the
 
952
        result of L{Connection.get_servername} is that server name.
 
953
        """
 
954
        args = []
 
955
        def servername(conn):
 
956
            args.append((conn, conn.get_servername()))
 
957
        context = Context(TLSv1_METHOD)
 
958
        context.set_tlsext_servername_callback(servername)
 
959
 
 
960
        # Necessary to actually accept the connection
 
961
        context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
 
962
        context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
 
963
 
 
964
        # Do a little connection to trigger the logic
 
965
        server = Connection(context, None)
 
966
        server.set_accept_state()
 
967
 
 
968
        client = Connection(Context(TLSv1_METHOD), None)
 
969
        client.set_connect_state()
 
970
        client.set_tlsext_host_name(b("foo1.example.com"))
 
971
 
 
972
        self._interactInMemory(server, client)
 
973
 
 
974
        self.assertEqual([(server, b("foo1.example.com"))], args)
 
975
 
 
976
 
 
977
 
876
978
class ConnectionTests(TestCase, _LoopbackMixin):
877
979
    """
878
980
    Unit tests for L{OpenSSL.SSL.Connection}.
924
1026
        self.assertRaises(TypeError, connection.get_context, None)
925
1027
 
926
1028
 
 
1029
    def test_set_context_wrong_args(self):
 
1030
        """
 
1031
        L{Connection.set_context} raises L{TypeError} if called with a
 
1032
        non-L{Context} instance argument or with any number of arguments other
 
1033
        than 1.
 
1034
        """
 
1035
        ctx = Context(TLSv1_METHOD)
 
1036
        connection = Connection(ctx, None)
 
1037
        self.assertRaises(TypeError, connection.set_context)
 
1038
        self.assertRaises(TypeError, connection.set_context, object())
 
1039
        self.assertRaises(TypeError, connection.set_context, "hello")
 
1040
        self.assertRaises(TypeError, connection.set_context, 1)
 
1041
        self.assertRaises(TypeError, connection.set_context, 1, 2)
 
1042
        self.assertRaises(
 
1043
            TypeError, connection.set_context, Context(TLSv1_METHOD), 2)
 
1044
        self.assertIdentical(ctx, connection.get_context())
 
1045
 
 
1046
 
 
1047
    def test_set_context(self):
 
1048
        """
 
1049
        L{Connection.set_context} specifies a new L{Context} instance to be used
 
1050
        for the connection.
 
1051
        """
 
1052
        original = Context(SSLv23_METHOD)
 
1053
        replacement = Context(TLSv1_METHOD)
 
1054
        connection = Connection(original, None)
 
1055
        connection.set_context(replacement)
 
1056
        self.assertIdentical(replacement, connection.get_context())
 
1057
        # Lose our references to the contexts, just in case the Connection isn't
 
1058
        # properly managing its own contributions to their reference counts.
 
1059
        del original, replacement
 
1060
        collect()
 
1061
 
 
1062
 
 
1063
    def test_set_tlsext_host_name_wrong_args(self):
 
1064
        """
 
1065
        If L{Connection.set_tlsext_host_name} is called with a non-byte string
 
1066
        argument or a byte string with an embedded NUL or other than one
 
1067
        argument, L{TypeError} is raised.
 
1068
        """
 
1069
        conn = Connection(Context(TLSv1_METHOD), None)
 
1070
        self.assertRaises(TypeError, conn.set_tlsext_host_name)
 
1071
        self.assertRaises(TypeError, conn.set_tlsext_host_name, object())
 
1072
        self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456)
 
1073
        self.assertRaises(
 
1074
            TypeError, conn.set_tlsext_host_name, b("with\0null"))
 
1075
 
 
1076
        if version_info >= (3,):
 
1077
            # On Python 3.x, don't accidentally implicitly convert from text.
 
1078
            self.assertRaises(
 
1079
                TypeError,
 
1080
                conn.set_tlsext_host_name, b("example.com").decode("ascii"))
 
1081
 
 
1082
 
 
1083
    def test_get_servername_wrong_args(self):
 
1084
        """
 
1085
        L{Connection.get_servername} raises L{TypeError} if called with any
 
1086
        arguments.
 
1087
        """
 
1088
        connection = Connection(Context(TLSv1_METHOD), None)
 
1089
        self.assertRaises(TypeError, connection.get_servername, object())
 
1090
        self.assertRaises(TypeError, connection.get_servername, 1)
 
1091
        self.assertRaises(TypeError, connection.get_servername, "hello")
 
1092
 
 
1093
 
927
1094
    def test_pending(self):
928
1095
        """
929
1096
        L{Connection.pending} returns the number of bytes available for