~justin-fathomdb/nova/justinsb-openstack-api-volumes

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/protocols/test/test_tls.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2009 Twisted Matrix Laboratories.
 
2
# See LICENSE for details.
 
3
 
 
4
"""
 
5
Tests for L{twisted.protocols.tls}.
 
6
"""
 
7
 
 
8
try:
 
9
    from twisted.protocols.tls import TLSMemoryBIOProtocol, TLSMemoryBIOFactory
 
10
except ImportError:
 
11
    # Skip the whole test module if it can't be imported.
 
12
    skip = "pyOpenSSL 0.10 or newer required for twisted.protocol.tls"
 
13
else:
 
14
    # Otherwise, the pyOpenSSL dependency must be satisfied, so all these
 
15
    # imports will work.
 
16
    from OpenSSL.crypto import X509Type
 
17
    from OpenSSL.SSL import TLSv1_METHOD, Error, Context, ConnectionType
 
18
    from twisted.internet.ssl import ClientContextFactory, PrivateCertificate
 
19
    from twisted.internet.ssl import DefaultOpenSSLContextFactory
 
20
 
 
21
from twisted.python.filepath import FilePath
 
22
from twisted.internet.interfaces import ISystemHandle, ISSLTransport
 
23
from twisted.internet.error import ConnectionDone
 
24
from twisted.internet.defer import Deferred, gatherResults
 
25
from twisted.internet.protocol import Protocol, ClientFactory, ServerFactory
 
26
from twisted.protocols.loopback import loopbackAsync, collapsingPumpPolicy
 
27
from twisted.trial.unittest import TestCase
 
28
from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
 
29
from twisted.test.test_ssl import certPath
 
30
from twisted.test.proto_helpers import StringTransport
 
31
 
 
32
 
 
33
class HandshakeCallbackContextFactory:
 
34
    """
 
35
    L{HandshakeCallbackContextFactory} is a factory for SSL contexts which
 
36
    allows applications to get notification when the SSL handshake completes.
 
37
 
 
38
    @ivar _finished: A L{Deferred} which will be called back when the handshake
 
39
        is done.
 
40
    """
 
41
    # pyOpenSSL needs to expose this.
 
42
    # https://bugs.launchpad.net/pyopenssl/+bug/372832
 
43
    SSL_CB_HANDSHAKE_DONE = 0x20
 
44
 
 
45
    def __init__(self):
 
46
        self._finished = Deferred()
 
47
 
 
48
 
 
49
    def factoryAndDeferred(cls):
 
50
        """
 
51
        Create a new L{HandshakeCallbackContextFactory} and return a two-tuple
 
52
        of it and a L{Deferred} which will fire when a connection created with
 
53
        it completes a TLS handshake.
 
54
        """
 
55
        contextFactory = cls()
 
56
        return contextFactory, contextFactory._finished
 
57
    factoryAndDeferred = classmethod(factoryAndDeferred)
 
58
 
 
59
 
 
60
    def _info(self, connection, where, ret):
 
61
        """
 
62
        This is the "info callback" on the context.  It will be called
 
63
        periodically by pyOpenSSL with information about the state of a
 
64
        connection.  When it indicates the handshake is complete, it will fire
 
65
        C{self._finished}.
 
66
        """
 
67
        if where & self.SSL_CB_HANDSHAKE_DONE:
 
68
            self._finished.callback(None)
 
69
 
 
70
 
 
71
    def getContext(self):
 
72
        """
 
73
        Create and return an SSL context configured to use L{self._info} as the
 
74
        info callback.
 
75
        """
 
76
        context = Context(TLSv1_METHOD)
 
77
        context.set_info_callback(self._info)
 
78
        return context
 
79
 
 
80
 
 
81
 
 
82
class AccumulatingProtocol(Protocol):
 
83
    """
 
84
    A protocol which collects the bytes it receives and closes its connection
 
85
    after receiving a certain minimum of data.
 
86
 
 
87
    @ivar howMany: The number of bytes of data to wait for before closing the connection.
 
88
    @ivar receiving: A C{list} of C{str} of the bytes received so far.
 
89
    """
 
90
    def __init__(self, howMany):
 
