~justin-fathomdb/nova/justinsb-openstack-api-volumes

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/test/test_tcp.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2001-2008 Twisted Matrix Laboratories.
 
2
# See LICENSE for details.
 
3
 
 
4
"""
 
5
Tests for implementations of L{IReactorTCP}.
 
6
"""
 
7
 
 
8
import socket, random, errno
 
9
 
 
10
from zope.interface import implements
 
11
 
 
12
from twisted.trial import unittest
 
13
 
 
14
from twisted.python.log import msg
 
15
from twisted.internet import protocol, reactor, defer, interfaces
 
16
from twisted.internet import error
 
17
from twisted.internet.address import IPv4Address
 
18
from twisted.internet.interfaces import IHalfCloseableProtocol, IPullProducer
 
19
from twisted.protocols import policies
 
20
from twisted.test.proto_helpers import AccumulatingProtocol
 
21
 
 
22
 
 
23
def loopUntil(predicate, interval=0):
 
24
    """
 
25
    Poor excuse for an event notification helper.  This polls a condition and
 
26
    calls back a Deferred when it is seen to be true.
 
27
 
 
28
    Do not use this function.
 
29
    """
 
30
    from twisted.internet import task
 
31
    d = defer.Deferred()
 
32
    def check():
 
33
        res = predicate()
 
34
        if res:
 
35
            d.callback(res)
 
36
    call = task.LoopingCall(check)
 
37
    def stop(result):
 
38
        call.stop()
 
39
        return result
 
40
    d.addCallback(stop)
 
41
    d2 = call.start(interval)
 
42
    d2.addErrback(d.errback)
 
43
    return d
 
44
 
 
45
 
 
46
class ClosingProtocol(protocol.Protocol):
 
47
 
 
48
    def connectionMade(self):
 
49
        self.transport.loseConnection()
 
50
 
 
51
    def connectionLost(self, reason):
 
52
        reason.trap(error.ConnectionDone)
 
53
 
 
54
class ClosingFactory(protocol.ServerFactory):
 
55
    """Factory that closes port immediatley."""
 
56
 
 
57
    def buildProtocol(self, conn):
 
58
        self.port.stopListening()
 
59
        return ClosingProtocol()
 
60
 
 
61
 
 
62
class MyProtocolFactoryMixin(object):
 
63
    """
 
64
    Mixin for factories which create L{AccumulatingProtocol} instances.
 
65
 
 
66
    @type protocolFactory: no-argument callable
 
67
    @ivar protocolFactory: Factory for protocols - takes the place of the
 
68
        typical C{protocol} attribute of factories (but that name is used by
 
69
        this class for something else).
 
70
 
 
71
    @type protocolConnectionMade: L{NoneType} or L{defer.Deferred}
 
72
    @ivar protocolConnectionMade: When an instance of L{AccumulatingProtocol}
 
73
        is connected, if this is not C{None}, the L{Deferred} will be called
 
74
        back with the protocol instance and the attribute set to C{None}.
 
75
 
 
76
    @type protocolConnectionLost: L{NoneType} or L{defer.Deferred}
 
77
    @ivar protocolConnectionLost: When an instance of L{AccumulatingProtocol}
 
78
        is created, this will be set as its C{closedDeferred} attribute and
 
79
        then this attribute will be set to C{None} so the L{defer.Deferred} is
 
80
        not used by more than one protocol.
 
81
 
 
82
    @ivar protocol: The most recently created L{AccumulatingProtocol} instance
 
83
        which was returned from C{buildProtocol}.
 
84
 
 
85
    @type called: C{int}
 
86
    @ivar called: A counter which is incremented each time C{buildProtocol}
 
87
        is called.
 
88
 
 
89
    @ivar peerAddresses: A C{list} of the addresses passed to C{buildProtocol}.
 
90
    """
 
91
    protocolFactory = AccumulatingProtocol
 
92
 
 
93
    protocolConnectionMade = None
 
94
    protocolConnectionLost = None
 
95
    protocol = None
 
96
    called = 0
 
97
 
 
98
    def __init__(self):
 
99
        self.peerAddresses = []
 
100
 
 
101
 
 
102
    def buildProtocol(self, addr):
 
103
        """
 
104
        Create a L{AccumulatingProtocol} and set it up to be able to perform
 
105
        callbacks.
 
106
        """
 
107
        self.peerAddresses.append(addr)
 
108
        self.called += 1
 
109
        p = self.protocolFactory()
 
110
        p.factory = self
 
111
        p.closedDeferred = self.protocolConnectionLost
 
112
        self.protocolConnectionLost = None
 
113
        self.protocol = p
 
114
        return p
 
115
 
 
116
 
 
117
 
 
118
class MyServerFactory(MyProtocolFactoryMixin, protocol.ServerFactory):
 
119
    """
 
120
    Server factory which creates L{AccumulatingProtocol} instances.
 
121
    """
 
122
 
 
123
 
 
124
 
 
125
class MyClientFactory(MyProtocolFactoryMixin, protocol.ClientFactory):
 
126
    """
 
127
    Client factory which creates L{AccumulatingProtocol} instances.
 
128
    """
 
129
    failed = 0
 
130
    stopped = 0
 
131
 
 
132
    def __init__(self):
 
133
        MyProtocolFactoryMixin.__init__(self)
 
134
        self.deferred = defer.Deferred()
 
135
        self.failDeferred = defer.Deferred()
 
136
 
 
137
    def clientConnectionFailed(self, connector, reason):
 
138
        self.failed = 1
 
139
        self.reason = reason
 
140
        self.failDeferred.callback(None)
 
141
 
 
142
    def clientConnectionLost(self, connector, reason):
 
143
        self.lostReason = reason
 
144
        self.deferred.callback(None)
 
145
 
 
146
    def stopFactory(self):
 
147
        self.stopped = 1
 
148
 
 
149
 
 
150
 
 
151
class ListeningTestCase(unittest.TestCase):
 
152
 
 
153
    def test_listen(self):
 
154
        """
 
155
        L{IReactorTCP.listenTCP} returns an object which provides
 
156
        L{IListeningPort}.
 
157
        """
 
158
        f = MyServerFactory()
 
159
        p1 = reactor.listenTCP(0, f, interface="127.0.0.1")
 
160
        self.addCleanup(p1.stopListening)
 
161
        self.failUnless(interfaces.IListeningPort.providedBy(p1))
 
162
 
 
163
 
 
164
    def testStopListening(self):
 
165
        """
 
166
        The L{IListeningPort} returned by L{IReactorTCP.listenTCP} can be
 
167
        stopped with its C{stopListening} method.  After the L{Deferred} it
 
168
        (optionally) returns has been called back, the port number can be bound
 
169
        to a new server.
 
170
        """
 
171
        f = MyServerFactory()
 
172
        port = reactor.listenTCP(0, f, interface="127.0.0.1")
 
173
        n = port.getHost().port
 
174
 
 
175
        def cbStopListening(ignored):
 
176
            # Make sure we can rebind the port right away
 
177
            port = reactor.listenTCP(n, f, interface="127.0.0.1")
 
178
            return port.stopListening()
 
179
 
 
180
        d = defer.maybeDeferred(port.stopListening)
 
181
        d.addCallback(cbStopListening)
 
182
        return d
 
183
 
 
184
 
 
185
    def testNumberedInterface(self):
 
186
        f = MyServerFactory()
 
187
        # listen only on the loopback interface
 
188
        p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
 
189
        return p1.stopListening()
 
190
 
 
191
    def testPortRepr(self):
 
192
        f = MyServerFactory()
 
193
        p = reactor.listenTCP(0, f)
 
194
        portNo = str(p.getHost().port)
 
195
        self.failIf(repr(p).find(portNo) == -1)
 
196
        def stoppedListening(ign):
 
197
            self.failIf(repr(p).find(portNo) != -1)
 
198
        d = defer.maybeDeferred(p.stopListening)
 
199
        return d.addCallback(stoppedListening)
 
200
 
 
201
 
 
202
    def test_serverRepr(self):
 
203
        """
 
204
        Check that the repr string of the server transport get the good port
 
205
        number if the server listens on 0.
 
206
        """
 
207
        server = MyServerFactory()
 
208
        serverConnMade = server.protocolConnectionMade = defer.Deferred()
 
209
        port = reactor.listenTCP(0, server)
 
210
        self.addCleanup(port.stopListening)
 
211
 
 
212
        client = MyClientFactory()
 
213
        clientConnMade = client.protocolConnectionMade = defer.Deferred()
 
214
        connector = reactor.connectTCP("127.0.0.1",
 
215
                                       port.getHost().port, client)
 
216
        self.addCleanup(connector.disconnect)
 
217
        def check((serverProto, clientProto)):
 
218
            portNumber = port.getHost().port
 
219
            self.assertEquals(
 
220
                repr(serverProto.transport),
 
221
                "<AccumulatingProtocol #0 on %s>" % (portNumber,))
 
222
            serverProto.transport.loseConnection()
 
223
            clientProto.transport.loseConnection()
 
224
        return defer.gatherResults([serverConnMade, clientConnMade]
 
225
            ).addCallback(check)
 
226
 
 
227
 
 
228
    def test_restartListening(self):
 
229
        """
 
230
        Stop and then try to restart a L{tcp.Port}: after a restart, the
 
231
        server should be able to handle client connections.
 
232
        """
 
233
        serverFactory = MyServerFactory()
 
234
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
235
        self.addCleanup(port.stopListening)
 
236
 
 
237
        def cbStopListening(ignored):
 
238
            port.startListening()
 
239
 
 
240
            client = MyClientFactory()
 
241
            serverFactory.protocolConnectionMade = defer.Deferred()
 
242
            client.protocolConnectionMade = defer.Deferred()
 
243
            connector = reactor.connectTCP("127.0.0.1",
 
244
                                           port.getHost().port, client)
 
245
            self.addCleanup(connector.disconnect)
 
246
            return defer.gatherResults([serverFactory.protocolConnectionMade,
 
247
                                        client.protocolConnectionMade]
 
248
                ).addCallback(close)
 
249
 
 
250
        def close((serverProto, clientProto)):
 
251
            clientProto.transport.loseConnection()
 
252
            serverProto.transport.loseConnection()
 
253
 
 
254
        d = defer.maybeDeferred(port.stopListening)
 
255
        d.addCallback(cbStopListening)
 
256
        return d
 
257
 
 
258
 
 
259
    def test_exceptInStop(self):
 
