1
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
2
# See LICENSE for details.
5
Tests for twisted SSL support.
8
from twisted.trial import unittest
9
from twisted.internet import protocol, reactor, interfaces, defer
10
from twisted.protocols import basic
11
from twisted.python import util
12
from twisted.python.reflect import getClass, fullyQualifiedName
13
from twisted.python.runtime import platform
14
from twisted.test.test_tcp import WriteDataTestCase, ProperlyCloseFilesMixin
19
from OpenSSL import SSL, crypto
20
from twisted.internet import ssl
21
from twisted.test.ssl_helpers import ClientTLSContext
24
# ugh, make pyflakes happy.
30
certPath = util.sibpath(__file__, "server.pem")
34
class UnintelligentProtocol(basic.LineReceiver):
36
@ivar deferred: a deferred that will fire at connection lost.
37
@type deferred: L{defer.Deferred}
39
@cvar pretext: text sent before TLS is set up.
42
@cvar posttext: text sent after TLS is set up.
43
@type posttext: C{str}
47
"last thing before tls starts",
51
"first thing after tls started",
55
self.deferred = defer.Deferred()
58
def connectionMade(self):
59
for l in self.pretext:
63
def lineReceived(self, line):
65
self.transport.startTLS(ClientTLSContext(), self.factory.client)
66
for l in self.posttext:
68
self.transport.loseConnection()
71
def connectionLost(self, reason):
72
self.deferred.callback(None)
76
class LineCollector(basic.LineReceiver):
78
@ivar deferred: a deferred that will fire at connection lost.
79
@type deferred: L{defer.Deferred}
81
@ivar doTLS: whether the protocol is initiate TLS or not.
84
@ivar fillBuffer: if set to True, it will send lots of data once
85
C{STARTTLS} is received.
86
@type fillBuffer: C{bool}
89
def __init__(self, doTLS, fillBuffer=False):
91
self.fillBuffer = fillBuffer
92
self.deferred = defer.Deferred()
95
def connectionMade(self):
96
self.factory.rawdata = ''
97
self.factory.lines = []
100
def lineReceived(self, line):
101
self.factory.lines.append(line)
102
if line == 'STARTTLS':
105
self.sendLine('X' * 1000)
106
self.sendLine('READY')
108
ctx = ServerTLSContext(
109
privateKeyFileName=certPath,
110
certificateFileName=certPath,
112
self.transport.startTLS(ctx, self.factory.server)
117
def rawDataReceived(self, data):
118
self.factory.rawdata += data
119
self.transport.loseConnection()
122
def connectionLost(self, reason):
123
self.deferred.callback(None)
127
class SingleLineServerProtocol(protocol.Protocol):
129
A protocol that sends a single line of data at C{connectionMade}.
132
def connectionMade(self):
133
self.transport.write("+OK <some crap>\r\n")
134
self.transport.getPeerCertificate()
138
class RecordingClientProtocol(protocol.Protocol):
140
@ivar deferred: a deferred that will fire with first received content.
141
@type deferred: L{defer.Deferred}
145
self.deferred = defer.Deferred()
148
def connectionMade(self):
149
self.transport.getPeerCertificate()
152
def dataReceived(self, data):
153
self.deferred.callback(data)
157
class ImmediatelyDisconnectingProtocol(protocol.Protocol):
159
A protocol that disconnect immediately on connection. It fires the
160
C{connectionDisconnected} deferred of its factory on connetion lost.
163
def connectionMade(self):
164
self.transport.loseConnection()
167
def connectionLost(self, reason):
168
self.factory.connectionDisconnected.callback(None)
172
def generateCertificateObjects(organization, organizationalUnit):
174
Create a certificate for given C{organization} and C{organizationalUnit}.
176
@return: a tuple of (key, request, certificate) objects.
179
pkey.generate_key(crypto.TYPE_RSA, 512)
180
req = crypto.X509Req()
181
subject = req.get_subject()
182
subject.O = organization
183
subject.OU = organizationalUnit
185
req.sign(pkey, "md5")
187
# Here comes the actual certificate
189
cert.set_serial_number(1)
190
cert.gmtime_adj_notBefore(0)
191
cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
192
cert.set_issuer(req.get_subject())
193
cert.set_subject(req.get_subject())
194
cert.set_pubkey(req.get_pubkey())
195
cert.sign(pkey, "md5")
197
return pkey, req, cert
201
def generateCertificateFiles(basename, organization, organizationalUnit):
203
Create certificate files key, req and cert prefixed by C{basename} for
204
given C{organization} and C{organizationalUnit}.
206
pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
208
for ext, obj, dumpFunc in [
209
('key', pkey, crypto.dump_privatekey),
210
('req', req, crypto.dump_certificate_request),
211
('cert', cert, crypto.dump_certificate)]:
212
fName = os.extsep.join((basename, ext))
213
fObj = file(fName, 'w')
214
fObj.write(dumpFunc(crypto.FILETYPE_PEM, obj))
219
class ContextGeneratingMixin:
221
Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
224
@ivar clientBase: prefix of client certificate files.
225
@type clientBase: C{str}
227
@ivar serverBase: prefix of server certificate files.
228
@type serverBase: C{str}
230
@ivar clientCtxFactory: a generated context factory to be used in
231
C{reactor.connectSSL}.
232
@type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
234
@ivar serverCtxFactory: a generated context factory to be used in
235
C{reactor.listenSSL}.
236
@type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
239
def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
241
generateCertificateFiles(base, org, orgUnit)
242
serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
243
os.extsep.join((base, 'key')),
244
os.extsep.join((base, 'cert')),
247
return base, serverCtxFactory
250
def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs,
252
self.clientBase, self.clientCtxFactory = self.makeContextFactory(
253
*clientArgs, **clientKwArgs)
254
self.serverBase, self.serverCtxFactory = self.makeContextFactory(
255
*serverArgs, **serverKwArgs)
260
class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
262
A context factory with a default method set to L{SSL.TLSv1_METHOD}.
266
def __init__(self, *args, **kw):
267
kw['sslmethod'] = SSL.TLSv1_METHOD
268
ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
272
class StolenTCPTestCase(ProperlyCloseFilesMixin, unittest.TestCase):
274
For SSL transports, test many of the same things which are tested for
278
def createServer(self, address, portNumber, factory):
280
Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
282
cert = ssl.PrivateCertificate.loadPEM(file(certPath).read())
283
contextFactory = cert.options()
284
return reactor.listenSSL(
285
portNumber, factory, contextFactory, interface=address)
288
def connectClient(self, address, portNumber, clientCreator):
290
Create an SSL client using L{IReactorSSL.connectSSL}.
292
contextFactory = ssl.CertificateOptions()
293
return clientCreator.connectSSL(address, portNumber, contextFactory)
296
def getHandleExceptionType(self):
298
Return L{SSL.Error} as the expected error type which will be raised by
299
a write to the L{OpenSSL.SSL.Connection} object after it has been
305
_iocp = 'twisted.internet.iocpreactor.reactor.IOCPReactor'
307
def getHandleErrorCode(self):
309
Return the argument L{SSL.Error} will be constructed with for this
310
case. This is basically just a random OpenSSL implementation detail.
311
It would be better if this test worked in a way which did not require
314
# Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for
315
# SSL.Connection.write for some reason. The twisted.protocols.tls
316
# implementation of IReactorSSL doesn't suffer from this imprecation,
317
# though, since it is isolated from the Windows I/O layer (I suppose?).
319
# If test_properlyCloseFiles waited for the SSL handshake to complete
320
# and performed an orderly shutdown, then this would probably be a
321
# little less weird: writing to a shutdown SSL connection has a more
322
# well-defined failure mode (or at least it should).
323
name = fullyQualifiedName(getClass(reactor))
324
if platform.getType() == 'win32' and name != self._iocp:
325
return errno.WSAENOTSOCK
326
# This is terribly implementation-specific.
327
return [('SSL routines', 'SSL_write', 'protocol is shutdown')]
331
class TLSTestCase(unittest.TestCase):
333
Tests for startTLS support.
335
@ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
336
@type fillBuffer: C{bool}
345
if self.clientProto.transport is not None:
346
self.clientProto.transport.loseConnection()
347
if self.serverProto.transport is not None:
348
self.serverProto.transport.loseConnection()
351
def _runTest(self, clientProto, serverProto, clientIsServer=False):
353
Helper method to run TLS tests.
355
@param clientProto: protocol instance attached to the client
357
@param serverProto: protocol instance attached to the server
359
@param clientIsServer: flag indicated if client should initiate
360
startTLS instead of server.
362
@return: a L{defer.Deferred} that will fire when both connections are
365
self.clientProto = clientProto
366
cf = self.clientFactory = protocol.ClientFactory()
367
cf.protocol = lambda: clientProto
373
self.serverProto = serverProto
374
sf = self.serverFactory = protocol.ServerFactory()
375
sf.protocol = lambda: serverProto
381
port = reactor.listenTCP(0, sf, interface="127.0.0.1")
382
self.addCleanup(port.stopListening)
384
reactor.connectTCP('127.0.0.1', port.getHost().port, cf)
386
return defer.gatherResults([clientProto.deferred, serverProto.deferred])
391
Test for server and client startTLS: client should received data both
392
before and after the startTLS.
396
self.serverFactory.lines,
397
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
399
d = self._runTest(UnintelligentProtocol(),
400
LineCollector(True, self.fillBuffer))
401
return d.addCallback(check)
404
def test_unTLS(self):
406
Test for server startTLS not followed by a startTLS in client: the data
407
received after server startTLS should be received as raw.
411
self.serverFactory.lines,
412
UnintelligentProtocol.pretext
414
self.failUnless(self.serverFactory.rawdata,
415
"No encrypted bytes received")
416
d = self._runTest(UnintelligentProtocol(),
417
LineCollector(False, self.fillBuffer))
418
return d.addCallback(check)
421
def test_backwardsTLS(self):
423
Test startTLS first initiated by client.
427
self.clientFactory.lines,
428
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
430
d = self._runTest(LineCollector(True, self.fillBuffer),
431
UnintelligentProtocol(), True)
432
return d.addCallback(check)
436
class SpammyTLSTestCase(TLSTestCase):
438
Test TLS features with bytes sitting in the out buffer.
444
class BufferingTestCase(unittest.TestCase):
450
if self.serverProto.transport is not None:
451
self.serverProto.transport.loseConnection()
452
if self.clientProto.transport is not None:
453
self.clientProto.transport.loseConnection()
456
def test_openSSLBuffering(self):
457
serverProto = self.serverProto = SingleLineServerProtocol()
458
clientProto = self.clientProto = RecordingClientProtocol()
460
server = protocol.ServerFactory()
461
client = self.client = protocol.ClientFactory()
463
server.protocol = lambda: serverProto
464
client.protocol = lambda: clientProto
466
sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
467
cCTX = ssl.ClientContextFactory()
469
port = reactor.listenSSL(0, server, sCTX, interface='127.0.0.1')
470
self.addCleanup(port.stopListening)
472
reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX)
474
return clientProto.deferred.addCallback(
475
self.assertEquals, "+OK <some crap>\r\n")
479
class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin):
481
def testImmediateDisconnect(self):
482
org = "twisted.test.test_ssl"
483
self.setupServerAndClient(
484
(org, org + ", client"), {},
485
(org, org + ", server"), {})
487
# Set up a server, connect to it with a client, which should work since our verifiers
488
# allow anything, then disconnect.
489
serverProtocolFactory = protocol.ServerFactory()
490
serverProtocolFactory.protocol = protocol.Protocol
491
self.serverPort = serverPort = reactor.listenSSL(0,
492
serverProtocolFactory, self.serverCtxFactory)
494
clientProtocolFactory = protocol.ClientFactory()
495
clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
496
clientProtocolFactory.connectionDisconnected = defer.Deferred()
497
clientConnector = reactor.connectSSL('127.0.0.1',
498
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
500
return clientProtocolFactory.connectionDisconnected.addCallback(
501
lambda ignoredResult: self.serverPort.stopListening())
504
def testFailedVerify(self):
505
org = "twisted.test.test_ssl"
506
self.setupServerAndClient(
507
(org, org + ", client"), {},
508
(org, org + ", server"), {})
512
self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
514
serverConnLost = defer.Deferred()
515
serverProtocol = protocol.Protocol()
516
serverProtocol.connectionLost = serverConnLost.callback
517
serverProtocolFactory = protocol.ServerFactory()
518
serverProtocolFactory.protocol = lambda: serverProtocol
519
self.serverPort = serverPort = reactor.listenSSL(0,
520
serverProtocolFactory, self.serverCtxFactory)
522
clientConnLost = defer.Deferred()
523
clientProtocol = protocol.Protocol()
524
clientProtocol.connectionLost = clientConnLost.callback
525
clientProtocolFactory = protocol.ClientFactory()
526
clientProtocolFactory.protocol = lambda: clientProtocol
527
clientConnector = reactor.connectSSL('127.0.0.1',
528
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
530
dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
531
return dl.addCallback(self._cbLostConns)
534
def _cbLostConns(self, results):
535
(sSuccess, sResult), (cSuccess, cResult) = results
537
self.failIf(sSuccess)
538
self.failIf(cSuccess)
540
acceptableErrors = [SSL.Error]
542
# Rather than getting a verification failure on Windows, we are getting
543
# a connection failure. Without something like sslverify proxying
544
# in-between we can't fix up the platform's errors, so let's just
545
# specifically say it is only OK in this one case to keep the tests
546
# passing. Normally we'd like to be as strict as possible here, so
547
# we're not going to allow this to report errors incorrectly on any
550
if platform.isWindows():
551
from twisted.internet.error import ConnectionLost
552
acceptableErrors.append(ConnectionLost)
554
sResult.trap(*acceptableErrors)
555
cResult.trap(*acceptableErrors)
557
return self.serverPort.stopListening()
563
L{OpenSSL.SSL.Context} double which can more easily be inspected.
565
def __init__(self, method):
566
self._method = method
570
def set_options(self, options):
571
self._options |= options
574
def use_certificate_file(self, fileName):
578
def use_privatekey_file(self, fileName):
583
class DefaultOpenSSLContextFactoryTests(unittest.TestCase):
585
Tests for L{ssl.DefaultOpenSSLContextFactory}.
588
# pyOpenSSL Context objects aren't introspectable enough. Pass in
589
# an alternate context factory so we can inspect what is done to it.
590
self.contextFactory = ssl.DefaultOpenSSLContextFactory(
591
certPath, certPath, _contextFactory=FakeContext)
592
self.context = self.contextFactory.getContext()
595
def test_method(self):
597
L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
598
which can use SSLv3 or TLSv1 but not SSLv2.
600
# SSLv23_METHOD allows SSLv2, SSLv3, or TLSv1
601
self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
603
# And OP_NO_SSLv2 disables the SSLv2 support.
604
self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
606
# Make sure SSLv3 and TLSv1 aren't disabled though.
607
self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
608
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
611
def test_missingCertificateFile(self):
613
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
614
filename which does not identify an existing file results in the
615
initializer raising L{OpenSSL.SSL.Error}.
619
ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp())
622
def test_missingPrivateKeyFile(self):
624
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
625
filename which does not identify an existing file results in the
626
initializer raising L{OpenSSL.SSL.Error}.
630
ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath)
634
class ClientContextFactoryTests(unittest.TestCase):
636
Tests for L{ssl.ClientContextFactory}.
639
self.contextFactory = ssl.ClientContextFactory()
640
self.contextFactory._contextFactory = FakeContext
641
self.context = self.contextFactory.getContext()
644
def test_method(self):
646
L{ssl.ClientContextFactory.getContext} returns a context which can use
647
SSLv3 or TLSv1 but not SSLv2.
649
self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
650
self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
651
self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
652
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
656
if interfaces.IReactorSSL(reactor, None) is None:
657
for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase,
658
BufferingTestCase, ConnectionLostTestCase,
659
DefaultOpenSSLContextFactoryTests,
660
ClientContextFactoryTests]:
661
tCase.skip = "Reactor does not support SSL, cannot run SSL tests"
663
# Otherwise trial will run this test here
664
del WriteDataTestCase