91
        self.howMany = howMany
 
92
 
 
93
 
 
94
    def connectionMade(self):
 
95
        self.received = []
 
96
 
 
97
 
 
98
    def dataReceived(self, bytes):
 
99
        self.received.append(bytes)
 
100
        if sum(map(len, self.received)) >= self.howMany:
 
101
            self.transport.loseConnection()
 
102
 
 
103
 
 
104
 
 
105
class TLSMemoryBIOTests(TestCase):
 
106
    """
 
107
    Tests for the implementation of L{ISSLTransport} which runs over another
 
108
    L{ITransport}.
 
109
    """
 
110
    def test_interfaces(self):
 
111
        """
 
112
        L{TLSMemoryBIOProtocol} instances provide L{ISSLTransport} and
 
113
        L{ISystemHandle}.
 
114
        """
 
115
        proto = TLSMemoryBIOProtocol(None, None)
 
116
        self.assertTrue(ISSLTransport.providedBy(proto))
 
117
        self.assertTrue(ISystemHandle.providedBy(proto))
 
118
 
 
119
 
 
120
    def test_getHandle(self):
 
121
        """
 
122
        L{TLSMemoryBIOProtocol.getHandle} returns the L{OpenSSL.SSL.Connection}
 
123
        instance it uses to actually implement TLS.
 
124
 
 
125
        This may seem odd.  In fact, it is.  The L{OpenSSL.SSL.Connection} is
 
126
        not actually the "system handle" here, nor even an object the reactor
 
127
        knows about directly.  However, L{twisted.internet.ssl.Certificate}'s
 
128
        C{peerFromTransport} and C{hostFromTransport} methods depend on being
 
129
        able to get an L{OpenSSL.SSL.Connection} object in order to work
 
130
        properly.  Implementing L{ISystemHandle.getHandle} like this is the
 
131
        easiest way for those APIs to be made to work.  If they are changed,
 
132
        then it may make sense to get rid of this implementation of
 
133
        L{ISystemHandle} and return the underlying socket instead.
 
134
        """
 
135
        factory = ClientFactory()
 
136
        contextFactory = ClientContextFactory()
 
137
        wrapperFactory = TLSMemoryBIOFactory(contextFactory, True, factory)
 
138
        proto = TLSMemoryBIOProtocol(wrapperFactory, Protocol())
 
139
        transport = StringTransport()
 
140
        proto.makeConnection(transport)
 
141
        self.assertIsInstance(proto.getHandle(), ConnectionType)
 
142
 
 
143
 
 
144
    def test_makeConnection(self):
 
145
        """
 
146
        When L{TLSMemoryBIOProtocol} is connected to a transport, it connects
 
147
        the protocol it wraps to a transport.
 
148
        """
 
149
        clientProtocol = Protocol()
 
150
        clientFactory = ClientFactory()
 
151
        clientFactory.protocol = lambda: clientProtocol
 
152
 
 
153
        contextFactory = ClientContextFactory()
 
154
        wrapperFactory = TLSMemoryBIOFactory(
 
155
            contextFactory, True, clientFactory)
 
156
        sslProtocol = wrapperFactory.buildProtocol(None)
 
157
 
 
158
        transport = StringTransport()
 
159
        sslProtocol.makeConnection(transport)
 
160
 
 
161
        self.assertNotIdentical(clientProtocol.transport, None)
 
162
        self.assertNotIdentical(clientProtocol.transport, transport)
 
163
 
 
164
 
 
165
    def test_handshake(self):
 
166
        """
 
167
        The TLS handshake is performed when L{TLSMemoryBIOProtocol} is
 
168
        connected to a transport.
 
169
        """
 
170
        clientFactory = ClientFactory()
 
171
        clientFactory.protocol = Protocol
 
172
 
 
173
        clientContextFactory, handshakeDeferred = (
 
174
            HandshakeCallbackContextFactory.factoryAndDeferred())
 
175
        wrapperFactory = TLSMemoryBIOFactory(
 
176
            clientContextFactory, True, clientFactory)
 
177
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
178
 
 
179
        serverFactory = ServerFactory()
 
180
        serverFactory.protocol = Protocol
 
