~tribaal/txaws/xss-hardening

« back to all changes in this revision

Viewing changes to txaws/client/ssl.py

  • Committer: Duncan McGreggor
  • Date: 2009-11-22 02:20:42 UTC
  • mto: (44.3.2 484858-s3-scripts)
  • mto: This revision was merged to the branch mainline in revision 52.
  • Revision ID: duncan@canonical.com-20091122022042-4zi231hxni1z53xd
* Updated the LICENSE file with copyright information.
* Updated the README with license information.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
from glob import glob
2
 
import os
3
 
import re
4
 
import sys
5
 
 
6
 
from OpenSSL import SSL
7
 
from OpenSSL.crypto import load_certificate, FILETYPE_PEM
8
 
 
9
 
from twisted.internet.ssl import CertificateOptions
10
 
 
11
 
from txaws import exception
12
 
 
13
 
 
14
 
__all__ = ["VerifyingContextFactory", "get_ca_certs"]
15
 
 
16
 
 
17
 
# Multiple defaults are supported; just add more paths, separated by colons.
18
 
if sys.platform == "darwin":
19
 
    DEFAULT_CERTS_PATH = "/System/Library/OpenSSL/certs/"
20
 
# XXX Windows users can file a bug to add theirs, since we don't know what
21
 
# the right path is
22
 
else:
23
 
    DEFAULT_CERTS_PATH = "/etc/ssl/certs/"
24
 
 
25
 
 
26
 
class VerifyingContextFactory(CertificateOptions):
27
 
    """
28
 
    A SSL context factory to pass to C{connectSSL} to check for hostname
29
 
    validity.
30
 
    """
31
 
 
32
 
    def __init__(self, host, caCerts=None):
33
 
        if caCerts is None:
34
 
            caCerts = get_global_ca_certs()
35
 
        CertificateOptions.__init__(self, verify=True, caCerts=caCerts)
36
 
        self.host = host
37
 
 
38
 
    def _dnsname_match(self, dn, host):
39
 
        pats = []
40
 
        for frag in dn.split(r"."):
41
 
            if frag == "*":
42
 
                pats.append("[^.]+")
43
 
            else:
44
 
                frag = re.escape(frag)
45
 
                pats.append(frag.replace(r"\*", "[^.]*"))
46
 
 
47
 
        rx = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
48
 
        return bool(rx.match(host))
49
 
 
50
 
    def verify_callback(self, connection, x509, errno, depth, preverifyOK):
51
 
        # Only check depth == 0 on chained certificates.
52
 
        if depth == 0:
53
 
            dns_found = False
54
 
            if getattr(x509, "get_extension", None) is not None:
55
 
                for index in range(x509.get_extension_count()):
56
 
                    extension = x509.get_extension(index)
57
 
                    if extension.get_short_name() != "subjectAltName":
58
 
                        continue
59
 
                    data = str(extension)
60
 
                    for element in data.split(", "):
61
 
                        key, value = element.split(":")
62
 
                        if key != "DNS":
63
 
                            continue
64
 
                        if self._dnsname_match(value, self.host):
65
 
                            return preverifyOK
66
 
                        dns_found = True
67
 
                    break
68
 
            if not dns_found:
69
 
                commonName = x509.get_subject().commonName
70
 
                if commonName is None:
71
 
                    return False
72
 
                if not self._dnsname_match(commonName, self.host):
73
 
                    return False
74
 
            else:
75
 
                return False
76
 
        return preverifyOK
77
 
 
78
 
    def _makeContext(self):
79
 
        context = CertificateOptions._makeContext(self)
80
 
        context.set_verify(
81
 
            SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
82
 
            self.verify_callback)
83
 
        return context
84
 
 
85
 
 
86
 
def get_ca_certs():
87
 
    """
88
 
    Retrieve a list of CAs at either the DEFAULT_CERTS_PATH or the env
89
 
    override, TXAWS_CERTS_PATH.
90
 
 
91
 
    In order to find .pem files, this function checks first for presence of the
92
 
    TXAWS_CERTS_PATH environment variable that should point to a directory
93
 
    containing cert files. In the absense of this variable, the module-level
94
 
    DEFAULT_CERTS_PATH will be used instead.
95
 
 
96
 
    Note that both of these variables have have multiple paths in them, just
97
 
    like the familiar PATH environment variable (separated by colons).
98
 
    """
99
 
    cert_paths = os.getenv("TXAWS_CERTS_PATH", DEFAULT_CERTS_PATH).split(":")
100
 
    certificate_authority_map = {}
101
 
    for path in cert_paths:
102
 
        if not path:
103
 
            continue
104
 
        for cert_file_name in glob(os.path.join(path, "*.pem")):
105
 
            # There might be some dead symlinks in there, so let's make sure
106
 
            # it's real.
107
 
            if not os.path.exists(cert_file_name):
108
 
                continue
109
 
            cert_file = open(cert_file_name)
110
 
            data = cert_file.read()
111
 
            cert_file.close()
112
 
            x509 = load_certificate(FILETYPE_PEM, data)
113
 
            digest = x509.digest("sha1")
114
 
            # Now, de-duplicate in case the same cert has multiple names.
115
 
            certificate_authority_map[digest] = x509
116
 
    values = certificate_authority_map.values()
117
 
    if len(values) == 0:
118
 
        raise exception.CertsNotFoundError("Could not find any .pem files.")
119
 
    return values
120
 
 
121
 
 
122
 
_ca_certs = None
123
 
 
124
 
 
125
 
def get_global_ca_certs():
126
 
    """Retrieve a singleton of CA certificates."""
127
 
    global _ca_certs
128
 
    if _ca_certs is None:
129
 
        _ca_certs = get_ca_certs()
130
 
    return _ca_certs