~ntt-pf-lab/nova/monkey_patch_notification

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/internet/_sslverify.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.test.test_sslverify -*-
 
2
# Copyright (c) 2005 Divmod, Inc.
 
3
# Copyright (c) 2008 Twisted Matrix Laboratories.
 
4
# See LICENSE for details.
 
5
# Copyright (c) 2005-2008 Twisted Matrix Laboratories.
 
6
 
 
7
import itertools
 
8
from OpenSSL import SSL, crypto
 
9
 
 
10
from twisted.python import reflect, util
 
11
from twisted.python.hashlib import md5
 
12
from twisted.internet.defer import Deferred
 
13
from twisted.internet.error import VerifyError, CertificateError
 
14
 
 
15
# Private - shared between all OpenSSLCertificateOptions, counts up to provide
 
16
# a unique session id for each context
 
17
_sessionCounter = itertools.count().next
 
18
 
 
19
_x509names = {
 
20
    'CN': 'commonName',
 
21
    'commonName': 'commonName',
 
22
 
 
23
    'O': 'organizationName',
 
24
    'organizationName': 'organizationName',
 
25
 
 
26
    'OU': 'organizationalUnitName',
 
27
    'organizationalUnitName': 'organizationalUnitName',
 
28
 
 
29
    'L': 'localityName',
 
30
    'localityName': 'localityName',
 
31
 
 
32
    'ST': 'stateOrProvinceName',
 
33
    'stateOrProvinceName': 'stateOrProvinceName',
 
34
 
 
35
    'C': 'countryName',
 
36
    'countryName': 'countryName',
 
37
 
 
38
    'emailAddress': 'emailAddress'}
 
39
 
 
40
 
 
41
class DistinguishedName(dict):
 
42
    """
 
43
    Identify and describe an entity.
 
44
 
 
45
    Distinguished names are used to provide a minimal amount of identifying
 
46
    information about a certificate issuer or subject.  They are commonly
 
47
    created with one or more of the following fields::
 
48
 
 
49
        commonName (CN)
 
50
        organizationName (O)
 
51
        organizationalUnitName (OU)
 
52
        localityName (L)
 
53
        stateOrProvinceName (ST)
 
54
        countryName (C)
 
55
        emailAddress
 
56
    """
 
57
    __slots__ = ()
 
58
 
 
59
    def __init__(self, **kw):
 
60
        for k, v in kw.iteritems():
 
61
            setattr(self, k, v)
 
62
 
 
63
 
 
64
    def _copyFrom(self, x509name):
 
65
        d = {}
 
66
        for name in _x509names:
 
67
            value = getattr(x509name, name, None)
 
68
            if value is not None:
 
69
                setattr(self, name, value)
 
70
 
 
71
 
 
72
    def _copyInto(self, x509name):
 
73
        for k, v in self.iteritems():
 
74
            setattr(x509name, k, v)
 
75
 
 
76
 
 
77
    def __repr__(self):
 
78
        return '<DN %s>' % (dict.__repr__(self)[1:-1])
 
79
 
 
80
 
 
81
    def __getattr__(self, attr):
 
82
        try:
 
83
            return self[_x509names[attr]]
 
84
        except KeyError:
 
85
            raise AttributeError(attr)
 
86
 
 
87
 
 
88
    def __setattr__(self, attr, value):
 
89
        assert type(attr) is str
 
90
        if not attr in _x509names:
 
91
            raise AttributeError("%s is not a valid OpenSSL X509 name field" % (attr,))
 
92
        realAttr = _x509names[attr]
 
93
        value = value.encode('ascii')
 
94
        assert type(value) is str
 
95
        self[realAttr] = value
 
96
 
 
97
 
 
98
    def inspect(self):
 
99
        """
 
100
        Return a multi-line, human-readable representation of this DN.
 
101
        """
 
102
        l = []
 
103
        lablen = 0
 
104
        def uniqueValues(mapping):
 
105
            return dict.fromkeys(mapping.itervalues()).keys()
 
106
        for k in uniqueValues(_x509names):
 
107
            label = util.nameToLabel(k)
 
108
            lablen = max(len(label), lablen)
 
109
            v = getattr(self, k, None)
 
110
            if v is not None:
 
111
                l.append((label, v))
 
112
        lablen += 2
 
113
        for n, (label, attr) in enumerate(l):
 