260
        """
 
261
        If the server factory raises an exception in C{stopFactory}, the
 
262
        deferred returned by L{tcp.Port.stopListening} should fail with the
 
263
        corresponding error.
 
264
        """
 
265
        serverFactory = MyServerFactory()
 
266
        def raiseException():
 
267
            raise RuntimeError("An error")
 
268
        serverFactory.stopFactory = raiseException
 
269
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
270
 
 
271
        return self.assertFailure(port.stopListening(), RuntimeError)
 
272
 
 
273
 
 
274
    def test_restartAfterExcept(self):
 
275
        """
 
276
        Even if the server factory raise an exception in C{stopFactory}, the
 
277
        corresponding C{tcp.Port} instance should be in a sane state and can
 
278
        be restarted.
 
279
        """
 
280
        serverFactory = MyServerFactory()
 
281
        def raiseException():
 
282
            raise RuntimeError("An error")
 
283
        serverFactory.stopFactory = raiseException
 
284
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
285
        self.addCleanup(port.stopListening)
 
286
 
 
287
        def cbStopListening(ignored):
 
288
            del serverFactory.stopFactory
 
289
            port.startListening()
 
290
 
 
291
            client = MyClientFactory()
 
292
            serverFactory.protocolConnectionMade = defer.Deferred()
 
293
            client.protocolConnectionMade = defer.Deferred()
 
294
            connector = reactor.connectTCP("127.0.0.1",
 
295
                                           port.getHost().port, client)
 
296
            self.addCleanup(connector.disconnect)
 
297
            return defer.gatherResults([serverFactory.protocolConnectionMade,
 
298
                                        client.protocolConnectionMade]
 
299
                ).addCallback(close)
 
300
 
 
301
        def close((serverProto, clientProto)):
 
302
            clientProto.transport.loseConnection()
 
303
            serverProto.transport.loseConnection()
 
304
 
 
305
        return self.assertFailure(port.stopListening(), RuntimeError
 
306
            ).addCallback(cbStopListening)
 
307
 
 
308
 
 
309
    def test_directConnectionLostCall(self):
 
310
        """
 
311
        If C{connectionLost} is called directly on a port object, it succeeds
 
312
        (and doesn't expect the presence of a C{deferred} attribute).
 
313
 
 
314
        C{connectionLost} is called by L{reactor.disconnectAll} at shutdown.
 
315
        """
 
316
        serverFactory = MyServerFactory()
 
317
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
318
        portNumber = port.getHost().port
 
319
        port.connectionLost(None)
 
320
 
 
321
        client = MyClientFactory()
 
322
        serverFactory.protocolConnectionMade = defer.Deferred()
 
323
        client.protocolConnectionMade = defer.Deferred()
 
324
        connector = reactor.connectTCP("127.0.0.1", portNumber, client)
 
325
        def check(ign):
 
326
            client.reason.trap(error.ConnectionRefusedError)
 
327
        return client.failDeferred.addCallback(check)
 
328
 
 
329
 
 
330
    def test_exceptInConnectionLostCall(self):
 
331
        """
 
332
        If C{connectionLost} is called directory on a port object and that the
 
333
        server factory raises an exception in C{stopFactory}, the exception is
 
334
        passed through to the caller.
 
335
 
 
336
        C{connectionLost} is called by L{reactor.disconnectAll} at shutdown.
 
337
        """
 
338
        serverFactory = MyServerFactory()
 
339
        def raiseException():
 
340
            raise RuntimeError("An error")
 
341
        serverFactory.stopFactory = raiseException
 
342
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
343
        self.assertRaises(RuntimeError, port.connectionLost, None)
 
344
 
 
345
 
 
346
 
 
347
def callWithSpew(f):
 
348
    from twisted.python.util import spewerWithLinenums as spewer
 
349
    import sys
 
350
    sys.settrace(spewer)
 
351
    try:
 
352
        f()
 
353
    finally:
 
354
        sys.settrace(None)
 
355
 
 
356
class LoopbackTestCase(unittest.TestCase):
 
357
    """
 
358
    Test loopback connections.
 
359
    """
 
360
    def test_closePortInProtocolFactory(self):
 
361
        """
 
362
        A port created with L{IReactorTCP.listenTCP} can be connected to with
 
363
        L{IReactorTCP.connectTCP}.
 
364
        """
 
365
        f = ClosingFactory()
 
366
        port = reactor.listenTCP(0, f, interface="127.0.0.1")
 
367
        self.addCleanup(port.stopListening)
 
368
        portNumber = port.getHost().port
 
369
        f.port = port
 
370
        clientF = MyClientFactory()
 
371
        reactor.connectTCP("127.0.0.1", portNumber, clientF)
 
372
        def check(x):
 
373
            self.assertTrue(clientF.protocol.made)
 
374
            self.assertTrue(port.disconnected)
 
375
            clientF.lostReason.trap(error.ConnectionDone)
 
376
        return clientF.deferred.addCallback(check)
 
377
 
 
378
    def _trapCnxDone(self, obj):
 
379
        getattr(obj, 'trap', lambda x: None)(error.ConnectionDone)
 
380
 
 
381
 
 
382
    def _connectedClientAndServerTest(self, callback):
 
383
        """
 
384
        Invoke the given callback with a client protocol and a server protocol
 
385
        which have been connected to each other.
 
386
        """
 
387
        serverFactory = MyServerFactory()
 
388
        serverConnMade = defer.Deferred()
 
389
        serverFactory.protocolConnectionMade = serverConnMade
 
390
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
391
        self.addCleanup(port.stopListening)
 
392
 
 
393
        portNumber = port.getHost().port
 
394
        clientF = MyClientFactory()
 
395
        clientConnMade = defer.Deferred()
 
396
        clientF.protocolConnectionMade = clientConnMade
 
397
        reactor.connectTCP("127.0.0.1", portNumber, clientF)
 
398
 
 
399
        connsMade = defer.gatherResults([serverConnMade, clientConnMade])
 
400
        def connected((serverProtocol, clientProtocol)):
 
401
            callback(serverProtocol, clientProtocol)
 
402
            serverProtocol.transport.loseConnection()
 
403
            clientProtocol.transport.loseConnection()
 
404
        connsMade.addCallback(connected)
 
405
        return connsMade
 
406
 
 
407
 
 
408
    def test_tcpNoDelay(self):
 
409
        """
 
410
        The transport of a protocol connected with L{IReactorTCP.connectTCP} or
 
411
        L{IReactor.TCP.listenTCP} can have its I{TCP_NODELAY} state inspected
 
412
        and manipulated with L{ITCPTransport.getTcpNoDelay} and
 
413
        L{ITCPTransport.setTcpNoDelay}.
 
414
        """
 
415
        def check(serverProtocol, clientProtocol):
 
416
            for p in [serverProtocol, clientProtocol]:
 
417
                transport = p.transport
 
418
                self.assertEquals(transport.getTcpNoDelay(), 0)
 
419
                transport.setTcpNoDelay(1)
 
420
                self.assertEquals(transport.getTcpNoDelay(), 1)
 
421
                transport.setTcpNoDelay(0)
 
422
                self.assertEquals(transport.getTcpNoDelay(), 0)
 
423
        return self._connectedClientAndServerTest(check)
 
424
 
 
425
 
 
426
    def test_tcpKeepAlive(self):
 
427
        """
 
428
        The transport of a protocol connected with L{IReactorTCP.connectTCP} or
 
429
        L{IReactor.TCP.listenTCP} can have its I{SO_KEEPALIVE} state inspected
 
430
        and manipulated with L{ITCPTransport.getTcpKeepAlive} and
 
431
        L{ITCPTransport.setTcpKeepAlive}.
 
432
        """
 
433
        def check(serverProtocol, clientProtocol):
 
434
            for p in [serverProtocol, clientProtocol]:
 
435
                transport = p.transport
 
436
                self.assertEquals(transport.getTcpKeepAlive(), 0)
 
437
                transport.setTcpKeepAlive(1)
 
438
                self.assertEquals(transport.getTcpKeepAlive(), 1)
 
439
                transport.setTcpKeepAlive(0)
 
440
                self.assertEquals(transport.getTcpKeepAlive(), 0)
 
441
        return self._connectedClientAndServerTest(check)
 
442
 
 
443
 
 
444
    def testFailing(self):
 
445
        clientF = MyClientFactory()
 
446
        # XXX we assume no one is listening on TCP port 69
 
447
        reactor.connectTCP("127.0.0.1", 69, clientF, timeout=5)
 
448
        def check(ignored):
 
449
            clientF.reason.trap(error.ConnectionRefusedError)
 
450
        return clientF.failDeferred.addCallback(check)
 
451
 
 
452
 
 
453
    def test_connectionRefusedErrorNumber(self):
 
454
        """
 
455
        Assert that the error number of the ConnectionRefusedError is
 
456
        ECONNREFUSED, and not some other socket related error.
 
457
        """
 
458
 
 
459
        # Bind a number of ports in the operating system.  We will attempt
 
460
        # to connect to these in turn immediately after closing them, in the
 
461
        # hopes that no one else has bound them in the mean time.  Any
 
462
        # connection which succeeds is ignored and causes us to move on to
 
463
        # the next port.  As soon as a connection attempt fails, we move on
 
464
        # to making an assertion about how it failed.  If they all succeed,
 
465
        # the test will fail.
 
466
 
 
467
        # It would be nice to have a simpler, reliable way to cause a
 
468
        # connection failure from the platform.
 
469
        #
 
470
        # On Linux (2.6.15), connecting to port 0 always fails.  FreeBSD
 
471
        # (5.4) rejects the connection attempt with EADDRNOTAVAIL.
 
472
        #
 
473
        # On FreeBSD (5.4), listening on a port and then repeatedly
 
474
        # connecting to it without ever accepting any connections eventually
 
475
        # leads to an ECONNREFUSED.  On Linux (2.6.15), a seemingly
 
476
        # unbounded number of connections succeed.
 
477
 
 
478
        serverSockets = []
 
479
        for i in xrange(10):
 
480
            serverSocket = socket.socket()
 
481
            serverSocket.bind(('127.0.0.1', 0))
 
482
            serverSocket.listen(1)
 
483
            serverSockets.append(serverSocket)
 
484
        random.shuffle(serverSockets)
 
485
 
 
486
        clientCreator = protocol.ClientCreator(reactor, protocol.Protocol)
 
487
 
 
488
        def tryConnectFailure():
 
489
            def connected(proto):
 
490
                """
 
491
                Darn.  Kill it and try again, if there are any tries left.
 
492
                """
 
493
                proto.transport.loseConnection()
 
