1
# Copyright (c) 2009 Twisted Matrix Laboratories.
2
# See LICENSE for details.
5
Tests for L{twisted.protocols.tls}.
9
from twisted.protocols.tls import TLSMemoryBIOProtocol, TLSMemoryBIOFactory
11
# Skip the whole test module if it can't be imported.
12
skip = "pyOpenSSL 0.10 or newer required for twisted.protocol.tls"
14
# Otherwise, the pyOpenSSL dependency must be satisfied, so all these
16
from OpenSSL.crypto import X509Type
17
from OpenSSL.SSL import TLSv1_METHOD, Error, Context, ConnectionType
18
from twisted.internet.ssl import ClientContextFactory, PrivateCertificate
19
from twisted.internet.ssl import DefaultOpenSSLContextFactory
21
from twisted.python.filepath import FilePath
22
from twisted.internet.interfaces import ISystemHandle, ISSLTransport
23
from twisted.internet.error import ConnectionDone
24
from twisted.internet.defer import Deferred, gatherResults
25
from twisted.internet.protocol import Protocol, ClientFactory, ServerFactory
26
from twisted.protocols.loopback import loopbackAsync, collapsingPumpPolicy
27
from twisted.trial.unittest import TestCase
28
from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
29
from twisted.test.test_ssl import certPath
30
from twisted.test.proto_helpers import StringTransport
33
class HandshakeCallbackContextFactory:
35
L{HandshakeCallbackContextFactory} is a factory for SSL contexts which
36
allows applications to get notification when the SSL handshake completes.
38
@ivar _finished: A L{Deferred} which will be called back when the handshake
41
# pyOpenSSL needs to expose this.
42
# https://bugs.launchpad.net/pyopenssl/+bug/372832
43
SSL_CB_HANDSHAKE_DONE = 0x20
46
self._finished = Deferred()
49
def factoryAndDeferred(cls):
51
Create a new L{HandshakeCallbackContextFactory} and return a two-tuple
52
of it and a L{Deferred} which will fire when a connection created with
53
it completes a TLS handshake.
55
contextFactory = cls()
56
return contextFactory, contextFactory._finished
57
factoryAndDeferred = classmethod(factoryAndDeferred)
60
def _info(self, connection, where, ret):
62
This is the "info callback" on the context. It will be called
63
periodically by pyOpenSSL with information about the state of a
64
connection. When it indicates the handshake is complete, it will fire
67
if where & self.SSL_CB_HANDSHAKE_DONE:
68
self._finished.callback(None)
73
Create and return an SSL context configured to use L{self._info} as the
76
context = Context(TLSv1_METHOD)
77
context.set_info_callback(self._info)
82
class AccumulatingProtocol(Protocol):
84
A protocol which collects the bytes it receives and closes its connection
85
after receiving a certain minimum of data.
87
@ivar howMany: The number of bytes of data to wait for before closing the connection.
88
@ivar receiving: A C{list} of C{str} of the bytes received so far.
90
def __init__(self, howMany):
91
self.howMany = howMany
94
def connectionMade(self):
98
def dataReceived(self, bytes):
99
self.received.append(bytes)
100
if sum(map(len, self.received)) >= self.howMany:
101
self.transport.loseConnection()
105
class TLSMemoryBIOTests(TestCase):
107
Tests for the implementation of L{ISSLTransport} which runs over another
110
def test_interfaces(self):
112
L{TLSMemoryBIOProtocol} instances provide L{ISSLTransport} and
115
proto = TLSMemoryBIOProtocol(None, None)
116
self.assertTrue(ISSLTransport.providedBy(proto))
117
self.assertTrue(ISystemHandle.providedBy(proto))
120
def test_getHandle(self):
122
L{TLSMemoryBIOProtocol.getHandle} returns the L{OpenSSL.SSL.Connection}
123
instance it uses to actually implement TLS.
125
This may seem odd. In fact, it is. The L{OpenSSL.SSL.Connection} is
126
not actually the "system handle" here, nor even an object the reactor
127
knows about directly. However, L{twisted.internet.ssl.Certificate}'s
128
C{peerFromTransport} and C{hostFromTransport} methods depend on being
129
able to get an L{OpenSSL.SSL.Connection} object in order to work
130
properly. Implementing L{ISystemHandle.getHandle} like this is the
131
easiest way for those APIs to be made to work. If they are changed,
132
then it may make sense to get rid of this implementation of
133
L{ISystemHandle} and return the underlying socket instead.
135
factory = ClientFactory()
136
contextFactory = ClientContextFactory()
137
wrapperFactory = TLSMemoryBIOFactory(contextFactory, True, factory)
138
proto = TLSMemoryBIOProtocol(wrapperFactory, Protocol())
139
transport = StringTransport()
140
proto.makeConnection(transport)
141
self.assertIsInstance(proto.getHandle(), ConnectionType)
144
def test_makeConnection(self):
146
When L{TLSMemoryBIOProtocol} is connected to a transport, it connects
147
the protocol it wraps to a transport.
149
clientProtocol = Protocol()
150
clientFactory = ClientFactory()
151
clientFactory.protocol = lambda: clientProtocol
153
contextFactory = ClientContextFactory()
154
wrapperFactory = TLSMemoryBIOFactory(
155
contextFactory, True, clientFactory)
156
sslProtocol = wrapperFactory.buildProtocol(None)
158
transport = StringTransport()
159
sslProtocol.makeConnection(transport)
161
self.assertNotIdentical(clientProtocol.transport, None)
162
self.assertNotIdentical(clientProtocol.transport, transport)
165
def test_handshake(self):
167
The TLS handshake is performed when L{TLSMemoryBIOProtocol} is
168
connected to a transport.
170
clientFactory = ClientFactory()
171
clientFactory.protocol = Protocol
173
clientContextFactory, handshakeDeferred = (
174
HandshakeCallbackContextFactory.factoryAndDeferred())
175
wrapperFactory = TLSMemoryBIOFactory(
176
clientContextFactory, True, clientFactory)
177
sslClientProtocol = wrapperFactory.buildProtocol(None)
179
serverFactory = ServerFactory()
180
serverFactory.protocol = Protocol
182
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
183
wrapperFactory = TLSMemoryBIOFactory(
184
serverContextFactory, False, serverFactory)
185
sslServerProtocol = wrapperFactory.buildProtocol(None)
187
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
189
# Only wait for the handshake to complete. Anything after that isn't
191
return handshakeDeferred
194
def test_handshakeFailure(self):
196
L{TLSMemoryBIOProtocol} reports errors in the handshake process to the
197
application-level protocol object using its C{connectionLost} method
198
and disconnects the underlying transport.
200
clientConnectionLost = Deferred()
201
clientFactory = ClientFactory()
202
clientFactory.protocol = (
203
lambda: ConnectionLostNotifyingProtocol(
204
clientConnectionLost))
206
clientContextFactory = HandshakeCallbackContextFactory()
207
wrapperFactory = TLSMemoryBIOFactory(
208
clientContextFactory, True, clientFactory)
209
sslClientProtocol = wrapperFactory.buildProtocol(None)
211
serverConnectionLost = Deferred()
212
serverFactory = ServerFactory()
213
serverFactory.protocol = (
214
lambda: ConnectionLostNotifyingProtocol(
215
serverConnectionLost))
217
# This context factory rejects any clients which do not present a
219
certificateData = FilePath(certPath).getContent()
220
certificate = PrivateCertificate.loadPEM(certificateData)
221
serverContextFactory = certificate.options(certificate)
222
wrapperFactory = TLSMemoryBIOFactory(
223
serverContextFactory, False, serverFactory)
224
sslServerProtocol = wrapperFactory.buildProtocol(None)
226
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
228
def cbConnectionLost(protocol):
229
# The connection should close on its own in response to the error
230
# induced by the client not supplying the required certificate.
231
# After that, check to make sure the protocol's connectionLost was
232
# called with the right thing.
233
protocol.lostConnectionReason.trap(Error)
234
clientConnectionLost.addCallback(cbConnectionLost)
235
serverConnectionLost.addCallback(cbConnectionLost)
237
# Additionally, the underlying transport should have been told to
239
return gatherResults([
240
clientConnectionLost, serverConnectionLost,
244
def test_getPeerCertificate(self):
246
L{TLSMemoryBIOFactory.getPeerCertificate} returns the
247
L{OpenSSL.crypto.X509Type} instance representing the peer's
250
# Set up a client and server so there's a certificate to grab.
251
clientFactory = ClientFactory()
252
clientFactory.protocol = Protocol
254
clientContextFactory, handshakeDeferred = (
255
HandshakeCallbackContextFactory.factoryAndDeferred())
256
wrapperFactory = TLSMemoryBIOFactory(
257
clientContextFactory, True, clientFactory)
258
sslClientProtocol = wrapperFactory.buildProtocol(None)
260
serverFactory = ServerFactory()
261
serverFactory.protocol = Protocol
263
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
264
wrapperFactory = TLSMemoryBIOFactory(
265
serverContextFactory, False, serverFactory)
266
sslServerProtocol = wrapperFactory.buildProtocol(None)
268
connectionDeferred = loopbackAsync(
269
sslServerProtocol, sslClientProtocol)
271
# Wait for the handshake
272
def cbHandshook(ignored):
273
# Grab the server's certificate and check it out
274
cert = sslClientProtocol.getPeerCertificate()
275
self.assertIsInstance(cert, X509Type)
278
'9B:A4:AB:43:10:BE:82:AE:94:3E:6B:91:F2:F3:40:E8')
279
handshakeDeferred.addCallback(cbHandshook)
280
return handshakeDeferred
283
def test_writeAfterHandshake(self):
285
Bytes written to L{TLSMemoryBIOProtocol} before the handshake is
286
complete are received by the protocol on the other side of the
287
connection once the handshake succeeds.
291
clientProtocol = Protocol()
292
clientFactory = ClientFactory()
293
clientFactory.protocol = lambda: clientProtocol
295
clientContextFactory, handshakeDeferred = (
296
HandshakeCallbackContextFactory.factoryAndDeferred())
297
wrapperFactory = TLSMemoryBIOFactory(
298
clientContextFactory, True, clientFactory)
299
sslClientProtocol = wrapperFactory.buildProtocol(None)
301
serverProtocol = AccumulatingProtocol(len(bytes))
302
serverFactory = ServerFactory()
303
serverFactory.protocol = lambda: serverProtocol
305
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
306
wrapperFactory = TLSMemoryBIOFactory(
307
serverContextFactory, False, serverFactory)
308
sslServerProtocol = wrapperFactory.buildProtocol(None)
310
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
312
# Wait for the handshake to finish before writing anything.
313
def cbHandshook(ignored):
314
clientProtocol.transport.write(bytes)
316
# The server will drop the connection once it gets the bytes.
317
return connectionDeferred
318
handshakeDeferred.addCallback(cbHandshook)
320
# Once the connection is lost, make sure the server received the
322
def cbDisconnected(ignored):
323
self.assertEquals("".join(serverProtocol.received), bytes)
324
handshakeDeferred.addCallback(cbDisconnected)
326
return handshakeDeferred
329
def test_writeBeforeHandshake(self):
331
Bytes written to L{TLSMemoryBIOProtocol} before the handshake is
332
complete are received by the protocol on the other side of the
333
connection once the handshake succeeds.
337
class SimpleSendingProtocol(Protocol):
338
def connectionMade(self):
339
self.transport.write(bytes)
341
clientFactory = ClientFactory()
342
clientFactory.protocol = SimpleSendingProtocol
344
clientContextFactory, handshakeDeferred = (
345
HandshakeCallbackContextFactory.factoryAndDeferred())
346
wrapperFactory = TLSMemoryBIOFactory(
347
clientContextFactory, True, clientFactory)
348
sslClientProtocol = wrapperFactory.buildProtocol(None)
350
serverProtocol = AccumulatingProtocol(len(bytes))
351
serverFactory = ServerFactory()
352
serverFactory.protocol = lambda: serverProtocol
354
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
355
wrapperFactory = TLSMemoryBIOFactory(
356
serverContextFactory, False, serverFactory)
357
sslServerProtocol = wrapperFactory.buildProtocol(None)
359
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
361
# Wait for the connection to end, then make sure the server received
362
# the bytes sent by the client.
363
def cbConnectionDone(ignored):
364
self.assertEquals("".join(serverProtocol.received), bytes)
365
connectionDeferred.addCallback(cbConnectionDone)
366
return connectionDeferred
369
def test_writeSequence(self):
371
Bytes written to L{TLSMemoryBIOProtocol} with C{writeSequence} are
372
received by the protocol on the other side of the connection.
375
class SimpleSendingProtocol(Protocol):
376
def connectionMade(self):
377
self.transport.writeSequence(list(bytes))
379
clientFactory = ClientFactory()
380
clientFactory.protocol = SimpleSendingProtocol
382
clientContextFactory = HandshakeCallbackContextFactory()
383
wrapperFactory = TLSMemoryBIOFactory(
384
clientContextFactory, True, clientFactory)
385
sslClientProtocol = wrapperFactory.buildProtocol(None)
387
serverProtocol = AccumulatingProtocol(len(bytes))
388
serverFactory = ServerFactory()
389
serverFactory.protocol = lambda: serverProtocol
391
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
392
wrapperFactory = TLSMemoryBIOFactory(
393
serverContextFactory, False, serverFactory)
394
sslServerProtocol = wrapperFactory.buildProtocol(None)
396
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
398
# Wait for the connection to end, then make sure the server received
399
# the bytes sent by the client.
400
def cbConnectionDone(ignored):
401
self.assertEquals("".join(serverProtocol.received), bytes)
402
connectionDeferred.addCallback(cbConnectionDone)
403
return connectionDeferred
406
def test_multipleWrites(self):
408
If multiple separate TLS messages are received in a single chunk from
409
the underlying transport, all of the application bytes from each
410
message are delivered to the application-level protocol.
412
bytes = [str(i) for i in range(10)]
413
class SimpleSendingProtocol(Protocol):
414
def connectionMade(self):
416
self.transport.write(b)
418
clientFactory = ClientFactory()
419
clientFactory.protocol = SimpleSendingProtocol
421
clientContextFactory = HandshakeCallbackContextFactory()
422
wrapperFactory = TLSMemoryBIOFactory(
423
clientContextFactory, True, clientFactory)
424
sslClientProtocol = wrapperFactory.buildProtocol(None)
426
serverProtocol = AccumulatingProtocol(sum(map(len, bytes)))
427
serverFactory = ServerFactory()
428
serverFactory.protocol = lambda: serverProtocol
430
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
431
wrapperFactory = TLSMemoryBIOFactory(
432
serverContextFactory, False, serverFactory)
433
sslServerProtocol = wrapperFactory.buildProtocol(None)
435
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol, collapsingPumpPolicy)
437
# Wait for the connection to end, then make sure the server received
438
# the bytes sent by the client.
439
def cbConnectionDone(ignored):
440
self.assertEquals("".join(serverProtocol.received), ''.join(bytes))
441
connectionDeferred.addCallback(cbConnectionDone)
442
return connectionDeferred
445
def test_hugeWrite(self):
447
If a very long string is passed to L{TLSMemoryBIOProtocol.write}, any
448
trailing part of it which cannot be send immediately is buffered and
453
class SimpleSendingProtocol(Protocol):
454
def connectionMade(self):
455
self.transport.write(bytes * factor)
457
clientFactory = ClientFactory()
458
clientFactory.protocol = SimpleSendingProtocol
460
clientContextFactory = HandshakeCallbackContextFactory()
461
wrapperFactory = TLSMemoryBIOFactory(
462
clientContextFactory, True, clientFactory)
463
sslClientProtocol = wrapperFactory.buildProtocol(None)
465
serverProtocol = AccumulatingProtocol(len(bytes) * factor)
466
serverFactory = ServerFactory()
467
serverFactory.protocol = lambda: serverProtocol
469
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
470
wrapperFactory = TLSMemoryBIOFactory(
471
serverContextFactory, False, serverFactory)
472
sslServerProtocol = wrapperFactory.buildProtocol(None)
474
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
476
# Wait for the connection to end, then make sure the server received
477
# the bytes sent by the client.
478
def cbConnectionDone(ignored):
479
self.assertEquals("".join(serverProtocol.received), bytes * factor)
480
connectionDeferred.addCallback(cbConnectionDone)
481
return connectionDeferred
484
def test_disorderlyShutdown(self):
486
If a L{TLSMemoryBIOProtocol} loses its connection unexpectedly, this is
487
reported to the application.
489
clientConnectionLost = Deferred()
490
clientFactory = ClientFactory()
491
clientFactory.protocol = (
492
lambda: ConnectionLostNotifyingProtocol(
493
clientConnectionLost))
495
clientContextFactory = HandshakeCallbackContextFactory()
496
wrapperFactory = TLSMemoryBIOFactory(
497
clientContextFactory, True, clientFactory)
498
sslClientProtocol = wrapperFactory.buildProtocol(None)
500
# Client speaks first, so the server can be dumb.
501
serverProtocol = Protocol()
503
connectionDeferred = loopbackAsync(serverProtocol, sslClientProtocol)
505
# Now destroy the connection.
506
serverProtocol.transport.loseConnection()
508
# And when the connection completely dies, check the reason.
509
def cbDisconnected(clientProtocol):
510
clientProtocol.lostConnectionReason.trap(Error)
511
clientConnectionLost.addCallback(cbDisconnected)
512
return clientConnectionLost
515
def test_loseConnectionAfterHandshake(self):
517
L{TLSMemoryBIOProtocol.loseConnection} sends a TLS close alert and
518
shuts down the underlying connection.
520
clientConnectionLost = Deferred()
521
clientFactory = ClientFactory()
522
clientFactory.protocol = (
523
lambda: ConnectionLostNotifyingProtocol(
524
clientConnectionLost))
526
clientContextFactory, handshakeDeferred = (
527
HandshakeCallbackContextFactory.factoryAndDeferred())
528
wrapperFactory = TLSMemoryBIOFactory(
529
clientContextFactory, True, clientFactory)
530
sslClientProtocol = wrapperFactory.buildProtocol(None)
532
serverProtocol = Protocol()
533
serverFactory = ServerFactory()
534
serverFactory.protocol = lambda: serverProtocol
536
serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
537
wrapperFactory = TLSMemoryBIOFactory(
538
serverContextFactory, False, serverFactory)
539
sslServerProtocol = wrapperFactory.buildProtocol(None)
541
connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
543
# Wait for the handshake before dropping the connection.
544
def cbHandshake(ignored):
545
serverProtocol.transport.loseConnection()
547
# Now wait for the client to notice.
548
return clientConnectionLost
549
handshakeDeferred.addCallback(cbHandshake)
551
# Wait for the connection to end, then make sure the client was
552
# notified of a handshake failure.
553
def cbConnectionDone(clientProtocol):
554
clientProtocol.lostConnectionReason.trap(ConnectionDone)
556
# The server should have closed its underlying transport, in
557
# addition to whatever it did to shut down the TLS layer.
558
self.assertTrue(serverProtocol.transport.q.disconnect)
560
# The client should also have closed its underlying transport once
561
# it saw the server shut down the TLS layer, so as to avoid relying
562
# on the server to close the underlying connection.
563
self.assertTrue(clientProtocol.transport.q.disconnect)
564
handshakeDeferred.addCallback(cbConnectionDone)
565
return handshakeDeferred