114
            l[n] = (label.rjust(lablen)+': '+ attr)
 
115
        return '\n'.join(l)
 
116
 
 
117
DN = DistinguishedName
 
118
 
 
119
 
 
120
class CertBase:
 
121
    def __init__(self, original):
 
122
        self.original = original
 
123
 
 
124
    def _copyName(self, suffix):
 
125
        dn = DistinguishedName()
 
126
        dn._copyFrom(getattr(self.original, 'get_'+suffix)())
 
127
        return dn
 
128
 
 
129
    def getSubject(self):
 
130
        """
 
131
        Retrieve the subject of this certificate.
 
132
 
 
133
        @rtype: L{DistinguishedName}
 
134
        @return: A copy of the subject of this certificate.
 
135
        """
 
136
        return self._copyName('subject')
 
137
 
 
138
 
 
139
 
 
140
def _handleattrhelper(Class, transport, methodName):
 
141
    """
 
142
    (private) Helper for L{Certificate.peerFromTransport} and
 
143
    L{Certificate.hostFromTransport} which checks for incompatible handle types
 
144
    and null certificates and raises the appropriate exception or returns the
 
145
    appropriate certificate object.
 
146
    """
 
147
    method = getattr(transport.getHandle(),
 
148
                     "get_%s_certificate" % (methodName,), None)
 
149
    if method is None:
 
150
        raise CertificateError(
 
151
            "non-TLS transport %r did not have %s certificate" % (transport, methodName))
 
152
    cert = method()
 
153
    if cert is None:
 
154
        raise CertificateError(
 
155
            "TLS transport %r did not have %s certificate" % (transport, methodName))
 
156
    return Class(cert)
 
157
 
 
158
 
 
159
class Certificate(CertBase):
 
160
    """
 
161
    An x509 certificate.
 
162
    """
 
163
    def __repr__(self):
 
164
        return '<%s Subject=%s Issuer=%s>' % (self.__class__.__name__,
 
165
                                              self.getSubject().commonName,
 
166
                                              self.getIssuer().commonName)
 
167
 
 
168
    def __eq__(self, other):
 
169
        if isinstance(other, Certificate):
 
170
            return self.dump() == other.dump()
 
171
        return False
 
172
 
 
173
 
 
174
    def __ne__(self, other):
 
175
        return not self.__eq__(other)
 
176
 
 
177
 
 
178
    def load(Class, requestData, format=crypto.FILETYPE_ASN1, args=()):
 
179
        """
 
180
        Load a certificate from an ASN.1- or PEM-format string.
 
181
 
 
182
        @rtype: C{Class}
 
183
        """
 
184
        return Class(crypto.load_certificate(format, requestData), *args)
 
185
    load = classmethod(load)
 
186
    _load = load
 
187
 
 
188
 
 
189
    def dumpPEM(self):
 
190
        """
 
191
        Dump this certificate to a PEM-format data string.
 
192
 
 
193
        @rtype: C{str}
 
194
        """
 
195
        return self.dump(crypto.FILETYPE_PEM)
 
196
 
 
197
 
 
198
    def loadPEM(Class, data):
 
199
        """
 
200
        Load a certificate from a PEM-format data string.
 
201
 
 
202
        @rtype: C{Class}
 
203
        """
 
204
        return Class.load(data, crypto.FILETYPE_PEM)
 
205
    loadPEM = classmethod(loadPEM)
 
206
 
 
207
 
 
208
    def peerFromTransport(Class, transport):
 
209
        """
 
210
        Get the certificate for the remote end of the given transport.
 
211
 
 
212
        @type: L{ISystemHandle}
 
213
        @rtype: C{Class}
 
214
 
 
215
        @raise: L{CertificateError}, if the given transport does not have a peer
 
216
        certificate.
 
217
        """
 
218
        return _handleattrhelper(Class, transport, 'peer')
 
219
    peerFromTransport = classmethod(peerFromTransport)
 
220
 
 
221
 
 
222
    def hostFromTransport(Class, transport):
 
223
        """
 
224
        Get the certificate for the local end of the given transport.
 
225
 
 
226
        @param transport: an L{ISystemHandle} provider; the transport we will
 
227
 
 
228
        @rtype: C{Class}
 
229
 
 
230
        @raise: L{CertificateError}, if the given transport does not have a host
 
231
        certificate.
 
232
        """
 