494
                if serverSockets:
 
495
                    return tryConnectFailure()
 
496
                self.fail("Could not fail to connect - could not test errno for that case.")
 
497
 
 
498
            serverSocket = serverSockets.pop()
 
499
            serverHost, serverPort = serverSocket.getsockname()
 
500
            serverSocket.close()
 
501
 
 
502
            connectDeferred = clientCreator.connectTCP(serverHost, serverPort)
 
503
            connectDeferred.addCallback(connected)
 
504
            return connectDeferred
 
505
 
 
506
        refusedDeferred = tryConnectFailure()
 
507
        self.assertFailure(refusedDeferred, error.ConnectionRefusedError)
 
508
        def connRefused(exc):
 
509
            self.assertEqual(exc.osError, errno.ECONNREFUSED)
 
510
        refusedDeferred.addCallback(connRefused)
 
511
        def cleanup(passthrough):
 
512
            while serverSockets:
 
513
                serverSockets.pop().close()
 
514
            return passthrough
 
515
        refusedDeferred.addBoth(cleanup)
 
516
        return refusedDeferred
 
517
 
 
518
 
 
519
    def test_connectByServiceFail(self):
 
520
        """
 
521
        Connecting to a named service which does not exist raises
 
522
        L{error.ServiceNameUnknownError}.
 
523
        """
 
524
        self.assertRaises(
 
525
            error.ServiceNameUnknownError,
 
526
            reactor.connectTCP,
 
527
            "127.0.0.1", "thisbetternotexist", MyClientFactory())
 
528
 
 
529
 
 
530
    def test_connectByService(self):
 
531
        """
 
532
        L{IReactorTCP.connectTCP} accepts the name of a service instead of a
 
533
        port number and connects to the port number associated with that
 
534
        service, as defined by L{socket.getservbyname}.
 
535
        """
 
536
        serverFactory = MyServerFactory()
 
537
        serverConnMade = defer.Deferred()
 
538
        serverFactory.protocolConnectionMade = serverConnMade
 
539
        port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
540
        self.addCleanup(port.stopListening)
 
541
        portNumber = port.getHost().port
 
542
        clientFactory = MyClientFactory()
 
543
        clientConnMade = defer.Deferred()
 
544
        clientFactory.protocolConnectionMade = clientConnMade
 
545
 
 
546
        def fakeGetServicePortByName(serviceName, protocolName):
 
547
            if serviceName == 'http' and protocolName == 'tcp':
 
548
                return portNumber
 
549
            return 10
 
550
        self.patch(socket, 'getservbyname', fakeGetServicePortByName)
 
551
 
 
552
        c = reactor.connectTCP('127.0.0.1', 'http', clientFactory)
 
553
 
 
554
        connMade = defer.gatherResults([serverConnMade, clientConnMade])
 
555
        def connected((serverProtocol, clientProtocol)):
 
556
            self.assertTrue(
 
557
                serverFactory.called,
 
558
                "Server factory was not called upon to build a protocol.")
 
559
            serverProtocol.transport.loseConnection()
 
560
            clientProtocol.transport.loseConnection()
 
561
        connMade.addCallback(connected)
 
562
        return connMade
 
563
 
 
564
 
 
565
class StartStopFactory(protocol.Factory):
 
566
 
 
567
    started = 0
 
568
    stopped = 0
 
569
 
 
570
    def startFactory(self):
 
571
        if self.started or self.stopped:
 
572
            raise RuntimeError
 
573
        self.started = 1
 
574
 
 
575
    def stopFactory(self):
 
576
        if not self.started or self.stopped:
 
577
            raise RuntimeError
 
578
        self.stopped = 1
 
579
 
 
580
 
 
581
class ClientStartStopFactory(MyClientFactory):
 
582
 
 
583
    started = 0
 
584
    stopped = 0
 
585
 
 
586
    def startFactory(self):
 
587
        if self.started or self.stopped:
 
588
            raise RuntimeError
 
589
        self.started = 1
 
590
 
 
591
    def stopFactory(self):
 
592
        if not self.started or self.stopped:
 
593
            raise RuntimeError
 
594
        self.stopped = 1
 
595
 
 
596
 
 
597
class FactoryTestCase(unittest.TestCase):
 
598
    """Tests for factories."""
 
599
 
 
600
    def test_serverStartStop(self):
 
601
        """
 
602
        The factory passed to L{IReactorTCP.listenTCP} should be started only
 
603
        when it transitions from being used on no ports to being used on one
 
604
        port and should be stopped only when it transitions from being used on
 
605
        one port to being used on no ports.
 
606
        """
 
607
        # Note - this test doesn't need to use listenTCP.  It is exercising
 
608
        # logic implemented in Factory.doStart and Factory.doStop, so it could
 
609
        # just call that directly.  Some other test can make sure that
 
610
        # listenTCP and stopListening correctly call doStart and
 
611
        # doStop. -exarkun
 
612
 
 
613
        f = StartStopFactory()
 
614
 
 
615
        # listen on port
 
616
        p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
 
617
        self.addCleanup(p1.stopListening)
 
618
 
 
619
        self.assertEqual((f.started, f.stopped), (1, 0))
 
620
 
 
621
        # listen on two more ports
 
622
        p2 = reactor.listenTCP(0, f, interface='127.0.0.1')
 
623
        p3 = reactor.listenTCP(0, f, interface='127.0.0.1')
 
624
 
 
625
        self.assertEqual((f.started, f.stopped), (1, 0))
 
626
 
 
627
        # close two ports
 
628
        d1 = defer.maybeDeferred(p1.stopListening)
 
629
        d2 = defer.maybeDeferred(p2.stopListening)
 
630
        closedDeferred = defer.gatherResults([d1, d2])
 
631
        def cbClosed(ignored):
 
632
            self.assertEqual((f.started, f.stopped), (1, 0))
 
633
            # Close the last port
 
634
            return p3.stopListening()
 
635
        closedDeferred.addCallback(cbClosed)
 
636
 
 
637
        def cbClosedAll(ignored):
 
638
            self.assertEquals((f.started, f.stopped), (1, 1))
 
639
        closedDeferred.addCallback(cbClosedAll)
 
640
        return closedDeferred
 
641
 
 
642
 
 
643
    def test_clientStartStop(self):
 
644
        """
 
645
        The factory passed to L{IReactorTCP.connectTCP} should be started when
 
646
        the connection attempt starts and stopped when it is over.
 
647
        """
 
648
        f = ClosingFactory()
 
649
        p = reactor.listenTCP(0, f, interface="127.0.0.1")
 
650
        self.addCleanup(p.stopListening)
 
651
        portNumber = p.getHost().port
 
652
        f.port = p
 
653
 
 
654
        factory = ClientStartStopFactory()
 
655
        reactor.connectTCP("127.0.0.1", portNumber, factory)
 
656
        self.assertTrue(factory.started)
 
657
        return loopUntil(lambda: factory.stopped)
 
658
 
 
659
 
 
660
 
 
661
class ConnectorTestCase(unittest.TestCase):
 
662
 
 
663
    def test_connectorIdentity(self):
 
664
        """
 
665
        L{IReactorTCP.connectTCP} returns an object which provides
 
666
        L{IConnector}.  The destination of the connector is the address which
 
667
        was passed to C{connectTCP}.  The same connector object is passed to
 
668
        the factory's C{startedConnecting} method as to the factory's
 
669
        C{clientConnectionLost} method.
 
670
        """
 
671
        serverFactory = ClosingFactory()
 
672
        tcpPort = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
673
        self.addCleanup(tcpPort.stopListening)
 
674
        portNumber = tcpPort.getHost().port
 
675
        serverFactory.port = tcpPort
 
676
 
 
677
        seenConnectors = []
 
678
        seenFailures = []
 
679
 
 
680
        clientFactory = ClientStartStopFactory()
 
681
        clientFactory.clientConnectionLost = (
 
682
            lambda connector, reason: (seenConnectors.append(connector),
 
683
                                       seenFailures.append(reason)))
 
684
        clientFactory.startedConnecting = seenConnectors.append
 
685
 
 
686
        connector = reactor.connectTCP("127.0.0.1", portNumber, clientFactory)
 
687
        self.assertTrue(interfaces.IConnector.providedBy(connector))
 
688
        dest = connector.getDestination()
 
689
        self.assertEquals(dest.type, "TCP")
 
690
        self.assertEquals(dest.host, "127.0.0.1")
 
691
        self.assertEquals(dest.port, portNumber)
 
692
 
 
693
        d = loopUntil(lambda: clientFactory.stopped)
 
694
        def clientFactoryStopped(ignored):
 
695
            seenFailures[0].trap(error.ConnectionDone)
 
696
            self.assertEqual(seenConnectors, [connector, connector])
 
697
        d.addCallback(clientFactoryStopped)
 
698
        return d
 
699
 
 
700
 
 
701
    def test_userFail(self):
 
702
        """
 
703
        Calling L{IConnector.stopConnecting} in C{Factory.startedConnecting}
 
704
        results in C{Factory.clientConnectionFailed} being called with
 
705
        L{error.UserError} as the reason.
 
706
        """
 
707
        serverFactory = MyServerFactory()
 
708
        tcpPort = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
709
        self.addCleanup(tcpPort.stopListening)
 
710
        portNumber = tcpPort.getHost().port
 
711
 
 
712
        def startedConnecting(connector):
 
713
            connector.stopConnecting()
 
714
 
 
715
        clientFactory = ClientStartStopFactory()
 
716
        clientFactory.startedConnecting = startedConnecting
 
717
        reactor.connectTCP("127.0.0.1", portNumber, clientFactory)
 
718
 
 
719
        d = loopUntil(lambda: clientFactory.stopped)
 
720
        def check(ignored):
 
721
            self.assertEquals(clientFactory.failed, 1)
 
722
            clientFactory.reason.trap(error.UserError)
 
723
        return d.addCallback(check)
 
724
 
 
725
 
 
726
    def test_reconnect(self):
 
727
        """
 
728
        Calling L{IConnector.connect} in C{Factory.clientConnectionLost} causes
 
729
        a new connection attempt to be made.
 
730
        """
 
731
        serverFactory = ClosingFactory()
 
732
        tcpPort = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
 
733
        self.addCleanup(tcpPort.stopListening)
 
734
        portNumber = tcpPort.getHost().port
 
735
        serverFactory.port = tcpPort
 
736
 
 
737
        clientFactory = MyClientFactory()
 
738
 
 
739
        def clientConnectionLost(connector, reason):
 