181
 
 
182
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
183
        wrapperFactory = TLSMemoryBIOFactory(
 
184
            serverContextFactory, False, serverFactory)
 
185
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
186
 
 
187
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
188
 
 
189
        # Only wait for the handshake to complete.  Anything after that isn't
 
190
        # important here.
 
191
        return handshakeDeferred
 
192
 
 
193
 
 
194
    def test_handshakeFailure(self):
 
195
        """
 
196
        L{TLSMemoryBIOProtocol} reports errors in the handshake process to the
 
197
        application-level protocol object using its C{connectionLost} method
 
198
        and disconnects the underlying transport.
 
199
        """
 
200
        clientConnectionLost = Deferred()
 
201
        clientFactory = ClientFactory()
 
202
        clientFactory.protocol = (
 
203
            lambda: ConnectionLostNotifyingProtocol(
 
204
                clientConnectionLost))
 
205
 
 
206
        clientContextFactory = HandshakeCallbackContextFactory()
 
207
        wrapperFactory = TLSMemoryBIOFactory(
 
208
            clientContextFactory, True, clientFactory)
 
209
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
210
 
 
211
        serverConnectionLost = Deferred()
 
212
        serverFactory = ServerFactory()
 
213
        serverFactory.protocol = (
 
214
            lambda: ConnectionLostNotifyingProtocol(
 
215
                serverConnectionLost))
 
216
 
 
217
        # This context factory rejects any clients which do not present a
 
218
        # certificate.
 
219
        certificateData = FilePath(certPath).getContent()
 
220
        certificate = PrivateCertificate.loadPEM(certificateData)
 
221
        serverContextFactory = certificate.options(certificate)
 
222
        wrapperFactory = TLSMemoryBIOFactory(
 
223
            serverContextFactory, False, serverFactory)
 
224
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
225
 
 
226
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
227
 
 
228
        def cbConnectionLost(protocol):
 
229
            # The connection should close on its own in response to the error
 
230
            # induced by the client not supplying the required certificate.
 
231
            # After that, check to make sure the protocol's connectionLost was
 
232
            # called with the right thing.
 
233
            protocol.lostConnectionReason.trap(Error)
 
234
        clientConnectionLost.addCallback(cbConnectionLost)
 
235
        serverConnectionLost.addCallback(cbConnectionLost)
 
236
 
 
237
        # Additionally, the underlying transport should have been told to
 
238
        # go away.
 
239
        return gatherResults([
 
240
                clientConnectionLost, serverConnectionLost,
 
241
                connectionDeferred])
 
242
 
 
243
 
 
244
    def test_getPeerCertificate(self):
 
245
        """
 
246
        L{TLSMemoryBIOFactory.getPeerCertificate} returns the
 
247
        L{OpenSSL.crypto.X509Type} instance representing the peer's
 
248
        certificate.
 
249
        """
 
250
        # Set up a client and server so there's a certificate to grab.
 
251
        clientFactory = ClientFactory()
 
252
        clientFactory.protocol = Protocol
 
253
 
 
254
        clientContextFactory, handshakeDeferred = (
 
255
            HandshakeCallbackContextFactory.factoryAndDeferred())
 
256
        wrapperFactory = TLSMemoryBIOFactory(
 
257
            clientContextFactory, True, clientFactory)
 
258
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
259
 
 
260
        serverFactory = ServerFactory()
 
261
        serverFactory.protocol = Protocol
 
262
 
 
263
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
264
        wrapperFactory = TLSMemoryBIOFactory(
 
265
            serverContextFactory, False, serverFactory)
 
266
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
267
 
 
268
        connectionDeferred = loopbackAsync(
 
269
            sslServerProtocol, sslClientProtocol)
 
270
 
 
271
        # Wait for the handshake
 
272
        def cbHandshook(ignored):
 
273
            # Grab the server's certificate and check it out
 
274
            cert = sslClientProtocol.getPeerCertificate()
 
275
            self.assertIsInstance(cert, X509Type)
 
276
            self.assertEquals(
 
277
                cert.digest('md5'),
 
278
                '9B:A4:AB:43:10:BE:82:AE:94:3E:6B:91:F2:F3:40:E8')
 