233
        return _handleattrhelper(Class, transport, 'host')
 
234
    hostFromTransport = classmethod(hostFromTransport)
 
235
 
 
236
 
 
237
    def getPublicKey(self):
 
238
        """
 
239
        Get the public key for this certificate.
 
240
 
 
241
        @rtype: L{PublicKey}
 
242
        """
 
243
        return PublicKey(self.original.get_pubkey())
 
244
 
 
245
 
 
246
    def dump(self, format=crypto.FILETYPE_ASN1):
 
247
        return crypto.dump_certificate(format, self.original)
 
248
 
 
249
 
 
250
    def serialNumber(self):
 
251
        """
 
252
        Retrieve the serial number of this certificate.
 
253
 
 
254
        @rtype: C{int}
 
255
        """
 
256
        return self.original.get_serial_number()
 
257
 
 
258
 
 
259
    def digest(self, method='md5'):
 
260
        """
 
261
        Return a digest hash of this certificate using the specified hash
 
262
        algorithm.
 
263
 
 
264
        @param method: One of C{'md5'} or C{'sha'}.
 
265
        @rtype: C{str}
 
266
        """
 
267
        return self.original.digest(method)
 
268
 
 
269
 
 
270
    def _inspect(self):
 
271
        return '\n'.join(['Certificate For Subject:',
 
272
                          self.getSubject().inspect(),
 
273
                          '\nIssuer:',
 
274
                          self.getIssuer().inspect(),
 
275
                          '\nSerial Number: %d' % self.serialNumber(),
 
276
                          'Digest: %s' % self.digest()])
 
277
 
 
278
 
 
279
    def inspect(self):
 
280
        """
 
281
        Return a multi-line, human-readable representation of this
 
282
        Certificate, including information about the subject, issuer, and
 
283
        public key.
 
284
        """
 
285
        return '\n'.join((self._inspect(), self.getPublicKey().inspect()))
 
286
 
 
287
 
 
288
    def getIssuer(self):
 
289
        """
 
290
        Retrieve the issuer of this certificate.
 
291
 
 
292
        @rtype: L{DistinguishedName}
 
293
        @return: A copy of the issuer of this certificate.
 
294
        """
 
295
        return self._copyName('issuer')
 
296
 
 
297
 
 
298
    def options(self, *authorities):
 
299
        raise NotImplementedError('Possible, but doubtful we need this yet')
 
300
 
 
301
 
 
302
 
 
303
class CertificateRequest(CertBase):
 
304
    """
 
305
    An x509 certificate request.
 
306
 
 
307
    Certificate requests are given to certificate authorities to be signed and
 
308
    returned resulting in an actual certificate.
 
309
    """
 
310
    def load(Class, requestData, requestFormat=crypto.FILETYPE_ASN1):
 
311
        req = crypto.load_certificate_request(requestFormat, requestData)
 
312
        dn = DistinguishedName()
 
313
        dn._copyFrom(req.get_subject())
 
314
        if not req.verify(req.get_pubkey()):
 
315
            raise VerifyError("Can't verify that request for %r is self-signed." % (dn,))
 
316
        return Class(req)
 
317
    load = classmethod(load)
 
318
 
 
319
 
 
320
    def dump(self, format=crypto.FILETYPE_ASN1):
 
321
        return crypto.dump_certificate_request(format, self.original)
 
322
 
 
323
 
 
324
 
 
325
class PrivateCertificate(Certificate):
 
326
    """
 
327
    An x509 certificate and private key.
 
328
    """
 
329
    def __repr__(self):
 
330
        return Certificate.__repr__(self) + ' with ' + repr(self.privateKey)
 
331
 
 
332
 
 
333
    def _setPrivateKey(self, privateKey):
 
334
        if not privateKey.matches(self.getPublicKey()):
 
335
            raise VerifyError(
 
336
                "Certificate public and private keys do not match.")
 
337
        self.privateKey = privateKey
 
338
        return self
 
339
 
 
340
 
 
341
    def newCertificate(self, newCertData, format=crypto.FILETYPE_ASN1):
 
342
        """
 
343
        Create a new L{PrivateCertificate} from the given certificate data and
 
344
        this instance's private key.
 
345
        """
 
346
        return self.load(newCertData, self.privateKey, format)
 
347
 
 
348
 
 
349
    def load(Class, data, privateKey, format=crypto.FILETYPE_ASN1):
 