740
            connector.connect()
 
741
        clientFactory.clientConnectionLost = clientConnectionLost
 
742
        reactor.connectTCP("127.0.0.1", portNumber, clientFactory)
 
743
 
 
744
        d = loopUntil(lambda: clientFactory.failed)
 
745
        def reconnectFailed(ignored):
 
746
            p = clientFactory.protocol
 
747
            self.assertEqual((p.made, p.closed), (1, 1))
 
748
            clientFactory.reason.trap(error.ConnectionRefusedError)
 
749
            self.assertEqual(clientFactory.stopped, 1)
 
750
        return d.addCallback(reconnectFailed)
 
751
 
 
752
 
 
753
 
 
754
class CannotBindTestCase(unittest.TestCase):
 
755
    """
 
756
    Tests for correct behavior when a reactor cannot bind to the required TCP
 
757
    port.
 
758
    """
 
759
 
 
760
    def test_cannotBind(self):
 
761
        """
 
762
        L{IReactorTCP.listenTCP} raises L{error.CannotListenError} if the
 
763
        address to listen on is already in use.
 
764
        """
 
765
        f = MyServerFactory()
 
766
 
 
767
        p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
 
768
        self.addCleanup(p1.stopListening)
 
769
        n = p1.getHost().port
 
770
        dest = p1.getHost()
 
771
        self.assertEquals(dest.type, "TCP")
 
772
        self.assertEquals(dest.host, "127.0.0.1")
 
773
        self.assertEquals(dest.port, n)
 
774
 
 
775
        # make sure new listen raises error
 
776
        self.assertRaises(error.CannotListenError,
 
777
                          reactor.listenTCP, n, f, interface='127.0.0.1')
 
778
 
 
779
 
 
780
 
 
781
    def _fireWhenDoneFunc(self, d, f):
 
782
        """Returns closure that when called calls f and then callbacks d.
 
783
        """
 
784
        from twisted.python import util as tputil
 
785
        def newf(*args, **kw):
 
786
            rtn = f(*args, **kw)
 
787
            d.callback('')
 
788
            return rtn
 
789
        return tputil.mergeFunctionMetadata(f, newf)
 
790
 
 
791
 
 
792
    def test_clientBind(self):
 
793
        """
 
794
        L{IReactorTCP.connectTCP} calls C{Factory.clientConnectionFailed} with
 
795
        L{error.ConnectBindError} if the bind address specified is already in
 
796
        use.
 
797
        """
 
798
        theDeferred = defer.Deferred()
 
799
        sf = MyServerFactory()
 
800
        sf.startFactory = self._fireWhenDoneFunc(theDeferred, sf.startFactory)
 
801
        p = reactor.listenTCP(0, sf, interface="127.0.0.1")
 
802
        self.addCleanup(p.stopListening)
 
803
 
 
804
        def _connect1(results):
 
805
            d = defer.Deferred()
 
806
            cf1 = MyClientFactory()
 
807
            cf1.buildProtocol = self._fireWhenDoneFunc(d, cf1.buildProtocol)
 
808
            reactor.connectTCP("127.0.0.1", p.getHost().port, cf1,
 
809
                               bindAddress=("127.0.0.1", 0))
 
810
            d.addCallback(_conmade, cf1)
 
811
            return d
 
812
 
 
813
        def _conmade(results, cf1):
 
814
            d = defer.Deferred()
 
815
            cf1.protocol.connectionMade = self._fireWhenDoneFunc(
 
816
                d, cf1.protocol.connectionMade)
 
817
            d.addCallback(_check1connect2, cf1)
 
818
            return d
 
819
 
 
820
        def _check1connect2(results, cf1):
 
821
            self.assertEquals(cf1.protocol.made, 1)
 
822
 
 
823
            d1 = defer.Deferred()
 
824
            d2 = defer.Deferred()
 
825
            port = cf1.protocol.transport.getHost().port
 
826
            cf2 = MyClientFactory()
 
827
            cf2.clientConnectionFailed = self._fireWhenDoneFunc(
 
828
                d1, cf2.clientConnectionFailed)
 
829
            cf2.stopFactory = self._fireWhenDoneFunc(d2, cf2.stopFactory)
 
830
            reactor.connectTCP("127.0.0.1", p.getHost().port, cf2,
 
831
                               bindAddress=("127.0.0.1", port))
 
832
            d1.addCallback(_check2failed, cf1, cf2)
 
833
            d2.addCallback(_check2stopped, cf1, cf2)
 
834
            dl = defer.DeferredList([d1, d2])
 
835
            dl.addCallback(_stop, cf1, cf2)
 
836
            return dl
 
837
 
 
838
        def _check2failed(results, cf1, cf2):
 
839
            self.assertEquals(cf2.failed, 1)
 
840
            cf2.reason.trap(error.ConnectBindError)
 
841
            self.assertTrue(cf2.reason.check(error.ConnectBindError))
 
842
            return results
 
843
 
 
844
        def _check2stopped(results, cf1, cf2):
 
845
            self.assertEquals(cf2.stopped, 1)
 
846
            return results
 
847
 
 
848
        def _stop(results, cf1, cf2):
 
849
            d = defer.Deferred()
 
850
            d.addCallback(_check1cleanup, cf1)
 
851
            cf1.stopFactory = self._fireWhenDoneFunc(d, cf1.stopFactory)
 
852
            cf1.protocol.transport.loseConnection()
 
853
            return d
 
854
 
 
855
        def _check1cleanup(results, cf1):
 
856
            self.assertEquals(cf1.stopped, 1)
 
857
 
 
858
        theDeferred.addCallback(_connect1)
 
859
        return theDeferred
 
860
 
 
861
 
 
862
 
 
863
class MyOtherClientFactory(protocol.ClientFactory):
 
864
    def buildProtocol(self, address):
 
865
        self.address = address
 
866
        self.protocol = AccumulatingProtocol()
 
867
        return self.protocol
 
868
 
 
869
 
 
870
 
 
871
class LocalRemoteAddressTestCase(unittest.TestCase):
 
872
    """
 
873
    Tests for correct getHost/getPeer values and that the correct address is
 
874
    passed to buildProtocol.
 
875
    """
 
876
    def test_hostAddress(self):
 
877
        """
 
878
        L{IListeningPort.getHost} returns the same address as a client
 
879
        connection's L{ITCPTransport.getPeer}.
 
880
        """
 
881
        serverFactory = MyServerFactory()
 
882
        serverFactory.protocolConnectionLost = defer.Deferred()
 
883
        serverConnectionLost = serverFactory.protocolConnectionLost
 
884
        port = reactor.listenTCP(0, serverFactory, interface='127.0.0.1')
 
885
        self.addCleanup(port.stopListening)
 
886
        n = port.getHost().port
 
887
 
 
888
        clientFactory = MyClientFactory()
 
889
        onConnection = clientFactory.protocolConnectionMade = defer.Deferred()
 
890
        connector = reactor.connectTCP('127.0.0.1', n, clientFactory)
 
891
 
 
892
        def check(ignored):
 
893
            self.assertEquals([port.getHost()], clientFactory.peerAddresses)
 
894
            self.assertEquals(
 
895
                port.getHost(), clientFactory.protocol.transport.getPeer())
 
896
        onConnection.addCallback(check)
 
897
 
 
898
        def cleanup(ignored):
 
899
            # Clean up the client explicitly here so that tear down of
 
900
            # the server side of the connection begins, then wait for
 
901
            # the server side to actually disconnect.
 
902
            connector.disconnect()
 
903
            return serverConnectionLost
 
904
        onConnection.addCallback(cleanup)
 
905
 
 
906
        return onConnection
 
907
 
 
908
 
 
909
 
 
910
class WriterProtocol(protocol.Protocol):
 
911
    def connectionMade(self):
 
912
        # use everything ITransport claims to provide. If something here
 
913
        # fails, the exception will be written to the log, but it will not
 
914
        # directly flunk the test. The test will fail when maximum number of
 
915
        # iterations have passed and the writer's factory.done has not yet
 
916
        # been set.
 
917
        self.transport.write("Hello Cleveland!\n")
 
918
        seq = ["Goodbye", " cruel", " world", "\n"]
 
919
        self.transport.writeSequence(seq)
 
920
        peer = self.transport.getPeer()
 
921
        if peer.type != "TCP":
 
922
            print "getPeer returned non-TCP socket:", peer
 
923
            self.factory.problem = 1
 
924
        us = self.transport.getHost()
 
925
        if us.type != "TCP":
 
926
            print "getHost returned non-TCP socket:", us
 
927
            self.factory.problem = 1
 
928
        self.factory.done = 1
 
929
 
 
930
        self.transport.loseConnection()
 
931
 
 
932
class ReaderProtocol(protocol.Protocol):
 
933
    def dataReceived(self, data):
 
934
        self.factory.data += data
 
935
    def connectionLost(self, reason):
 
936
        self.factory.done = 1
 
937
 
 
938
class WriterClientFactory(protocol.ClientFactory):
 
939
    def __init__(self):
 
940
        self.done = 0
 
941
        self.data = ""
 
942
    def buildProtocol(self, addr):
 
943
        p = ReaderProtocol()
 
944
        p.factory = self
 
945
        self.protocol = p
 
946
        return p
 
947
 
 
948
class WriteDataTestCase(unittest.TestCase):
 
949
    """
 
950
    Test that connected TCP sockets can actually write data. Try to exercise
 
951
    the entire ITransport interface.
 
952
    """
 
953
 
 
954
    def test_writer(self):
 
955
        """
 
956
        L{ITCPTransport.write} and L{ITCPTransport.writeSequence} send bytes to
 
957
        the other end of the connection.
 
958
        """
 
959
        f = protocol.Factory()
 
960
        f.protocol = WriterProtocol
 
961
        f.done = 0
 
962
        f.problem = 0
 
963
        wrappedF = WiredFactory(f)
 
964
        p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1")
 
965
        self.addCleanup(p.stopListening)
 
966
        n = p.getHost().port
 
967
        clientF = WriterClientFactory()
 
968
        wrappedClientF = WiredFactory(clientF)
 
969
        reactor.connectTCP("127.0.0.1", n, wrappedClientF)
 
970
 
 
971
        def check(ignored):
 
972
            self.failUnless(f.done, "writer didn't finish, it probably died")
 
973
            self.failUnless(f.problem == 0, "writer indicated an error")
 
974
            self.failUnless(clientF.done,
 
975
                            "client didn't see connection dropped")
 