279
        handshakeDeferred.addCallback(cbHandshook)
 
280
        return handshakeDeferred
 
281
 
 
282
 
 
283
    def test_writeAfterHandshake(self):
 
284
        """
 
285
        Bytes written to L{TLSMemoryBIOProtocol} before the handshake is
 
286
        complete are received by the protocol on the other side of the
 
287
        connection once the handshake succeeds.
 
288
        """
 
289
        bytes = "some bytes"
 
290
 
 
291
        clientProtocol = Protocol()
 
292
        clientFactory = ClientFactory()
 
293
        clientFactory.protocol = lambda: clientProtocol
 
294
 
 
295
        clientContextFactory, handshakeDeferred = (
 
296
            HandshakeCallbackContextFactory.factoryAndDeferred())
 
297
        wrapperFactory = TLSMemoryBIOFactory(
 
298
            clientContextFactory, True, clientFactory)
 
299
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
300
 
 
301
        serverProtocol = AccumulatingProtocol(len(bytes))
 
302
        serverFactory = ServerFactory()
 
303
        serverFactory.protocol = lambda: serverProtocol
 
304
 
 
305
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
306
        wrapperFactory = TLSMemoryBIOFactory(
 
307
            serverContextFactory, False, serverFactory)
 
308
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
309
 
 
310
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
311
 
 
312
        # Wait for the handshake to finish before writing anything.
 
313
        def cbHandshook(ignored):
 
314
            clientProtocol.transport.write(bytes)
 
315
 
 
316
            # The server will drop the connection once it gets the bytes.
 
317
            return connectionDeferred
 
318
        handshakeDeferred.addCallback(cbHandshook)
 
319
 
 
320
        # Once the connection is lost, make sure the server received the
 
321
        # expected bytes.
 
322
        def cbDisconnected(ignored):
 
323
            self.assertEquals("".join(serverProtocol.received), bytes)
 
324
        handshakeDeferred.addCallback(cbDisconnected)
 
325
 
 
326
        return handshakeDeferred
 
327
 
 
328
 
 
329
    def test_writeBeforeHandshake(self):
 
330
        """
 
331
        Bytes written to L{TLSMemoryBIOProtocol} before the handshake is
 
332
        complete are received by the protocol on the other side of the
 
333
        connection once the handshake succeeds.
 
334
        """
 
335
        bytes = "some bytes"
 
336
 
 
337
        class SimpleSendingProtocol(Protocol):
 
338
            def connectionMade(self):
 
339
                self.transport.write(bytes)
 
340
 
 
341
        clientFactory = ClientFactory()
 
342
        clientFactory.protocol = SimpleSendingProtocol
 
343
 
 
344
        clientContextFactory, handshakeDeferred = (
 
345
            HandshakeCallbackContextFactory.factoryAndDeferred())
 
346
        wrapperFactory = TLSMemoryBIOFactory(
 
347
            clientContextFactory, True, clientFactory)
 
348
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
349
 
 
350
        serverProtocol = AccumulatingProtocol(len(bytes))
 
351
        serverFactory = ServerFactory()
 
352
        serverFactory.protocol = lambda: serverProtocol
 
353
 
 
354
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
355
        wrapperFactory = TLSMemoryBIOFactory(
 
356
            serverContextFactory, False, serverFactory)
 
357
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
358
 
 
359
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
360
 
 
361
        # Wait for the connection to end, then make sure the server received
 
362
        # the bytes sent by the client.
 
363
        def cbConnectionDone(ignored):
 
364
            self.assertEquals("".join(serverProtocol.received), bytes)
 
365
        connectionDeferred.addCallback(cbConnectionDone)
 
366
        return connectionDeferred
 
367
 
 
368
 
 
369
    def test_writeSequence(self):
 
370
        """
 
371
        Bytes written to L{TLSMemoryBIOProtocol} with C{writeSequence} are
 
372
        received by the protocol on the other side of the connection.
 
373
        """
 
374
        bytes = "some bytes"
 
375
        class SimpleSendingProtocol(Protocol):
 
376
            def connectionMade(self):
 
377
                self.transport.writeSequence(list(bytes))
 