350
        return Class._load(data, format)._setPrivateKey(privateKey)
 
351
    load = classmethod(load)
 
352
 
 
353
 
 
354
    def inspect(self):
 
355
        return '\n'.join([Certificate._inspect(self),
 
356
                          self.privateKey.inspect()])
 
357
 
 
358
 
 
359
    def dumpPEM(self):
 
360
        """
 
361
        Dump both public and private parts of a private certificate to
 
362
        PEM-format data.
 
363
        """
 
364
        return self.dump(crypto.FILETYPE_PEM) + self.privateKey.dump(crypto.FILETYPE_PEM)
 
365
 
 
366
 
 
367
    def loadPEM(Class, data):
 
368
        """
 
369
        Load both private and public parts of a private certificate from a
 
370
        chunk of PEM-format data.
 
371
        """
 
372
        return Class.load(data, KeyPair.load(data, crypto.FILETYPE_PEM),
 
373
                          crypto.FILETYPE_PEM)
 
374
    loadPEM = classmethod(loadPEM)
 
375
 
 
376
 
 
377
    def fromCertificateAndKeyPair(Class, certificateInstance, privateKey):
 
378
        privcert = Class(certificateInstance.original)
 
379
        return privcert._setPrivateKey(privateKey)
 
380
    fromCertificateAndKeyPair = classmethod(fromCertificateAndKeyPair)
 
381
 
 
382
 
 
383
    def options(self, *authorities):
 
384
        options = dict(privateKey=self.privateKey.original,
 
385
                       certificate=self.original)
 
386
        if authorities:
 
387
            options.update(dict(verify=True,
 
388
                                requireCertificate=True,
 
389
                                caCerts=[auth.original for auth in authorities]))
 
390
        return OpenSSLCertificateOptions(**options)
 
391
 
 
392
 
 
393
    def certificateRequest(self, format=crypto.FILETYPE_ASN1,
 
394
                           digestAlgorithm='md5'):
 
395
        return self.privateKey.certificateRequest(
 
396
            self.getSubject(),
 
397
            format,
 
398
            digestAlgorithm)
 
399
 
 
400
 
 
401
    def signCertificateRequest(self,
 
402
                               requestData,
 
403
                               verifyDNCallback,
 
404
                               serialNumber,
 
405
                               requestFormat=crypto.FILETYPE_ASN1,
 
406
                               certificateFormat=crypto.FILETYPE_ASN1):
 
407
        issuer = self.getSubject()
 
408
        return self.privateKey.signCertificateRequest(
 
409
            issuer,
 
410
            requestData,
 
411
            verifyDNCallback,
 
412
            serialNumber,
 
413
            requestFormat,
 
414
            certificateFormat)
 
415
 
 
416
 
 
417
    def signRequestObject(self, certificateRequest, serialNumber,
 
418
                          secondsToExpiry=60 * 60 * 24 * 365, # One year
 
419
                          digestAlgorithm='md5'):
 
420
        return self.privateKey.signRequestObject(self.getSubject(),
 
421
                                                 certificateRequest,
 
422
                                                 serialNumber,
 
423
                                                 secondsToExpiry,
 
424
                                                 digestAlgorithm)
 
425
 
 
426
 
 
427
class PublicKey:
 
428
    def __init__(self, osslpkey):
 
429
        self.original = osslpkey
 
430
        req1 = crypto.X509Req()
 
431
        req1.set_pubkey(osslpkey)
 
432
        self._emptyReq = crypto.dump_certificate_request(crypto.FILETYPE_ASN1, req1)
 
433
 
 
434
 
 
435
    def matches(self, otherKey):
 
436
        return self._emptyReq == otherKey._emptyReq
 
437
 
 
438
 
 
439
    # XXX This could be a useful method, but sometimes it triggers a segfault,
 
440
    # so we'll steer clear for now.
 
441
#     def verifyCertificate(self, certificate):
 
442
#         """
 
443
#         returns None, or raises a VerifyError exception if the certificate
 
444
#         could not be verified.
 
445
#         """
 
446
#         if not certificate.original.verify(self.original):
 
447
#             raise VerifyError("We didn't sign that certificate.")
 
448
 
 
449
    def __repr__(self):
 
450
        return '<%s %s>' % (self.__class__.__name__, self.keyHash())
 
451
 
 
452
 
 
453
    def keyHash(self):
 