976
            expected = "".join(["Hello Cleveland!\n",
 
977
                                "Goodbye", " cruel", " world", "\n"])
 
978
            self.failUnless(clientF.data == expected,
 
979
                            "client didn't receive all the data it expected")
 
980
        d = defer.gatherResults([wrappedF.onDisconnect,
 
981
                                 wrappedClientF.onDisconnect])
 
982
        return d.addCallback(check)
 
983
 
 
984
 
 
985
    def test_writeAfterShutdownWithoutReading(self):
 
986
        """
 
987
        A TCP transport which is written to after the connection has been shut
 
988
        down should notify its protocol that the connection has been lost, even
 
989
        if the TCP transport is not actively being monitored for read events
 
990
        (ie, pauseProducing was called on it).
 
991
        """
 
992
        # This is an unpleasant thing.  Generally tests shouldn't skip or
 
993
        # run based on the name of the reactor being used (most tests
 
994
        # shouldn't care _at all_ what reactor is being used, in fact).  The
 
995
        # Gtk reactor cannot pass this test, though, because it fails to
 
996
        # implement IReactorTCP entirely correctly.  Gtk is quite old at
 
997
        # this point, so it's more likely that gtkreactor will be deprecated
 
998
        # and removed rather than fixed to handle this case correctly.
 
999
        # Since this is a pre-existing (and very long-standing) issue with
 
1000
        # the Gtk reactor, there's no reason for it to prevent this test
 
1001
        # being added to exercise the other reactors, for which the behavior
 
1002
        # was also untested but at least works correctly (now).  See #2833
 
1003
        # for information on the status of gtkreactor.
 
1004
        if reactor.__class__.__name__ == 'IOCPReactor':
 
1005
            raise unittest.SkipTest(
 
1006
                "iocpreactor does not, in fact, stop reading immediately after "
 
1007
                "pauseProducing is called. This results in a bonus disconnection "
 
1008
                "notification. Under some circumstances, it might be possible to "
 
1009
                "not receive this notifications (specifically, pauseProducing, "
 
1010
                "deliver some data, proceed with this test).")
 
1011
        if reactor.__class__.__name__ == 'GtkReactor':
 
1012
            raise unittest.SkipTest(
 
1013
                "gtkreactor does not implement unclean disconnection "
 
1014
                "notification correctly.  This might more properly be "
 
1015
                "a todo, but due to technical limitations it cannot be.")
 
1016
 
 
1017
        # Called back after the protocol for the client side of the connection
 
1018
        # has paused its transport, preventing it from reading, therefore
 
1019
        # preventing it from noticing the disconnection before the rest of the
 
1020
        # actions which are necessary to trigger the case this test is for have
 
1021
        # been taken.
 
1022
        clientPaused = defer.Deferred()
 
1023
 
 
1024
        # Called back when the protocol for the server side of the connection
 
1025
        # has received connection lost notification.
 
1026
        serverLost = defer.Deferred()
 
1027
 
 
1028
        class Disconnecter(protocol.Protocol):
 
1029
            """
 
1030
            Protocol for the server side of the connection which disconnects
 
1031
            itself in a callback on clientPaused and publishes notification
 
1032
            when its connection is actually lost.
 
1033
            """
 
1034
            def connectionMade(self):
 
1035
                """
 
1036
                Set up a callback on clientPaused to lose the connection.
 
1037
                """
 
1038
                msg('Disconnector.connectionMade')
 
1039
                def disconnect(ignored):
 
1040
                    msg('Disconnector.connectionMade disconnect')
 
1041
                    self.transport.loseConnection()
 
1042
                    msg('loseConnection called')
 
1043
                clientPaused.addCallback(disconnect)
 
1044
 
 
1045
            def connectionLost(self, reason):
 
1046
                """
 
1047
                Notify observers that the server side of the connection has
 
1048
                ended.
 
1049
                """
 
1050
                msg('Disconnecter.connectionLost')
 
1051
                serverLost.callback(None)
 
1052
                msg('serverLost called back')
 
1053
 
 
1054
        # Create the server port to which a connection will be made.
 
1055
        server = protocol.ServerFactory()
 
1056
        server.protocol = Disconnecter
 
1057
        port = reactor.listenTCP(0, server, interface='127.0.0.1')
 
1058
        self.addCleanup(port.stopListening)
 
1059
        addr = port.getHost()
 
1060
 
 
1061
        class Infinite(object):
 
1062
            """
 
1063
            A producer which will write to its consumer as long as
 
1064
            resumeProducing is called.
 
1065
 
 
1066
            @ivar consumer: The L{IConsumer} which will be written to.
 
1067
            """
 
1068
            implements(IPullProducer)
 
1069
 
 
1070
            def __init__(self, consumer):
 
1071
                self.consumer = consumer
 
1072
 
 
1073
            def resumeProducing(self):
 
1074
                msg('Infinite.resumeProducing')
 
1075
                self.consumer.write('x')
 
1076
                msg('Infinite.resumeProducing wrote to consumer')
 
1077
 
 
1078
            def stopProducing(self):
 
1079
                msg('Infinite.stopProducing')
 
1080
 
 
1081
 
 
1082
        class UnreadingWriter(protocol.Protocol):
 
1083
            """
 
1084
            Trivial protocol which pauses its transport immediately and then
 
1085
            writes some bytes to it.
 
1086
            """
 
1087
            def connectionMade(self):
 
1088
                msg('UnreadingWriter.connectionMade')
 
1089
                self.transport.pauseProducing()
 
1090
                clientPaused.callback(None)
 
1091
                msg('clientPaused called back')
 
1092
                def write(ignored):
 
1093
                    msg('UnreadingWriter.connectionMade write')
 
1094
                    # This needs to be enough bytes to spill over into the
 
1095
                    # userspace Twisted send buffer - if it all fits into
 
1096
                    # the kernel, Twisted won't even poll for OUT events,
 
1097
                    # which means it won't poll for any events at all, so
 
1098
                    # the disconnection is never noticed.  This is due to
 
1099
                    # #1662.  When #1662 is fixed, this test will likely
 
1100
                    # need to be adjusted, otherwise connection lost
 
1101
                    # notification will happen too soon and the test will
 
1102
                    # probably begin to fail with ConnectionDone instead of
 
1103
                    # ConnectionLost (in any case, it will no longer be
 
1104
                    # entirely correct).
 
1105
                    producer = Infinite(self.transport)
 
1106
                    msg('UnreadingWriter.connectionMade write created producer')
 
1107
                    self.transport.registerProducer(producer, False)
 
1108
                    msg('UnreadingWriter.connectionMade write registered producer')
 
1109
                serverLost.addCallback(write)
 
1110
 
 
1111
        # Create the client and initiate the connection
 
1112
        client = MyClientFactory()
 
1113
        client.protocolFactory = UnreadingWriter
 
1114
        clientConnectionLost = client.deferred
 
1115
        def cbClientLost(ignored):
 
1116
            msg('cbClientLost')
 
1117
            return client.lostReason
 
1118
        clientConnectionLost.addCallback(cbClientLost)
 
1119
        msg('Connecting to %s:%s' % (addr.host, addr.port))
 
1120
        connector = reactor.connectTCP(addr.host, addr.port, client)
 
1121
 
 
1122
        # By the end of the test, the client should have received notification
 
1123
        # of unclean disconnection.
 
1124
        msg('Returning Deferred')
 
1125
        return self.assertFailure(clientConnectionLost, error.ConnectionLost)
 
1126
 
 
1127
 
 
1128
 
 
1129
class ConnectionLosingProtocol(protocol.Protocol):
 
1130
    def connectionMade(self):
 
1131
        self.transport.write("1")
 
1132
        self.transport.loseConnection()
 
1133
        self.master._connectionMade()
 
1134
        self.master.ports.append(self.transport)
 
1135
 
 
1136
 
 
1137
 
 
1138
class NoopProtocol(protocol.Protocol):
 
1139
    def connectionMade(self):
 
1140
        self.d = defer.Deferred()
 
1141
        self.master.serverConns.append(self.d)
 
1142
 
 
1143
    def connectionLost(self, reason):
 
1144
        self.d.callback(True)
 
1145
 
 
1146
 
 
1147
 
 
1148
class ConnectionLostNotifyingProtocol(protocol.Protocol):
 
1149
    """
 
1150
    Protocol which fires a Deferred which was previously passed to
 
1151
    its initializer when the connection is lost.
 
1152
 
 
1153
    @ivar onConnectionLost: The L{Deferred} which will be fired in
 
1154
        C{connectionLost}.
 
1155
 
 
1156
    @ivar lostConnectionReason: C{None} until the connection is lost, then a
 
1157
        reference to the reason passed to C{connectionLost}.
 
1158
    """
 
1159
    def __init__(self, onConnectionLost):
 
1160
        self.lostConnectionReason = None
 
1161
        self.onConnectionLost = onConnectionLost
 
1162
 
 
1163
 
 
1164
    def connectionLost(self, reason):
 
1165
        self.lostConnectionReason = reason
 
1166
        self.onConnectionLost.callback(self)
 
1167
 
 
1168
 
 
1169
 
 
1170
class HandleSavingProtocol(ConnectionLostNotifyingProtocol):
 
1171
    """
 
1172
    Protocol which grabs the platform-specific socket handle and
 
1173
    saves it as an attribute on itself when the connection is
 
1174
    established.
 
1175
    """
 
1176
    def makeConnection(self, transport):
 
1177
        """
 
1178
        Save the platform-specific socket handle for future
 
1179
        introspection.
 
1180
        """
 
1181
        self.handle = transport.getHandle()
 
1182
        return protocol.Protocol.makeConnection(self, transport)
 
1183
 
 
1184
 
 
1185
 
 
1186
class ProperlyCloseFilesMixin:
 
1187
    """
 
1188
    Tests for platform resources properly being cleaned up.
 
1189
    """
 
1190
    def createServer(self, address, portNumber, factory):
 
1191
        """
 
1192
        Bind a server port to which connections will be made.  The server
 
1193
        should use the given protocol factory.
 
1194
 
 
1195
        @return: The L{IListeningPort} for the server created.
 
1196
        """
 
1197
        raise NotImplementedError()
 
1198
 
 
1199
 
 
1200
    def connectClient(self, address, portNumber, clientCreator):
 
1201
        """
 
1202
        Establish a connection to the given address using the given
 
1203
        L{ClientCreator} instance.
 
1204
 
 
1205
        @return: A Deferred which will fire with the connected protocol instance.
 
1206
        """
 
1207
        raise NotImplementedError()
 
