1
# Copyright 2005 Divmod, Inc. See LICENSE file for details
5
from OpenSSL import SSL
6
from OpenSSL.crypto import PKey, X509, X509Req
7
from OpenSSL.crypto import TYPE_RSA
9
from twisted.trial import unittest
10
from twisted.internet import protocol, defer, reactor
11
from twisted.python import log
12
from twisted.python.reflect import objgrep, isSame
14
from twisted.internet import _sslverify as sslverify
16
from twisted.internet.error import CertificateError
19
# A couple of static PEM-format certificates to be used by various tests.
20
A_HOST_CERTIFICATE_PEM = """
21
-----BEGIN CERTIFICATE-----
22
MIIC2jCCAkMCAjA5MA0GCSqGSIb3DQEBBAUAMIG0MQswCQYDVQQGEwJVUzEiMCAG
23
A1UEAxMZZXhhbXBsZS50d2lzdGVkbWF0cml4LmNvbTEPMA0GA1UEBxMGQm9zdG9u
24
MRwwGgYDVQQKExNUd2lzdGVkIE1hdHJpeCBMYWJzMRYwFAYDVQQIEw1NYXNzYWNo
25
dXNldHRzMScwJQYJKoZIhvcNAQkBFhhub2JvZHlAdHdpc3RlZG1hdHJpeC5jb20x
26
ETAPBgNVBAsTCFNlY3VyaXR5MB4XDTA2MDgxNjAxMDEwOFoXDTA3MDgxNjAxMDEw
27
OFowgbQxCzAJBgNVBAYTAlVTMSIwIAYDVQQDExlleGFtcGxlLnR3aXN0ZWRtYXRy
28
aXguY29tMQ8wDQYDVQQHEwZCb3N0b24xHDAaBgNVBAoTE1R3aXN0ZWQgTWF0cml4
29
IExhYnMxFjAUBgNVBAgTDU1hc3NhY2h1c2V0dHMxJzAlBgkqhkiG9w0BCQEWGG5v
30
Ym9keUB0d2lzdGVkbWF0cml4LmNvbTERMA8GA1UECxMIU2VjdXJpdHkwgZ8wDQYJ
31
KoZIhvcNAQEBBQADgY0AMIGJAoGBAMzH8CDF/U91y/bdbdbJKnLgnyvQ9Ig9ZNZp
32
8hpsu4huil60zF03+Lexg2l1FIfURScjBuaJMR6HiMYTMjhzLuByRZ17KW4wYkGi
33
KXstz03VIKy4Tjc+v4aXFI4XdRw10gGMGQlGGscXF/RSoN84VoDKBfOMWdXeConJ
34
VyC4w3iJAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAviMT4lBoxOgQy32LIgZ4lVCj
35
JNOiZYg8GMQ6y0ugp86X80UjOvkGtNf/R7YgED/giKRN/q/XJiLJDEhzknkocwmO
36
S+4b2XpiaZYxRyKWwL221O7CGmtWYyZl2+92YYmmCiNzWQPfP6BOMlfax0AGLHls
38
-----END CERTIFICATE-----
41
A_PEER_CERTIFICATE_PEM = """
42
-----BEGIN CERTIFICATE-----
43
MIIC3jCCAkcCAjA6MA0GCSqGSIb3DQEBBAUAMIG2MQswCQYDVQQGEwJVUzEiMCAG
44
A1UEAxMZZXhhbXBsZS50d2lzdGVkbWF0cml4LmNvbTEPMA0GA1UEBxMGQm9zdG9u
45
MRwwGgYDVQQKExNUd2lzdGVkIE1hdHJpeCBMYWJzMRYwFAYDVQQIEw1NYXNzYWNo
46
dXNldHRzMSkwJwYJKoZIhvcNAQkBFhpzb21lYm9keUB0d2lzdGVkbWF0cml4LmNv
47
bTERMA8GA1UECxMIU2VjdXJpdHkwHhcNMDYwODE2MDEwMTU2WhcNMDcwODE2MDEw
48
MTU2WjCBtjELMAkGA1UEBhMCVVMxIjAgBgNVBAMTGWV4YW1wbGUudHdpc3RlZG1h
49
dHJpeC5jb20xDzANBgNVBAcTBkJvc3RvbjEcMBoGA1UEChMTVHdpc3RlZCBNYXRy
50
aXggTGFiczEWMBQGA1UECBMNTWFzc2FjaHVzZXR0czEpMCcGCSqGSIb3DQEJARYa
51
c29tZWJvZHlAdHdpc3RlZG1hdHJpeC5jb20xETAPBgNVBAsTCFNlY3VyaXR5MIGf
52
MA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCnm+WBlgFNbMlHehib9ePGGDXF+Nz4
53
CjGuUmVBaXCRCiVjg3kSDecwqfb0fqTksBZ+oQ1UBjMcSh7OcvFXJZnUesBikGWE
54
JE4V8Bjh+RmbJ1ZAlUPZ40bAkww0OpyIRAGMvKG+4yLFTO4WDxKmfDcrOb6ID8WJ
55
e1u+i3XGkIf/5QIDAQABMA0GCSqGSIb3DQEBBAUAA4GBAD4Oukm3YYkhedUepBEA
56
vvXIQhVDqL7mk6OqYdXmNj6R7ZMC8WWvGZxrzDI1bZuB+4aIxxd1FXC3UOHiR/xg
57
i9cDl1y8P/qRp4aEBNF6rI0D4AxTbfnHQx4ERDAOShJdYZs/2zifPJ6va6YvrEyr
58
yqDtGhklsWW3ZwBzEh5VEOUp
59
-----END CERTIFICATE-----
64
counter = itertools.count().next
65
def makeCertificate(**kw):
67
keypair.generate_key(TYPE_RSA, 1024)
70
certificate.gmtime_adj_notBefore(0)
71
certificate.gmtime_adj_notAfter(60 * 60 * 24 * 365) # One year
72
for xname in certificate.get_issuer(), certificate.get_subject():
73
for (k, v) in kw.items():
76
certificate.set_serial_number(counter())
77
certificate.set_pubkey(keypair)
78
certificate.sign(keypair, "md5")
80
return keypair, certificate
82
def otherMakeCertificate(**kw):
84
keypair.generate_key(TYPE_RSA, 1024)
87
subj = req.get_subject()
88
for (k, v) in kw.items():
91
req.set_pubkey(keypair)
92
req.sign(keypair, "md5")
95
cert.set_serial_number(counter())
96
cert.gmtime_adj_notBefore(0)
97
cert.gmtime_adj_notAfter(60 * 60 * 24 * 365) # One year
99
cert.set_issuer(req.get_subject())
100
cert.set_subject(req.get_subject())
101
cert.set_pubkey(req.get_pubkey())
102
cert.sign(keypair, "md5")
107
class DataCallbackProtocol(protocol.Protocol):
108
def dataReceived(self, data):
109
d, self.factory.onData = self.factory.onData, None
113
def connectionLost(self, reason):
114
d, self.factory.onLost = self.factory.onLost, None
118
class WritingProtocol(protocol.Protocol):
120
def connectionMade(self):
121
self.transport.write(self.byte)
123
def connectionLost(self, reason):
124
self.factory.onLost.errback(reason)
127
class OpenSSLOptions(unittest.TestCase):
128
serverPort = clientConn = None
129
onServerLost = onClientLost = None
131
def setUpClass(self):
132
self.sKey, self.sCert = makeCertificate(
133
O="Server Test Certificate",
135
self.cKey, self.cCert = makeCertificate(
136
O="Client Test Certificate",
140
if self.serverPort is not None:
141
self.serverPort.stopListening()
142
if self.clientConn is not None:
143
self.clientConn.disconnect()
146
if self.onServerLost is not None:
147
L.append(self.onServerLost)
148
if self.onClientLost is not None:
149
L.append(self.onClientLost)
151
return defer.DeferredList(L, consumeErrors=True)
153
def loopback(self, serverCertOpts, clientCertOpts,
154
onServerLost=None, onClientLost=None, onData=None):
155
if onServerLost is None:
156
self.onServerLost = onServerLost = defer.Deferred()
157
if onClientLost is None:
158
self.onClientLost = onClientLost = defer.Deferred()
160
onData = defer.Deferred()
162
serverFactory = protocol.ServerFactory()
163
serverFactory.protocol = DataCallbackProtocol
164
serverFactory.onLost = onServerLost
165
serverFactory.onData = onData
167
clientFactory = protocol.ClientFactory()
168
clientFactory.protocol = WritingProtocol
169
clientFactory.onLost = onClientLost
171
self.serverPort = reactor.listenSSL(0, serverFactory, serverCertOpts)
172
self.clientConn = reactor.connectSSL('127.0.0.1', self.serverPort.getHost().port,
173
clientFactory, clientCertOpts)
175
def testAbbreviatingDistinguishedNames(self):
176
self.assertEquals(sslverify.DN(CN='a', OU='hello'),
177
sslverify.DistinguishedName(commonName='a', organizationalUnitName='hello'))
178
self.assertNotEquals(sslverify.DN(CN='a', OU='hello'),
179
sslverify.DN(CN='a', OU='hello', emailAddress='xxx'))
180
dn = sslverify.DN(CN='abcdefg')
181
self.assertRaises(AttributeError, setattr, dn, 'Cn', 'x')
182
self.assertEquals(dn.CN, dn.commonName)
184
self.assertEquals(dn.CN, dn.commonName)
187
def testInspectDistinguishedName(self):
188
n = sslverify.DN(commonName='common name',
189
organizationName='organization name',
190
organizationalUnitName='organizational unit name',
191
localityName='locality name',
192
stateOrProvinceName='state or province name',
193
countryName='country name',
194
emailAddress='email address')
199
'organizational unit name',
201
'state or province name',
204
self.assertIn(k, s, "%r was not in inspect output." % (k,))
205
self.assertIn(k.title(), s, "%r was not in inspect output." % (k,))
208
def testInspectDistinguishedNameWithoutAllFields(self):
209
n = sslverify.DN(localityName='locality name')
214
'organizational unit name',
215
'state or province name',
218
self.assertNotIn(k, s, "%r was in inspect output." % (k,))
219
self.assertNotIn(k.title(), s, "%r was in inspect output." % (k,))
220
self.assertIn('locality name', s)
221
self.assertIn('Locality Name', s)
224
def test_inspectCertificate(self):
226
Test that the C{inspect} method of L{sslverify.Certificate} returns
227
a human-readable string containing some basic information about the
230
c = sslverify.Certificate.loadPEM(A_HOST_CERTIFICATE_PEM)
232
c.inspect().split('\n'),
233
["Certificate For Subject:",
234
" Organizational Unit Name: Security",
235
" Organization Name: Twisted Matrix Labs",
236
" Common Name: example.twistedmatrix.com",
237
" State Or Province Name: Massachusetts",
239
" Email Address: nobody@twistedmatrix.com",
240
" Locality Name: Boston",
243
" Organizational Unit Name: Security",
244
" Organization Name: Twisted Matrix Labs",
245
" Common Name: example.twistedmatrix.com",
246
" State Or Province Name: Massachusetts",
248
" Email Address: nobody@twistedmatrix.com",
249
" Locality Name: Boston",
251
"Serial Number: 12345",
252
"Digest: C4:96:11:00:30:C3:EC:EE:A3:55:AA:ED:8C:84:85:18",
253
"Public Key with Hash: ff33994c80812aa95a79cdb85362d054"])
256
def test_certificateOptionsSerialization(self):
258
Test that __setstate__(__getstate__()) round-trips properly.
260
firstOpts = sslverify.OpenSSLCertificateOptions(
261
privateKey=self.sKey,
262
certificate=self.sCert,
263
method=SSL.SSLv3_METHOD,
265
caCerts=[self.sCert],
267
requireCertificate=False,
269
enableSingleUseKeys=False,
270
enableSessions=False,
272
context = firstOpts.getContext()
273
state = firstOpts.__getstate__()
275
# The context shouldn't be in the state to serialize
276
self.failIf(objgrep(state, context, isSame), objgrep(state, context, isSame))
278
opts = sslverify.OpenSSLCertificateOptions()
279
opts.__setstate__(state)
280
self.assertEqual(opts.privateKey, self.sKey)
281
self.assertEqual(opts.certificate, self.sCert)
282
self.assertEqual(opts.method, SSL.SSLv3_METHOD)
283
self.assertEqual(opts.verify, True)
284
self.assertEqual(opts.caCerts, [self.sCert])
285
self.assertEqual(opts.verifyDepth, 2)
286
self.assertEqual(opts.requireCertificate, False)
287
self.assertEqual(opts.verifyOnce, False)
288
self.assertEqual(opts.enableSingleUseKeys, False)
289
self.assertEqual(opts.enableSessions, False)
290
self.assertEqual(opts.fixBrokenPeers, True)
293
def testAllowedAnonymousClientConnection(self):
294
onData = defer.Deferred()
295
self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False),
296
sslverify.OpenSSLCertificateOptions(requireCertificate=False),
299
return onData.addCallback(
300
lambda result: self.assertEquals(result, WritingProtocol.byte))
302
def testRefusedAnonymousClientConnection(self):
303
onServerLost = defer.Deferred()
304
onClientLost = defer.Deferred()
305
self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, verify=True, caCerts=[self.sCert], requireCertificate=True),
306
sslverify.OpenSSLCertificateOptions(requireCertificate=False),
307
onServerLost=onServerLost,
308
onClientLost=onClientLost)
310
d = defer.DeferredList([onClientLost, onServerLost], consumeErrors=True)
313
def afterLost(((cSuccess, cResult), (sSuccess, sResult))):
315
self.failIf(cSuccess)
316
self.failIf(sSuccess)
318
# XXX Twisted doesn't report SSL errors as SSL errors, but in the
321
# cResult.trap(SSL.Error)
322
# sResult.trap(SSL.Error)
324
# Twisted trunk will do the correct thing here, and not log any
325
# errors. Twisted 2.1 will do the wrong thing. We're flushing
326
# errors until the buildbot is updated to a reasonable facsimilie
328
log.flushErrors(SSL.Error)
330
return d.addCallback(afterLost)
332
def testFailedCertificateVerification(self):
333
onServerLost = defer.Deferred()
334
onClientLost = defer.Deferred()
335
self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, verify=False, requireCertificate=False),
336
sslverify.OpenSSLCertificateOptions(verify=True, requireCertificate=False, caCerts=[self.cCert]),
337
onServerLost=onServerLost,
338
onClientLost=onClientLost)
340
d = defer.DeferredList([onClientLost, onServerLost], consumeErrors=True)
341
def afterLost(((cSuccess, cResult), (sSuccess, sResult))):
343
self.failIf(cSuccess)
344
self.failIf(sSuccess)
346
# Twisted trunk will do the correct thing here, and not log any
347
# errors. Twisted 2.1 will do the wrong thing. We're flushing
348
# errors until the buildbot is updated to a reasonable facsimilie
350
log.flushErrors(SSL.Error)
352
return d.addCallback(afterLost)
354
def testSuccessfulCertificateVerification(self):
355
onData = defer.Deferred()
356
self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, verify=False, requireCertificate=False),
357
sslverify.OpenSSLCertificateOptions(verify=True, requireCertificate=True, caCerts=[self.sCert]),
360
return onData.addCallback(lambda result: self.assertEquals(result, WritingProtocol.byte))
362
def testSuccessfulSymmetricSelfSignedCertificateVerification(self):
363
onData = defer.Deferred()
364
self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, verify=True, requireCertificate=True, caCerts=[self.cCert]),
365
sslverify.OpenSSLCertificateOptions(privateKey=self.cKey, certificate=self.cCert, verify=True, requireCertificate=True, caCerts=[self.sCert]),
368
return onData.addCallback(lambda result: self.assertEquals(result, WritingProtocol.byte))
370
def testVerification(self):
371
clientDN = sslverify.DistinguishedName(commonName='client')
372
clientKey = sslverify.KeyPair.generate()
373
clientCertReq = clientKey.certificateRequest(clientDN)
375
serverDN = sslverify.DistinguishedName(commonName='server')
376
serverKey = sslverify.KeyPair.generate()
377
serverCertReq = serverKey.certificateRequest(serverDN)
380
clientSelfCertReq = clientKey.certificateRequest(clientDN)
381
clientSelfCertData = clientKey.signCertificateRequest(clientDN, clientSelfCertReq,
384
clientSelfCert = clientKey.newCertificate(clientSelfCertData)
388
serverSelfCertReq = serverKey.certificateRequest(serverDN)
389
serverSelfCertData = serverKey.signCertificateRequest(serverDN, serverSelfCertReq,
392
serverSelfCert = serverKey.newCertificate(serverSelfCertData)
396
clientCertData = serverKey.signCertificateRequest(serverDN, clientCertReq,
399
clientCert = clientKey.newCertificate(clientCertData)
403
serverCertData = clientKey.signCertificateRequest(clientDN, serverCertReq,
406
serverCert = serverKey.newCertificate(serverCertData)
409
onData = defer.Deferred()
411
serverOpts = serverCert.options(serverSelfCert)
412
clientOpts = clientCert.options(clientSelfCert)
414
self.loopback(serverOpts,
418
return onData.addCallback(lambda result: self.assertEquals(result, WritingProtocol.byte))
421
class _NotSSLTransport:
425
class _MaybeSSLTransport:
429
def get_peer_certificate(self):
432
def get_host_certificate(self):
436
class _ActualSSLTransport:
440
def get_host_certificate(self):
441
return sslverify.Certificate.loadPEM(A_HOST_CERTIFICATE_PEM).original
443
def get_peer_certificate(self):
444
return sslverify.Certificate.loadPEM(A_PEER_CERTIFICATE_PEM).original
447
class Constructors(unittest.TestCase):
448
def test_peerFromNonSSLTransport(self):
450
Verify that peerFromTransport raises an exception if the transport
451
passed is not actually an SSL transport.
453
x = self.assertRaises(CertificateError,
454
sslverify.Certificate.peerFromTransport,
456
self.failUnless(str(x).startswith("non-TLS"))
458
def test_peerFromBlankSSLTransport(self):
460
Verify that peerFromTransport raises an exception if the transport
461
passed is an SSL transport, but doesn't have a peer certificate.
463
x = self.assertRaises(CertificateError,
464
sslverify.Certificate.peerFromTransport,
465
_MaybeSSLTransport())
466
self.failUnless(str(x).startswith("TLS"))
468
def test_hostFromNonSSLTransport(self):
470
Verify that hostFromTransport raises an exception if the transport
471
passed is not actually an SSL transport.
473
x = self.assertRaises(CertificateError,
474
sslverify.Certificate.hostFromTransport,
476
self.failUnless(str(x).startswith("non-TLS"))
478
def test_hostFromBlankSSLTransport(self):
480
Verify that hostFromTransport raises an exception if the transport
481
passed is an SSL transport, but doesn't have a host certificate.
483
x = self.assertRaises(CertificateError,
484
sslverify.Certificate.hostFromTransport,
485
_MaybeSSLTransport())
486
self.failUnless(str(x).startswith("TLS"))
489
def test_hostFromSSLTransport(self):
491
Verify that hostFromTransport successfully creates the correct certificate
492
if passed a valid SSL transport.
495
sslverify.Certificate.hostFromTransport(
496
_ActualSSLTransport()).serialNumber(),
499
def test_peerFromSSLTransport(self):
501
Verify that peerFromTransport successfully creates the correct certificate
502
if passed a valid SSL transport.
505
sslverify.Certificate.peerFromTransport(
506
_ActualSSLTransport()).serialNumber(),