454
        """
 
455
        MD5 hex digest of signature on an empty certificate request with this
 
456
        key.
 
457
        """
 
458
        return md5(self._emptyReq).hexdigest()
 
459
 
 
460
 
 
461
    def inspect(self):
 
462
        return 'Public Key with Hash: %s' % (self.keyHash(),)
 
463
 
 
464
 
 
465
 
 
466
class KeyPair(PublicKey):
 
467
 
 
468
    def load(Class, data, format=crypto.FILETYPE_ASN1):
 
469
        return Class(crypto.load_privatekey(format, data))
 
470
    load = classmethod(load)
 
471
 
 
472
 
 
473
    def dump(self, format=crypto.FILETYPE_ASN1):
 
474
        return crypto.dump_privatekey(format, self.original)
 
475
 
 
476
 
 
477
    def __getstate__(self):
 
478
        return self.dump()
 
479
 
 
480
 
 
481
    def __setstate__(self, state):
 
482
        self.__init__(crypto.load_privatekey(crypto.FILETYPE_ASN1, state))
 
483
 
 
484
 
 
485
    def inspect(self):
 
486
        t = self.original.type()
 
487
        if t == crypto.TYPE_RSA:
 
488
            ts = 'RSA'
 
489
        elif t == crypto.TYPE_DSA:
 
490
            ts = 'DSA'
 
491
        else:
 
492
            ts = '(Unknown Type!)'
 
493
        L = (self.original.bits(), ts, self.keyHash())
 
494
        return '%s-bit %s Key Pair with Hash: %s' % L
 
495
 
 
496
 
 
497
    def generate(Class, kind=crypto.TYPE_RSA, size=1024):
 
498
        pkey = crypto.PKey()
 
499
        pkey.generate_key(kind, size)
 
500
        return Class(pkey)
 
501
 
 
502
 
 
503
    def newCertificate(self, newCertData, format=crypto.FILETYPE_ASN1):
 
504
        return PrivateCertificate.load(newCertData, self, format)
 
505
    generate = classmethod(generate)
 
506
 
 
507
 
 
508
    def requestObject(self, distinguishedName, digestAlgorithm='md5'):
 
509
        req = crypto.X509Req()
 
510
        req.set_pubkey(self.original)
 
511
        distinguishedName._copyInto(req.get_subject())
 
512
        req.sign(self.original, digestAlgorithm)
 
513
        return CertificateRequest(req)
 
514
 
 
515
 
 
516
    def certificateRequest(self, distinguishedName,
 
517
                           format=crypto.FILETYPE_ASN1,
 
518
                           digestAlgorithm='md5'):
 
519
        """Create a certificate request signed with this key.
 
520
 
 
521
        @return: a string, formatted according to the 'format' argument.
 
522
        """
 
523
        return self.requestObject(distinguishedName, digestAlgorithm).dump(format)
 
524
 
 
525
 
 
526
    def signCertificateRequest(self,
 
527
                               issuerDistinguishedName,
 
528
                               requestData,
 
529
                               verifyDNCallback,
 
530
                               serialNumber,
 
531
                               requestFormat=crypto.FILETYPE_ASN1,
 
532
                               certificateFormat=crypto.FILETYPE_ASN1,
 
533
                               secondsToExpiry=60 * 60 * 24 * 365, # One year
 
534
                               digestAlgorithm='md5'):
 
535
        """
 
536
        Given a blob of certificate request data and a certificate authority's
 
537
        DistinguishedName, return a blob of signed certificate data.
 
538
 
 
539
        If verifyDNCallback returns a Deferred, I will return a Deferred which
 
540
        fires the data when that Deferred has completed.
 
541
        """
 
542
        hlreq = CertificateRequest.load(requestData, requestFormat)
 
543
 
 
544
        dn = hlreq.getSubject()
 
545
        vval = verifyDNCallback(dn)
 
546
 
 
547
        def verified(value):
 
548
            if not value:
 
549
                raise VerifyError("DN callback %r rejected request DN %r" % (verifyDNCallback, dn))
 
550
            return self.signRequestObject(issuerDistinguishedName, hlreq,
 
551
                                          serialNumber, secondsToExpiry, digestAlgorithm).dump(certificateFormat)
 
552
 
 
553
        if isinstance(vval, Deferred):
 
554
            return vval.addCallback(verified)
 