1208
 
 
1209
 
 
1210
    def getHandleExceptionType(self):
 
1211
        """
 
1212
        Return the exception class which will be raised when an operation is
 
1213
        attempted on a closed platform handle.
 
1214
        """
 
1215
        raise NotImplementedError()
 
1216
 
 
1217
 
 
1218
    def getHandleErrorCode(self):
 
1219
        """
 
1220
        Return the errno expected to result from writing to a closed
 
1221
        platform socket handle.
 
1222
        """
 
1223
        # These platforms have been seen to give EBADF:
 
1224
        #
 
1225
        #  Linux 2.4.26, Linux 2.6.15, OS X 10.4, FreeBSD 5.4
 
1226
        #  Windows 2000 SP 4, Windows XP SP 2
 
1227
        return errno.EBADF
 
1228
 
 
1229
 
 
1230
    def test_properlyCloseFiles(self):
 
1231
        """
 
1232
        Test that lost connections properly have their underlying socket
 
1233
        resources cleaned up.
 
1234
        """
 
1235
        onServerConnectionLost = defer.Deferred()
 
1236
        serverFactory = protocol.ServerFactory()
 
1237
        serverFactory.protocol = lambda: ConnectionLostNotifyingProtocol(
 
1238
            onServerConnectionLost)
 
1239
        serverPort = self.createServer('127.0.0.1', 0, serverFactory)
 
1240
 
 
1241
        onClientConnectionLost = defer.Deferred()
 
1242
        serverAddr = serverPort.getHost()
 
1243
        clientCreator = protocol.ClientCreator(
 
1244
            reactor, lambda: HandleSavingProtocol(onClientConnectionLost))
 
1245
        clientDeferred = self.connectClient(
 
1246
            serverAddr.host, serverAddr.port, clientCreator)
 
1247
 
 
1248
        def clientConnected(client):
 
1249
            """
 
1250
            Disconnect the client.  Return a Deferred which fires when both
 
1251
            the client and the server have received disconnect notification.
 
1252
            """
 
1253
            client.transport.write(
 
1254
                'some bytes to make sure the connection is set up')
 
1255
            client.transport.loseConnection()
 
1256
            return defer.gatherResults([
 
1257
                onClientConnectionLost, onServerConnectionLost])
 
1258
        clientDeferred.addCallback(clientConnected)
 
1259
 
 
1260
        def clientDisconnected((client, server)):
 
1261
            """
 
1262
            Verify that the underlying platform socket handle has been
 
1263
            cleaned up.
 
1264
            """
 
1265
            client.lostConnectionReason.trap(error.ConnectionClosed)
 
1266
            server.lostConnectionReason.trap(error.ConnectionClosed)
 
1267
            expectedErrorCode = self.getHandleErrorCode()
 
1268
            err = self.assertRaises(
 
1269
                self.getHandleExceptionType(), client.handle.send, 'bytes')
 
1270
            self.assertEqual(err.args[0], expectedErrorCode)
 
1271
        clientDeferred.addCallback(clientDisconnected)
 
1272
 
 
1273
        def cleanup(passthrough):
 
1274
            """
 
1275
            Shut down the server port.  Return a Deferred which fires when
 
1276
            this has completed.
 
1277
            """
 
1278
            result = defer.maybeDeferred(serverPort.stopListening)
 
1279
            result.addCallback(lambda ign: passthrough)
 
1280
            return result
 
1281
        clientDeferred.addBoth(cleanup)
 
1282
 
 
1283
        return clientDeferred
 
1284
 
 
1285
 
 
1286
 
 
1287
class ProperlyCloseFilesTestCase(unittest.TestCase, ProperlyCloseFilesMixin):
 
1288
    """
 
1289
    Test that the sockets created by L{IReactorTCP.connectTCP} are cleaned up
 
1290
    when the connection they are associated with is closed.
 
1291
    """
 
1292
    def createServer(self, address, portNumber, factory):
 
1293
        """
 
1294
        Create a TCP server using L{IReactorTCP.listenTCP}.
 
1295
        """
 
1296
        return reactor.listenTCP(portNumber, factory, interface=address)
 
1297
 
 
1298
 
 
1299
    def connectClient(self, address, portNumber, clientCreator):
 
1300
        """
 
1301
        Create a TCP client using L{IReactorTCP.connectTCP}.
 
1302
        """
 
1303
        return clientCreator.connectTCP(address, portNumber)
 
1304
 
 
1305
 
 
1306
    def getHandleExceptionType(self):
 
1307
        """
 
1308
        Return L{socket.error} as the expected error type which will be
 
1309
        raised by a write to the low-level socket object after it has been
 
1310
        closed.
 
1311
        """
 
1312
        return socket.error
 
1313
 
 
1314
 
 
1315
 
 
1316
class WiredForDeferreds(policies.ProtocolWrapper):
 
1317
    def __init__(self, factory, wrappedProtocol):
 
1318
        policies.ProtocolWrapper.__init__(self, factory, wrappedProtocol)
 
1319
 
 
1320
    def connectionMade(self):
 
1321
        policies.ProtocolWrapper.connectionMade(self)
 
1322
        self.factory.onConnect.callback(None)
 
1323
 
 
1324
    def connectionLost(self, reason):
 
1325
        policies.ProtocolWrapper.connectionLost(self, reason)
 
1326
        self.factory.onDisconnect.callback(None)
 
1327
 
 
1328
 
 
1329
 
 
1330
class WiredFactory(policies.WrappingFactory):
 
1331
    protocol = WiredForDeferreds
 
1332
 
 
1333
    def __init__(self, wrappedFactory):
 
1334
        policies.WrappingFactory.__init__(self, wrappedFactory)
 
1335
        self.onConnect = defer.Deferred()
 
1336
        self.onDisconnect = defer.Deferred()
 
1337
 
 
1338
 
 
1339
 
 
1340
class AddressTestCase(unittest.TestCase):
 
1341
    """
 
1342
    Tests for address-related interactions with client and server protocols.
 
1343
    """
 
1344
    def setUp(self):
 
1345
        """
 
1346
        Create a port and connected client/server pair which can be used
 
1347
        to test factory behavior related to addresses.
 
1348
 
 
1349
        @return: A L{defer.Deferred} which will be called back when both the
 
1350
            client and server protocols have received their connection made
 
1351
            callback.
 
1352
        """
 
1353
        class RememberingWrapper(protocol.ClientFactory):
 
1354
            """
 
1355
            Simple wrapper factory which records the addresses which are
 
1356
            passed to its L{buildProtocol} method and delegates actual
 
1357
            protocol creation to another factory.
 
1358
 
 
1359
            @ivar addresses: A list of the objects passed to buildProtocol.
 
1360
            @ivar factory: The wrapped factory to which protocol creation is
 
1361
                delegated.
 
1362
            """
 
1363
            def __init__(self, factory):
 
1364
                self.addresses = []
 
1365
                self.factory = factory
 
1366
 
 
1367
            # Only bother to pass on buildProtocol calls to the wrapped
 
1368
            # factory - doStart, doStop, etc aren't necessary for this test
 
1369
            # to pass.
 
1370
            def buildProtocol(self, addr):
 
1371
                """
 
1372
                Append the given address to C{self.addresses} and forward
 
1373
                the call to C{self.factory}.
 
1374
                """
 
1375
                self.addresses.append(addr)
 
1376
                return self.factory.buildProtocol(addr)
 
1377
 
 
1378
        # Make a server which we can receive connection and disconnection
 
1379
        # notification for, and which will record the address passed to its
 
1380
        # buildProtocol.
 
1381
        self.server = MyServerFactory()
 
1382
        self.serverConnMade = self.server.protocolConnectionMade = defer.Deferred()
 
1383
        self.serverConnLost = self.server.protocolConnectionLost = defer.Deferred()
 
1384
        # RememberingWrapper is a ClientFactory, but ClientFactory is-a
 
1385
        # ServerFactory, so this is okay.
 
1386
        self.serverWrapper = RememberingWrapper(self.server)
 
1387
 
 
1388
        # Do something similar for a client.
 
1389
        self.client = MyClientFactory()
 
1390
        self.clientConnMade = self.client.protocolConnectionMade = defer.Deferred()
 
1391
        self.clientConnLost = self.client.protocolConnectionLost = defer.Deferred()
 
1392
        self.clientWrapper = RememberingWrapper(self.client)
 
1393
 
 
1394
        self.port = reactor.listenTCP(0, self.serverWrapper, interface='127.0.0.1')
 
1395
        self.connector = reactor.connectTCP(
 
1396
            self.port.getHost().host, self.port.getHost().port, self.clientWrapper)
 
1397
 
 
1398
        return defer.gatherResults([self.serverConnMade, self.clientConnMade])
 
1399
 
 
1400
 
 
1401
    def tearDown(self):
 
1402
        """
 
1403
        Disconnect the client/server pair and shutdown the port created in
 
1404
        L{setUp}.
 
1405
        """
 
1406
        self.connector.disconnect()
 
1407
        return defer.gatherResults([
 
1408
            self.serverConnLost, self.clientConnLost,
 
1409
            defer.maybeDeferred(self.port.stopListening)])
 
1410
 
 
1411
 
 
1412
    def test_buildProtocolClient(self):
 
1413
        """
 
1414
        L{ClientFactory.buildProtocol} should be invoked with the address of
 
1415
        the server to which a connection has been established, which should
 
1416
        be the same as the address reported by the C{getHost} method of the
 
1417
        transport of the server protocol and as the C{getPeer} method of the
 
1418
        transport of the client protocol.
 
1419
        """
 
1420
        serverHost = self.server.protocol.transport.getHost()
 
1421
        clientPeer = self.client.protocol.transport.getPeer()
 
1422
 
 
1423
        self.assertEqual(
 
1424
            self.clientWrapper.addresses,
 
1425
            [IPv4Address('TCP', serverHost.host, serverHost.port)])
 
1426
        self.assertEqual(
 
1427
            self.clientWrapper.addresses,
 
1428
            [IPv4Address('TCP', clientPeer.host, clientPeer.port)])
 
1429
 
 
1430
 
 
1431
    def test_buildProtocolServer(self):
 
1432
        """
 
1433
        L{ServerFactory.buildProtocol} should be invoked with the address of
 
1434
        the client which has connected to the port the factory is listening on,
 
1435
        which should be the same as the address reported by the C{getPeer}
 
1436
        method of the transport of the server protocol and as the C{getHost}
 
1437
        method of the transport of the client protocol.
 
1438
        """
 
1439
        clientHost = self.client.protocol.transport.getHost()
 