378
 
 
379
        clientFactory = ClientFactory()
 
380
        clientFactory.protocol = SimpleSendingProtocol
 
381
 
 
382
        clientContextFactory = HandshakeCallbackContextFactory()
 
383
        wrapperFactory = TLSMemoryBIOFactory(
 
384
            clientContextFactory, True, clientFactory)
 
385
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
386
 
 
387
        serverProtocol = AccumulatingProtocol(len(bytes))
 
388
        serverFactory = ServerFactory()
 
389
        serverFactory.protocol = lambda: serverProtocol
 
390
 
 
391
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
392
        wrapperFactory = TLSMemoryBIOFactory(
 
393
            serverContextFactory, False, serverFactory)
 
394
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
395
 
 
396
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
397
 
 
398
        # Wait for the connection to end, then make sure the server received
 
399
        # the bytes sent by the client.
 
400
        def cbConnectionDone(ignored):
 
401
            self.assertEquals("".join(serverProtocol.received), bytes)
 
402
        connectionDeferred.addCallback(cbConnectionDone)
 
403
        return connectionDeferred
 
404
 
 
405
 
 
406
    def test_multipleWrites(self):
 
407
        """
 
408
        If multiple separate TLS messages are received in a single chunk from
 
409
        the underlying transport, all of the application bytes from each
 
410
        message are delivered to the application-level protocol.
 
411
        """
 
412
        bytes = [str(i) for i in range(10)]
 
413
        class SimpleSendingProtocol(Protocol):
 
414
            def connectionMade(self):
 
415
                for b in bytes:
 
416
                    self.transport.write(b)
 
417
 
 
418
        clientFactory = ClientFactory()
 
419
        clientFactory.protocol = SimpleSendingProtocol
 
420
 
 
421
        clientContextFactory = HandshakeCallbackContextFactory()
 
422
        wrapperFactory = TLSMemoryBIOFactory(
 
423
            clientContextFactory, True, clientFactory)
 
424
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
425
 
 
426
        serverProtocol = AccumulatingProtocol(sum(map(len, bytes)))
 
427
        serverFactory = ServerFactory()
 
428
        serverFactory.protocol = lambda: serverProtocol
 
429
 
 
430
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
431
        wrapperFactory = TLSMemoryBIOFactory(
 
432
            serverContextFactory, False, serverFactory)
 
433
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
434
 
 
435
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol, collapsingPumpPolicy)
 
436
 
 
437
        # Wait for the connection to end, then make sure the server received
 
438
        # the bytes sent by the client.
 
439
        def cbConnectionDone(ignored):
 
440
            self.assertEquals("".join(serverProtocol.received), ''.join(bytes))
 
441
        connectionDeferred.addCallback(cbConnectionDone)
 
442
        return connectionDeferred
 
443
 
 
444
 
 
445
    def test_hugeWrite(self):
 
446
        """
 
447
        If a very long string is passed to L{TLSMemoryBIOProtocol.write}, any
 
448
        trailing part of it which cannot be send immediately is buffered and
 
449
        sent later.
 
450
        """
 
451
        bytes = "some bytes"
 
452
        factor = 8192
 
453
        class SimpleSendingProtocol(Protocol):
 
454
            def connectionMade(self):
 
455
                self.transport.write(bytes * factor)
 
456
 
 
457
        clientFactory = ClientFactory()
 
458
        clientFactory.protocol = SimpleSendingProtocol
 
459
 
 
460
        clientContextFactory = HandshakeCallbackContextFactory()
 
461
        wrapperFactory = TLSMemoryBIOFactory(
 
462
            clientContextFactory, True, clientFactory)
 
463
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
464
 
 
465
        serverProtocol = AccumulatingProtocol(len(bytes) * factor)
 
466
        serverFactory = ServerFactory()
 
467
        serverFactory.protocol = lambda: serverProtocol
 
468
 
 
469
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
470
        wrapperFactory = TLSMemoryBIOFactory(
 
471
            serverContextFactory, False, serverFactory)
 
472
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
473
 
 
474
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
475
 
 
476
        # Wait for the connection to end, then make sure the server received
 
