~tribaal/txaws/xss-hardening

« back to all changes in this revision

Viewing changes to txaws/client/ssl.py

  • Committer: Thomas Hervé
  • Date: 2011-12-01 12:19:46 UTC
  • mfrom: (102.1.4 ssl-verify)
  • Revision ID: thomas@canonical.com-20111201121946-kbhurhwhaiotemfm
Merge ssl-verify [r=free.ekanayaka] [f=781949]

Add a ssl_hostname_verification option to AWSServiceEndpoint, which enables
verification of the SSL certificate at connection time.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
from glob import glob
 
2
import os
 
3
import re
 
4
 
 
5
from OpenSSL import SSL
 
6
from OpenSSL.crypto import load_certificate, FILETYPE_PEM
 
7
 
 
8
from twisted.internet.ssl import CertificateOptions
 
9
 
 
10
 
 
11
__all__ = ["VerifyingContextFactory", "get_ca_certs"]
 
12
 
 
13
 
 
14
class VerifyingContextFactory(CertificateOptions):
 
15
    """
 
16
    A SSL context factory to pass to C{connectSSL} to check for hostname
 
17
    validity.
 
18
    """
 
19
 
 
20
    def __init__(self, host, caCerts=None):
 
21
        if caCerts is None:
 
22
            caCerts = get_global_ca_certs()
 
23
        CertificateOptions.__init__(self, verify=True, caCerts=caCerts)
 
24
        self.host = host
 
25
 
 
26
    def _dnsname_match(self, dn, host):
 
27
        pats = []
 
28
        for frag in dn.split(r"."):
 
29
            if frag == "*":
 
30
                pats.append("[^.]+")
 
31
            else:
 
32
                frag = re.escape(frag)
 
33
                pats.append(frag.replace(r"\*", "[^.]*"))
 
34
 
 
35
        rx = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
 
36
        return bool(rx.match(host))
 
37
 
 
38
    def verify_callback(self, connection, x509, errno, depth, preverifyOK):
 
39
        # Only check depth == 0 on chained certificates.
 
40
        if depth == 0:
 
41
            dns_found = False
 
42
            if getattr(x509, "get_extension", None) is not None:
 
43
                for index in range(x509.get_extension_count()):
 
44
                    extension = x509.get_extension(index)
 
45
                    if extension.get_short_name() != "subjectAltName":
 
46
                        continue
 
47
                    data = str(extension)
 
48
                    for element in data.split(", "):
 
49
                        key, value = element.split(":")
 
50
                        if key != "DNS":
 
51
                            continue
 
52
                        if self._dnsname_match(value, self.host):
 
53
                            return preverifyOK
 
54
                        dns_found = True
 
55
                    break
 
56
            if not dns_found:
 
57
                commonName = x509.get_subject().commonName
 
58
                if commonName is None:
 
59
                    return False
 
60
                if not self._dnsname_match(commonName, self.host):
 
61
                    return False
 
62
            else:
 
63
                return False
 
64
        return preverifyOK
 
65
 
 
66
    def _makeContext(self):
 
67
        context = CertificateOptions._makeContext(self)
 
68
        context.set_verify(
 
69
            SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
 
70
            self.verify_callback)
 
71
        return context
 
72
 
 
73
 
 
74
def get_ca_certs(files="/etc/ssl/certs/*.pem"):
 
75
    """Retrieve a list of CAs pointed by C{files}."""
 
76
    certificateAuthorityMap = {}
 
77
    for certFileName in glob(files):
 
78
        # There might be some dead symlinks in there, so let's make sure it's
 
79
        # real.
 
80
        if not os.path.exists(certFileName):
 
81
            continue
 
82
        certFile = open(certFileName)
 
83
        data = certFile.read()
 
84
        certFile.close()
 
85
        x509 = load_certificate(FILETYPE_PEM, data)
 
86
        digest = x509.digest("sha1")
 
87
        # Now, de-duplicate in case the same cert has multiple names.
 
88
        certificateAuthorityMap[digest] = x509
 
89
    return certificateAuthorityMap.values()
 
90
 
 
91
 
 
92
_ca_certs = None
 
93
 
 
94
 
 
95
def get_global_ca_certs():
 
96
    """Retrieve a singleton of CA certificates."""
 
97
    global _ca_certs
 
98
    if _ca_certs is None:
 
99
        _ca_certs = get_ca_certs()
 
100
    return _ca_certs