555
        else:
 
556
            return verified(vval)
 
557
 
 
558
 
 
559
    def signRequestObject(self,
 
560
                          issuerDistinguishedName,
 
561
                          requestObject,
 
562
                          serialNumber,
 
563
                          secondsToExpiry=60 * 60 * 24 * 365, # One year
 
564
                          digestAlgorithm='md5'):
 
565
        """
 
566
        Sign a CertificateRequest instance, returning a Certificate instance.
 
567
        """
 
568
        req = requestObject.original
 
569
        dn = requestObject.getSubject()
 
570
        cert = crypto.X509()
 
571
        issuerDistinguishedName._copyInto(cert.get_issuer())
 
572
        cert.set_subject(req.get_subject())
 
573
        cert.set_pubkey(req.get_pubkey())
 
574
        cert.gmtime_adj_notBefore(0)
 
575
        cert.gmtime_adj_notAfter(secondsToExpiry)
 
576
        cert.set_serial_number(serialNumber)
 
577
        cert.sign(self.original, digestAlgorithm)
 
578
        return Certificate(cert)
 
579
 
 
580
 
 
581
    def selfSignedCert(self, serialNumber, **kw):
 
582
        dn = DN(**kw)
 
583
        return PrivateCertificate.fromCertificateAndKeyPair(
 
584
            self.signRequestObject(dn, self.requestObject(dn), serialNumber),
 
585
            self)
 
586
 
 
587
 
 
588
 
 
589
class OpenSSLCertificateOptions(object):
 
590
    """
 
591
    A factory for SSL context objects for both SSL servers and clients.
 
592
    """
 
593
 
 
594
    _context = None
 
595
    # Older versions of PyOpenSSL didn't provide OP_ALL.  Fudge it here, just in case.
 
596
    _OP_ALL = getattr(SSL, 'OP_ALL', 0x0000FFFF)
 
597
    # OP_NO_TICKET is not (yet) exposed by PyOpenSSL
 
598
    _OP_NO_TICKET = 0x00004000
 
599
 
 
600
    method = SSL.TLSv1_METHOD
 
601
 
 
602
    def __init__(self,
 
603
                 privateKey=None,
 
604
                 certificate=None,
 
605
                 method=None,
 
606
                 verify=False,
 
607
                 caCerts=None,
 
608
                 verifyDepth=9,
 
609
                 requireCertificate=True,
 
610
                 verifyOnce=True,
 
611
                 enableSingleUseKeys=True,
 
612
                 enableSessions=True,
 
613
                 fixBrokenPeers=False,
 
614
                 enableSessionTickets=False):
 
615
        """
 
616
        Create an OpenSSL context SSL connection context factory.
 
617
 
 
618
        @param privateKey: A PKey object holding the private key.
 
619
 
 
620
        @param certificate: An X509 object holding the certificate.
 
621
 
 
622
        @param method: The SSL protocol to use, one of SSLv23_METHOD,
 
623
        SSLv2_METHOD, SSLv3_METHOD, TLSv1_METHOD.  Defaults to TLSv1_METHOD.
 
624
 
 
625
        @param verify: If True, verify certificates received from the peer and
 
626
        fail the handshake if verification fails.  Otherwise, allow anonymous
 
627
        sessions and sessions with certificates which fail validation.  By
 
628
        default this is False.
 
629
 
 
630
        @param caCerts: List of certificate authority certificates to
 
631
        send to the client when requesting a certificate.  Only used if verify
 
632
        is True, and if verify is True, either this must be specified or
 
633
        caCertsFile must be given.  Since verify is False by default,
 
634
        this is None by default.
 
635
 
 
636
        @param verifyDepth: Depth in certificate chain down to which to verify.
 
637
        If unspecified, use the underlying default (9).
 
638
 
 
639
        @param requireCertificate: If True, do not allow anonymous sessions.
 
640
 
 
641
        @param verifyOnce: If True, do not re-verify the certificate
 
642
        on session resumption.
 
643
 
 
644
        @param enableSingleUseKeys: If True, generate a new key whenever
 
645
        ephemeral DH parameters are used to prevent small subgroup attacks.
 
646
 
 
647
        @param enableSessions: If True, set a session ID on each context.  This
 
648
        allows a shortened handshake to be used when a known client reconnects.
 
649
 
 
650
        @param fixBrokenPeers: If True, enable various non-spec protocol fixes
 
651
        for broken SSL implementations.  This should be entirely safe,
 
652
        according to the OpenSSL documentation, but YMMV.  This option is now
 
653
        off by default, because it causes problems with connections between
 
654
        peers using OpenSSL 0.9.8a.
 
655
 
 
656
        @param enableSessionTickets: If True, enable session ticket extension
 
657
        for session resumption per RFC 5077. Note there is no support for
 
658
        controlling session tickets. This option is off by default, as some
 
659
        server implementations don't correctly process incoming empty session
 
660
        ticket extensions in the hello.
 
661
        """
 