477
        # the bytes sent by the client.
 
478
        def cbConnectionDone(ignored):
 
479
            self.assertEquals("".join(serverProtocol.received), bytes * factor)
 
480
        connectionDeferred.addCallback(cbConnectionDone)
 
481
        return connectionDeferred
 
482
 
 
483
 
 
484
    def test_disorderlyShutdown(self):
 
485
        """
 
486
        If a L{TLSMemoryBIOProtocol} loses its connection unexpectedly, this is
 
487
        reported to the application.
 
488
        """
 
489
        clientConnectionLost = Deferred()
 
490
        clientFactory = ClientFactory()
 
491
        clientFactory.protocol = (
 
492
            lambda: ConnectionLostNotifyingProtocol(
 
493
                clientConnectionLost))
 
494
 
 
495
        clientContextFactory = HandshakeCallbackContextFactory()
 
496
        wrapperFactory = TLSMemoryBIOFactory(
 
497
            clientContextFactory, True, clientFactory)
 
498
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
499
 
 
500
        # Client speaks first, so the server can be dumb.
 
501
        serverProtocol = Protocol()
 
502
 
 
503
        connectionDeferred = loopbackAsync(serverProtocol, sslClientProtocol)
 
504
 
 
505
        # Now destroy the connection.
 
506
        serverProtocol.transport.loseConnection()
 
507
 
 
508
        # And when the connection completely dies, check the reason.
 
509
        def cbDisconnected(clientProtocol):
 
510
            clientProtocol.lostConnectionReason.trap(Error)
 
511
        clientConnectionLost.addCallback(cbDisconnected)
 
512
        return clientConnectionLost
 
513
 
 
514
 
 
515
    def test_loseConnectionAfterHandshake(self):
 
516
        """
 
517
        L{TLSMemoryBIOProtocol.loseConnection} sends a TLS close alert and
 
518
        shuts down the underlying connection.
 
519
        """
 
520
        clientConnectionLost = Deferred()
 
521
        clientFactory = ClientFactory()
 
522
        clientFactory.protocol = (
 
523
            lambda: ConnectionLostNotifyingProtocol(
 
524
                clientConnectionLost))
 
525
 
 
526
        clientContextFactory, handshakeDeferred = (
 
527
            HandshakeCallbackContextFactory.factoryAndDeferred())
 
528
        wrapperFactory = TLSMemoryBIOFactory(
 
529
            clientContextFactory, True, clientFactory)
 
530
        sslClientProtocol = wrapperFactory.buildProtocol(None)
 
531
 
 
532
        serverProtocol = Protocol()
 
533
        serverFactory = ServerFactory()
 
534
        serverFactory.protocol = lambda: serverProtocol
 
535
 
 
536
        serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath)
 
537
        wrapperFactory = TLSMemoryBIOFactory(
 
538
            serverContextFactory, False, serverFactory)
 
539
        sslServerProtocol = wrapperFactory.buildProtocol(None)
 
540
 
 
541
        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
 
542
 
 
543
        # Wait for the handshake before dropping the connection.
 
544
        def cbHandshake(ignored):
 
545
            serverProtocol.transport.loseConnection()
 
546
 
 
547
            # Now wait for the client to notice.
 
548
            return clientConnectionLost
 
549
        handshakeDeferred.addCallback(cbHandshake)
 
550
 
 
551
        # Wait for the connection to end, then make sure the client was
 
552
        # notified of a handshake failure.
 
553
        def cbConnectionDone(clientProtocol):
 
554
            clientProtocol.lostConnectionReason.trap(ConnectionDone)
 
555
 
 
556
            # The server should have closed its underlying transport, in
 
557
            # addition to whatever it did to shut down the TLS layer.
 
558
            self.assertTrue(serverProtocol.transport.q.disconnect)
 
559
 
 
560
            # The client should also have closed its underlying transport once
 
561
            # it saw the server shut down the TLS layer, so as to avoid relying
 
562
            # on the server to close the underlying connection.
 
563
            self.assertTrue(clientProtocol.transport.q.disconnect)
 
564
        handshakeDeferred.addCallback(cbConnectionDone)
 
565
        return handshakeDeferred
 
566