1440
        serverPeer = self.server.protocol.transport.getPeer()
 
1441
 
 
1442
        self.assertEqual(
 
1443
            self.serverWrapper.addresses,
 
1444
            [IPv4Address('TCP', serverPeer.host, serverPeer.port)])
 
1445
        self.assertEqual(
 
1446
            self.serverWrapper.addresses,
 
1447
            [IPv4Address('TCP', clientHost.host, clientHost.port)])
 
1448
 
 
1449
 
 
1450
 
 
1451
class LargeBufferWriterProtocol(protocol.Protocol):
 
1452
 
 
1453
    # Win32 sockets cannot handle single huge chunks of bytes.  Write one
 
1454
    # massive string to make sure Twisted deals with this fact.
 
1455
 
 
1456
    def connectionMade(self):
 
1457
        # write 60MB
 
1458
        self.transport.write('X'*self.factory.len)
 
1459
        self.factory.done = 1
 
1460
        self.transport.loseConnection()
 
1461
 
 
1462
class LargeBufferReaderProtocol(protocol.Protocol):
 
1463
    def dataReceived(self, data):
 
1464
        self.factory.len += len(data)
 
1465
    def connectionLost(self, reason):
 
1466
        self.factory.done = 1
 
1467
 
 
1468
class LargeBufferReaderClientFactory(protocol.ClientFactory):
 
1469
    def __init__(self):
 
1470
        self.done = 0
 
1471
        self.len = 0
 
1472
    def buildProtocol(self, addr):
 
1473
        p = LargeBufferReaderProtocol()
 
1474
        p.factory = self
 
1475
        self.protocol = p
 
1476
        return p
 
1477
 
 
1478
 
 
1479
class FireOnClose(policies.ProtocolWrapper):
 
1480
    """A wrapper around a protocol that makes it fire a deferred when
 
1481
    connectionLost is called.
 
1482
    """
 
1483
    def connectionLost(self, reason):
 
1484
        policies.ProtocolWrapper.connectionLost(self, reason)
 
1485
        self.factory.deferred.callback(None)
 
1486
 
 
1487
 
 
1488
class FireOnCloseFactory(policies.WrappingFactory):
 
1489
    protocol = FireOnClose
 
1490
 
 
1491
    def __init__(self, wrappedFactory):
 
1492
        policies.WrappingFactory.__init__(self, wrappedFactory)
 
1493
        self.deferred = defer.Deferred()
 
1494
 
 
1495
 
 
1496
class LargeBufferTestCase(unittest.TestCase):
 
1497
    """Test that buffering large amounts of data works.
 
1498
    """
 
1499
 
 
1500
    datalen = 60*1024*1024
 
1501
    def testWriter(self):
 
1502
        f = protocol.Factory()
 
1503
        f.protocol = LargeBufferWriterProtocol
 
1504
        f.done = 0
 
1505
        f.problem = 0
 
1506
        f.len = self.datalen
 
1507
        wrappedF = FireOnCloseFactory(f)
 
1508
        p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1")
 
1509
        self.addCleanup(p.stopListening)
 
1510
        n = p.getHost().port
 
1511
        clientF = LargeBufferReaderClientFactory()
 
1512
        wrappedClientF = FireOnCloseFactory(clientF)
 
1513
        reactor.connectTCP("127.0.0.1", n, wrappedClientF)
 
1514
 
 
1515
        d = defer.gatherResults([wrappedF.deferred, wrappedClientF.deferred])
 
1516
        def check(ignored):
 
1517
            self.failUnless(f.done, "writer didn't finish, it probably died")
 
1518
            self.failUnless(clientF.len == self.datalen,
 
1519
                            "client didn't receive all the data it expected "
 
1520
                            "(%d != %d)" % (clientF.len, self.datalen))
 
1521
            self.failUnless(clientF.done,
 
1522
                            "client didn't see connection dropped")
 
1523
        return d.addCallback(check)
 
1524
 
 
1525
 
 
1526
class MyHCProtocol(AccumulatingProtocol):
 
1527
 
 
1528
    implements(IHalfCloseableProtocol)
 
1529
 
 
1530
    readHalfClosed = False
 
1531
    writeHalfClosed = False
 
1532
 
 
1533
    def readConnectionLost(self):
 
1534
        self.readHalfClosed = True
 
1535
        # Invoke notification logic from the base class to simplify testing.
 
1536
        if self.writeHalfClosed:
 
1537
            self.connectionLost(None)
 
1538
 
 
1539
    def writeConnectionLost(self):
 
1540
        self.writeHalfClosed = True
 
1541
        # Invoke notification logic from the base class to simplify testing.
 
1542
        if self.readHalfClosed:
 
1543
            self.connectionLost(None)
 
1544
 
 
1545
 
 
1546
class MyHCFactory(protocol.ServerFactory):
 
1547
 
 
1548
    called = 0
 
1549
    protocolConnectionMade = None
 
1550
 
 
1551
    def buildProtocol(self, addr):
 
1552
        self.called += 1
 
1553
        p = MyHCProtocol()
 
1554
        p.factory = self
 
1555
        self.protocol = p
 
1556
        return p
 
1557
 
 
1558
 
 
1559
class HalfCloseTestCase(unittest.TestCase):
 
1560
    """Test half-closing connections."""
 
1561
 
 
1562
    def setUp(self):
 
1563
        self.f = f = MyHCFactory()
 
1564
        self.p = p = reactor.listenTCP(0, f, interface="127.0.0.1")
 
1565
        self.addCleanup(p.stopListening)
 
1566
        d = loopUntil(lambda :p.connected)
 
1567
 
 
1568
        self.cf = protocol.ClientCreator(reactor, MyHCProtocol)
 
1569
 
 
1570
        d.addCallback(lambda _: self.cf.connectTCP(p.getHost().host,
 
1571
                                                   p.getHost().port))
 
1572
        d.addCallback(self._setUp)
 
1573
        return d
 
1574
 
 
1575
    def _setUp(self, client):
 
1576
        self.client = client
 
1577
        self.clientProtoConnectionLost = self.client.closedDeferred = defer.Deferred()
 
1578
        self.assertEquals(self.client.transport.connected, 1)
 
1579
        # Wait for the server to notice there is a connection, too.
 
1580
        return loopUntil(lambda: getattr(self.f, 'protocol', None) is not None)
 
1581
 
 
1582
    def tearDown(self):
 
1583
        self.assertEquals(self.client.closed, 0)
 
1584
        self.client.transport.loseConnection()
 
1585
        d = defer.maybeDeferred(self.p.stopListening)
 
1586
        d.addCallback(lambda ign: self.clientProtoConnectionLost)
 
1587
        d.addCallback(self._tearDown)
 
1588
        return d
 
1589
 
 
1590
    def _tearDown(self, ignored):
 
1591
        self.assertEquals(self.client.closed, 1)
 
1592
        # because we did half-close, the server also needs to
 
1593
        # closed explicitly.
 
1594
        self.assertEquals(self.f.protocol.closed, 0)
 
1595
        d = defer.Deferred()
 
1596
        def _connectionLost(reason):
 
1597
            self.f.protocol.closed = 1
 
1598
            d.callback(None)
 
1599
        self.f.protocol.connectionLost = _connectionLost
 
1600
        self.f.protocol.transport.loseConnection()
 
1601
        d.addCallback(lambda x:self.assertEquals(self.f.protocol.closed, 1))
 
1602
        return d
 
1603
 
 
1604
    def testCloseWriteCloser(self):
 
1605
        client = self.client
 
1606
        f = self.f
 
1607
        t = client.transport
 
1608
 
 
1609
        t.write("hello")
 
1610
        d = loopUntil(lambda :len(t._tempDataBuffer) == 0)
 
1611
        def loseWrite(ignored):
 
1612
            t.loseWriteConnection()
 
1613
            return loopUntil(lambda :t._writeDisconnected)
 
1614
        def check(ignored):
 
1615
            self.assertEquals(client.closed, False)
 
1616
            self.assertEquals(client.writeHalfClosed, True)
 
1617
            self.assertEquals(client.readHalfClosed, False)
 
1618
            return loopUntil(lambda :f.protocol.readHalfClosed)
 
1619
        def write(ignored):
 
1620
            w = client.transport.write
 
1621
            w(" world")
 
1622
            w("lalala fooled you")
 
1623
            self.assertEquals(0, len(client.transport._tempDataBuffer))
 
1624
            self.assertEquals(f.protocol.data, "hello")
 
1625
            self.assertEquals(f.protocol.closed, False)
 
1626
            self.assertEquals(f.protocol.readHalfClosed, True)
 
1627
        return d.addCallback(loseWrite).addCallback(check).addCallback(write)
 
1628
 
 
1629
    def testWriteCloseNotification(self):
 
1630
        f = self.f
 
1631
        f.protocol.transport.loseWriteConnection()
 
1632
 
 
1633
        d = defer.gatherResults([
 
1634
            loopUntil(lambda :f.protocol.writeHalfClosed),
 
1635
            loopUntil(lambda :self.client.readHalfClosed)])
 
1636
        d.addCallback(lambda _: self.assertEquals(
 
1637
            f.protocol.readHalfClosed, False))
 
1638
        return d
 
1639
 
 
1640
 
 
1641
class HalfClose2TestCase(unittest.TestCase):
 
1642
 
 
1643
    def setUp(self):
 
1644
        self.f = f = MyServerFactory()
 
1645
        self.f.protocolConnectionMade = defer.Deferred()
 
1646
        self.p = p = reactor.listenTCP(0, f, interface="127.0.0.1")
 
1647
 
 
1648
        # XXX we don't test server side yet since we don't do it yet
 
1649
        d = protocol.ClientCreator(reactor, AccumulatingProtocol).connectTCP(
 
1650
            p.getHost().host, p.getHost().port)
 
1651
        d.addCallback(self._gotClient)
 
1652
        return d
 
1653
 
 
1654
    def _gotClient(self, client):
 
1655
        self.client = client
 
1656
        # Now wait for the server to catch up - it doesn't matter if this
 
1657
        # Deferred has already fired and gone away, in that case we'll
 
1658
        # return None and not wait at all, which is precisely correct.
 
1659
        return self.f.protocolConnectionMade
 
1660
 
 
1661
    def tearDown(self):
 
1662
        self.client.transport.loseConnection()
 
1663
        return self.p.stopListening()
 
1664
 
 
1665
    def testNoNotification(self):
 