662
 
 
663
        assert (privateKey is None) == (certificate is None), "Specify neither or both of privateKey and certificate"
 
664
        self.privateKey = privateKey
 
665
        self.certificate = certificate
 
666
        if method is not None:
 
667
            self.method = method
 
668
 
 
669
        self.verify = verify
 
670
        assert ((verify and caCerts) or
 
671
                (not verify)), "Specify client CA certificate information if and only if enabling certificate verification"
 
672
 
 
673
        self.caCerts = caCerts
 
674
        self.verifyDepth = verifyDepth
 
675
        self.requireCertificate = requireCertificate
 
676
        self.verifyOnce = verifyOnce
 
677
        self.enableSingleUseKeys = enableSingleUseKeys
 
678
        self.enableSessions = enableSessions
 
679
        self.fixBrokenPeers = fixBrokenPeers
 
680
        self.enableSessionTickets = enableSessionTickets
 
681
 
 
682
 
 
683
    def __getstate__(self):
 
684
        d = self.__dict__.copy()
 
685
        try:
 
686
            del d['_context']
 
687
        except KeyError:
 
688
            pass
 
689
        return d
 
690
 
 
691
 
 
692
    def __setstate__(self, state):
 
693
        self.__dict__ = state
 
694
 
 
695
 
 
696
    def getContext(self):
 
697
        """Return a SSL.Context object.
 
698
        """
 
699
        if self._context is None:
 
700
            self._context = self._makeContext()
 
701
        return self._context
 
702
 
 
703
 
 
704
    def _makeContext(self):
 
705
        ctx = SSL.Context(self.method)
 
706
 
 
707
        if self.certificate is not None and self.privateKey is not None:
 
708
            ctx.use_certificate(self.certificate)
 
709
            ctx.use_privatekey(self.privateKey)
 
710
            # Sanity check
 
711
            ctx.check_privatekey()
 
712
 
 
713
        verifyFlags = SSL.VERIFY_NONE
 
714
        if self.verify:
 
715
            verifyFlags = SSL.VERIFY_PEER
 
716
            if self.requireCertificate:
 
717
                verifyFlags |= SSL.VERIFY_FAIL_IF_NO_PEER_CERT
 
718
            if self.verifyOnce:
 
719
                verifyFlags |= SSL.VERIFY_CLIENT_ONCE
 
720
            if self.caCerts:
 
721
                store = ctx.get_cert_store()
 
722
                for cert in self.caCerts:
 
723
                    store.add_cert(cert)
 
724
 
 
725
        # It'd be nice if pyOpenSSL let us pass None here for this behavior (as
 
726
        # the underlying OpenSSL API call allows NULL to be passed).  It
 
727
        # doesn't, so we'll supply a function which does the same thing.
 
728
        def _verifyCallback(conn, cert, errno, depth, preverify_ok):
 
729
            return preverify_ok
 
730
        ctx.set_verify(verifyFlags, _verifyCallback)
 
731
 
 
732
        if self.verifyDepth is not None:
 
733
            ctx.set_verify_depth(self.verifyDepth)
 
734
 
 
735
        if self.enableSingleUseKeys:
 
736
            ctx.set_options(SSL.OP_SINGLE_DH_USE)
 
737
 
 
738
        if self.fixBrokenPeers:
 
739
            ctx.set_options(self._OP_ALL)
 
740
 
 
741
        if self.enableSessions:
 
742
            sessionName = md5("%s-%d" % (reflect.qual(self.__class__), _sessionCounter())).hexdigest()
 
743
            ctx.set_session_id(sessionName)
 
744
 
 
745
        if not self.enableSessionTickets:
 
746
            ctx.set_options(self._OP_NO_TICKET)
 
747
 
 
748
        return ctx