1666
        """
 
1667
        TCP protocols support half-close connections, but not all of them
 
1668
        support being notified of write closes.  In this case, test that
 
1669
        half-closing the connection causes the peer's connection to be
 
1670
        closed.
 
1671
        """
 
1672
        self.client.transport.write("hello")
 
1673
        self.client.transport.loseWriteConnection()
 
1674
        self.f.protocol.closedDeferred = d = defer.Deferred()
 
1675
        self.client.closedDeferred = d2 = defer.Deferred()
 
1676
        d.addCallback(lambda x:
 
1677
                      self.assertEqual(self.f.protocol.data, 'hello'))
 
1678
        d.addCallback(lambda x: self.assertEqual(self.f.protocol.closed, True))
 
1679
        return defer.gatherResults([d, d2])
 
1680
 
 
1681
    def testShutdownException(self):
 
1682
        """
 
1683
        If the other side has already closed its connection,
 
1684
        loseWriteConnection should pass silently.
 
1685
        """
 
1686
        self.f.protocol.transport.loseConnection()
 
1687
        self.client.transport.write("X")
 
1688
        self.client.transport.loseWriteConnection()
 
1689
        self.f.protocol.closedDeferred = d = defer.Deferred()
 
1690
        self.client.closedDeferred = d2 = defer.Deferred()
 
1691
        d.addCallback(lambda x:
 
1692
                      self.failUnlessEqual(self.f.protocol.closed, True))
 
1693
        return defer.gatherResults([d, d2])
 
1694
 
 
1695
 
 
1696
class HalfCloseBuggyApplicationTests(unittest.TestCase):
 
1697
    """
 
1698
    Test half-closing connections where notification code has bugs.
 
1699
    """
 
1700
 
 
1701
    def setUp(self):
 
1702
        """
 
1703
        Set up a server and connect a client to it.  Return a Deferred which
 
1704
        only fires once this is done.
 
1705
        """
 
1706
        self.serverFactory = MyHCFactory()
 
1707
        self.serverFactory.protocolConnectionMade = defer.Deferred()
 
1708
        self.port = reactor.listenTCP(
 
1709
            0, self.serverFactory, interface="127.0.0.1")
 
1710
        self.addCleanup(self.port.stopListening)
 
1711
        addr = self.port.getHost()
 
1712
        creator = protocol.ClientCreator(reactor, MyHCProtocol)
 
1713
        clientDeferred = creator.connectTCP(addr.host, addr.port)
 
1714
        def setClient(clientProtocol):
 
1715
            self.clientProtocol = clientProtocol
 
1716
        clientDeferred.addCallback(setClient)
 
1717
        return defer.gatherResults([
 
1718
            self.serverFactory.protocolConnectionMade,
 
1719
            clientDeferred])
 
1720
 
 
1721
 
 
1722
    def aBug(self, *args):
 
1723
        """
 
1724
        Fake implementation of a callback which illegally raises an
 
1725
        exception.
 
1726
        """
 
1727
        raise RuntimeError("ONO I AM BUGGY CODE")
 
1728
 
 
1729
 
 
1730
    def _notificationRaisesTest(self):
 
1731
        """
 
1732
        Helper for testing that an exception is logged by the time the
 
1733
        client protocol loses its connection.
 
1734
        """
 
1735
        closed = self.clientProtocol.closedDeferred = defer.Deferred()
 
1736
        self.clientProtocol.transport.loseWriteConnection()
 
1737
        def check(ignored):
 
1738
            errors = self.flushLoggedErrors(RuntimeError)
 
1739
            self.assertEqual(len(errors), 1)
 
1740
        closed.addCallback(check)
 
1741
        return closed
 
1742
 
 
1743
 
 
1744
    def test_readNotificationRaises(self):
 
1745
        """
 
1746
        If C{readConnectionLost} raises an exception when the transport
 
1747
        calls it to notify the protocol of that event, the exception should
 
1748
        be logged and the protocol should be disconnected completely.
 
1749
        """
 
1750
        self.serverFactory.protocol.readConnectionLost = self.aBug
 
1751
        return self._notificationRaisesTest()
 
1752
 
 
1753
 
 
1754
    def test_writeNotificationRaises(self):
 
1755
        """
 
1756
        If C{writeConnectionLost} raises an exception when the transport
 
1757
        calls it to notify the protocol of that event, the exception should
 
1758
        be logged and the protocol should be disconnected completely.
 
1759
        """
 
1760
        self.clientProtocol.writeConnectionLost = self.aBug
 
1761
        return self._notificationRaisesTest()
 
1762
 
 
1763
 
 
1764
 
 
1765
class LogTestCase(unittest.TestCase):
 
1766
    """
 
1767
    Test logging facility of TCP base classes.
 
1768
    """
 
1769
 
 
1770
    def test_logstrClientSetup(self):
 
1771
        """
 
1772
        Check that the log customization of the client transport happens
 
1773
        once the client is connected.
 
1774
        """
 
1775
        server = MyServerFactory()
 
1776
 
 
1777
        client = MyClientFactory()
 
1778
        client.protocolConnectionMade = defer.Deferred()
 
1779
 
 
1780
        port = reactor.listenTCP(0, server, interface='127.0.0.1')
 
1781
        self.addCleanup(port.stopListening)
 
1782
 
 
1783
        connector = reactor.connectTCP(
 
1784
            port.getHost().host, port.getHost().port, client)
 
1785
        self.addCleanup(connector.disconnect)
 
1786
 
 
1787
        # It should still have the default value
 
1788
        self.assertEquals(connector.transport.logstr,
 
1789
                          "Uninitialized")
 
1790
 
 
1791
        def cb(ign):
 
1792
            self.assertEquals(connector.transport.logstr,
 
1793
                              "AccumulatingProtocol,client")
 
1794
        client.protocolConnectionMade.addCallback(cb)
 
1795
        return client.protocolConnectionMade
 
1796
 
 
1797
 
 
1798
 
 
1799
class PauseProducingTestCase(unittest.TestCase):
 
1800
    """
 
1801
    Test some behaviors of pausing the production of a transport.
 
1802
    """
 
1803
 
 
1804
    def test_pauseProducingInConnectionMade(self):
 
1805
        """
 
1806
        In C{connectionMade} of a client protocol, C{pauseProducing} used to be
 
1807
        ignored: this test is here to ensure it's not ignored.
 
1808
        """
 
1809
        server = MyServerFactory()
 
1810
 
 
1811
        client = MyClientFactory()
 
1812
        client.protocolConnectionMade = defer.Deferred()
 
1813
 
 
1814
        port = reactor.listenTCP(0, server, interface='127.0.0.1')
 
1815
        self.addCleanup(port.stopListening)
 
1816
 
 
1817
        connector = reactor.connectTCP(
 
1818
            port.getHost().host, port.getHost().port, client)
 
1819
        self.addCleanup(connector.disconnect)
 
1820
 
 
1821
        def checkInConnectionMade(proto):
 
1822
            tr = proto.transport
 
1823
            # The transport should already be monitored
 
1824
            self.assertIn(tr, reactor.getReaders() +
 
1825
                              reactor.getWriters())
 
1826
            proto.transport.pauseProducing()
 
1827
            self.assertNotIn(tr, reactor.getReaders() +
 
1828
                                 reactor.getWriters())
 
1829
            d = defer.Deferred()
 
1830
            d.addCallback(checkAfterConnectionMade)
 
1831
            reactor.callLater(0, d.callback, proto)
 
1832
            return d
 
1833
        def checkAfterConnectionMade(proto):
 
1834
            tr = proto.transport
 
1835
            # The transport should still not be monitored
 
1836
            self.assertNotIn(tr, reactor.getReaders() +
 
1837
                                 reactor.getWriters())
 
1838
        client.protocolConnectionMade.addCallback(checkInConnectionMade)
 
1839
        return client.protocolConnectionMade
 
1840
 
 
1841
    if not interfaces.IReactorFDSet.providedBy(reactor):
 
1842
        test_pauseProducingInConnectionMade.skip = "Reactor not providing IReactorFDSet"
 
1843
 
 
1844
 
 
1845
 
 
1846
class CallBackOrderTestCase(unittest.TestCase):
 
1847
    """
 
1848
    Test the order of reactor callbacks
 
1849
    """
 
1850
 
 
1851
    def test_loseOrder(self):
 
1852
        """
 
1853
        Check that Protocol.connectionLost is called before factory's
 
1854
        clientConnectionLost
 
1855
        """
 
1856
        server = MyServerFactory()
 
1857
        server.protocolConnectionMade = (defer.Deferred()
 
1858
                .addCallback(lambda proto: self.addCleanup(
 
1859
                             proto.transport.loseConnection)))
 
1860
 
 
1861
        client = MyClientFactory()
 
1862
        client.protocolConnectionLost = defer.Deferred()
 
1863
        client.protocolConnectionMade = defer.Deferred()
 
1864
 
 
1865
        def _cbCM(res):
 
1866
            """
 
1867
            protocol.connectionMade callback
 
1868
            """
 
1869
            reactor.callLater(0, client.protocol.transport.loseConnection)
 
1870
 
 
1871
        client.protocolConnectionMade.addCallback(_cbCM)
 
1872
 
 
1873
        port = reactor.listenTCP(0, server, interface='127.0.0.1')
 
1874
        self.addCleanup(port.stopListening)
 
1875
 
 
1876
        connector = reactor.connectTCP(
 
1877
            port.getHost().host, port.getHost().port, client)
 
1878
        self.addCleanup(connector.disconnect)
 
1879
 
 
1880
        def _cbCCL(res):
 
1881
            """
 
1882
            factory.clientConnectionLost callback
 
1883
            """
 
1884
            return 'CCL'
 
1885
 
 
1886
        def _cbCL(res):
 
1887
            """
 
1888
            protocol.connectionLost callback
 
1889
            """
 
1890
            return 'CL'
 
1891
 
 
1892
        def _cbGather(res):
 
1893
            self.assertEquals(res, ['CL', 'CCL'])
 
1894
 
 
1895
        d = defer.gatherResults([
 
1896
                client.protocolConnectionLost.addCallback(_cbCL),
 
1897
                client.deferred.addCallback(_cbCCL)])
 
1898
        return d.addCallback(_cbGather)
 
1899
 
 
1900
 
 
1901
 
 
1902
try:
 
1903
    import resource
 
1904
except ImportError:
 
1905
    pass
 
1906
else:
 
1907
    numRounds = resource.getrlimit(resource.RLIMIT_NOFILE)[0] + 10
 
1908
    ProperlyCloseFilesTestCase.numberRounds = numRounds