1
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
2
# See LICENSE for details.
5
Test case for L{twisted.protocols.loopback}.
8
from zope.interface import implements
10
from twisted.trial import unittest
11
from twisted.trial.util import suppress as SUPPRESS
12
from twisted.protocols import basic, loopback
13
from twisted.internet import defer
14
from twisted.internet.protocol import Protocol
15
from twisted.internet.defer import Deferred
16
from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
17
from twisted.internet import reactor, interfaces
20
class SimpleProtocol(basic.LineReceiver):
22
self.conn = defer.Deferred()
26
def connectionMade(self):
27
self.conn.callback(None)
29
def lineReceived(self, line):
30
self.lines.append(line)
32
def connectionLost(self, reason):
33
self.connLost.append(reason)
36
class DoomProtocol(SimpleProtocol):
38
def lineReceived(self, line):
41
# by this point we should have connection closed,
42
# but just in case we didn't we won't ever send 'Hello 4'
43
self.sendLine("Hello %d" % self.i)
44
SimpleProtocol.lineReceived(self, line)
45
if self.lines[-1] == "Hello 3":
46
self.transport.loseConnection()
49
class LoopbackTestCaseMixin:
50
def testRegularFunction(self):
54
def sendALine(result):
55
s.sendLine("THIS IS LINE ONE!")
56
s.transport.loseConnection()
57
s.conn.addCallback(sendALine)
60
self.assertEquals(c.lines, ["THIS IS LINE ONE!"])
61
self.assertEquals(len(s.connLost), 1)
62
self.assertEquals(len(c.connLost), 1)
63
d = defer.maybeDeferred(self.loopbackFunc, s, c)
67
def testSneakyHiddenDoom(self):
71
def sendALine(result):
72
s.sendLine("DOOM LINE")
73
s.conn.addCallback(sendALine)
76
self.assertEquals(s.lines, ['Hello 1', 'Hello 2', 'Hello 3'])
77
self.assertEquals(c.lines, ['DOOM LINE', 'Hello 1', 'Hello 2', 'Hello 3'])
78
self.assertEquals(len(s.connLost), 1)
79
self.assertEquals(len(c.connLost), 1)
80
d = defer.maybeDeferred(self.loopbackFunc, s, c)
86
class LoopbackTestCase(LoopbackTestCaseMixin, unittest.TestCase):
87
loopbackFunc = staticmethod(loopback.loopback)
89
def testRegularFunction(self):
91
Suppress loopback deprecation warning.
93
return LoopbackTestCaseMixin.testRegularFunction(self)
94
testRegularFunction.suppress = [
95
SUPPRESS(message="loopback\(\) is deprecated",
96
category=DeprecationWarning)]
100
class LoopbackAsyncTestCase(LoopbackTestCase):
101
loopbackFunc = staticmethod(loopback.loopbackAsync)
104
def test_makeConnection(self):
106
Test that the client and server protocol both have makeConnection
107
invoked on them by loopbackAsync.
109
class TestProtocol(Protocol):
111
def makeConnection(self, transport):
112
self.transport = transport
114
server = TestProtocol()
115
client = TestProtocol()
116
loopback.loopbackAsync(server, client)
117
self.failIfEqual(client.transport, None)
118
self.failIfEqual(server.transport, None)
121
def _hostpeertest(self, get, testServer):
123
Test one of the permutations of client/server host/peer.
125
class TestProtocol(Protocol):
126
def makeConnection(self, transport):
127
Protocol.makeConnection(self, transport)
128
self.onConnection.callback(transport)
131
server = TestProtocol()
132
d = server.onConnection = Deferred()
136
client = TestProtocol()
137
d = client.onConnection = Deferred()
139
loopback.loopbackAsync(server, client)
141
def connected(transport):
142
host = getattr(transport, get)()
143
self.failUnless(IAddress.providedBy(host))
145
return d.addCallback(connected)
148
def test_serverHost(self):
150
Test that the server gets a transport with a properly functioning
151
implementation of L{ITransport.getHost}.
153
return self._hostpeertest("getHost", True)
156
def test_serverPeer(self):
158
Like C{test_serverHost} but for L{ITransport.getPeer}
160
return self._hostpeertest("getPeer", True)
163
def test_clientHost(self, get="getHost"):
165
Test that the client gets a transport with a properly functioning
166
implementation of L{ITransport.getHost}.
168
return self._hostpeertest("getHost", False)
171
def test_clientPeer(self):
173
Like C{test_clientHost} but for L{ITransport.getPeer}.
175
return self._hostpeertest("getPeer", False)
178
def _greetingtest(self, write, testServer):
180
Test one of the permutations of write/writeSequence client/server.
182
class GreeteeProtocol(Protocol):
184
def dataReceived(self, bytes):
186
if self.bytes == "bytes":
187
self.received.callback(None)
189
class GreeterProtocol(Protocol):
190
def connectionMade(self):
191
getattr(self.transport, write)("bytes")
194
server = GreeterProtocol()
195
client = GreeteeProtocol()
196
d = client.received = Deferred()
198
server = GreeteeProtocol()
199
d = server.received = Deferred()
200
client = GreeterProtocol()
202
loopback.loopbackAsync(server, client)
206
def test_clientGreeting(self):
208
Test that on a connection where the client speaks first, the server
209
receives the bytes sent by the client.
211
return self._greetingtest("write", False)
214
def test_clientGreetingSequence(self):
216
Like C{test_clientGreeting}, but use C{writeSequence} instead of
217
C{write} to issue the greeting.
219
return self._greetingtest("writeSequence", False)
222
def test_serverGreeting(self, write="write"):
224
Test that on a connection where the server speaks first, the client
225
receives the bytes sent by the server.
227
return self._greetingtest("write", True)
230
def test_serverGreetingSequence(self):
232
Like C{test_serverGreeting}, but use C{writeSequence} instead of
233
C{write} to issue the greeting.
235
return self._greetingtest("writeSequence", True)
238
def _producertest(self, producerClass):
239
toProduce = map(str, range(0, 10))
241
class ProducingProtocol(Protocol):
242
def connectionMade(self):
243
self.producer = producerClass(list(toProduce))
244
self.producer.start(self.transport)
246
class ReceivingProtocol(Protocol):
248
def dataReceived(self, bytes):
250
if self.bytes == ''.join(toProduce):
251
self.received.callback((client, server))
253
server = ProducingProtocol()
254
client = ReceivingProtocol()
255
client.received = Deferred()
257
loopback.loopbackAsync(server, client)
258
return client.received
261
def test_pushProducer(self):
263
Test a push producer registered against a loopback transport.
265
class PushProducer(object):
266
implements(IPushProducer)
269
def __init__(self, toProduce):
270
self.toProduce = toProduce
272
def resumeProducing(self):
275
def start(self, consumer):
276
self.consumer = consumer
277
consumer.registerProducer(self, True)
278
self._produceAndSchedule()
280
def _produceAndSchedule(self):
282
self.consumer.write(self.toProduce.pop(0))
283
reactor.callLater(0, self._produceAndSchedule)
285
self.consumer.unregisterProducer()
286
d = self._producertest(PushProducer)
288
def finished((client, server)):
290
server.producer.resumed,
291
"Streaming producer should not have been resumed.")
292
d.addCallback(finished)
296
def test_pullProducer(self):
298
Test a pull producer registered against a loopback transport.
300
class PullProducer(object):
301
implements(IPullProducer)
303
def __init__(self, toProduce):
304
self.toProduce = toProduce
306
def start(self, consumer):
307
self.consumer = consumer
308
self.consumer.registerProducer(self, False)
310
def resumeProducing(self):
311
self.consumer.write(self.toProduce.pop(0))
312
if not self.toProduce:
313
self.consumer.unregisterProducer()
314
return self._producertest(PullProducer)
317
def test_writeNotReentrant(self):
319
L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
320
method while that protocol's transport's C{write} method is higher up
323
class Server(Protocol):
324
def dataReceived(self, bytes):
325
self.transport.write("bytes")
327
class Client(Protocol):
330
def connectionMade(self):
331
reactor.callLater(0, self.go)
334
self.transport.write("foo")
337
def dataReceived(self, bytes):
338
self.wasReady = self.ready
339
self.transport.loseConnection()
344
d = loopback.loopbackAsync(client, server)
345
def cbFinished(ignored):
346
self.assertTrue(client.wasReady)
347
d.addCallback(cbFinished)
351
def test_pumpPolicy(self):
353
The callable passed as the value for the C{pumpPolicy} parameter to
354
L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
355
and a protocol to which they should be delivered.
358
def dummyPolicy(queue, target):
361
bytes.append(queue.get())
362
pumpCalls.append((target, bytes))
367
finished = loopback.loopbackAsync(server, client, dummyPolicy)
368
self.assertEquals(pumpCalls, [])
370
client.transport.write("foo")
371
client.transport.write("bar")
372
server.transport.write("baz")
373
server.transport.write("quux")
374
server.transport.loseConnection()
376
def cbComplete(ignored):
379
# The order here is somewhat arbitrary. The implementation
380
# happens to always deliver data to the client first.
381
[(client, ["baz", "quux", None]),
382
(server, ["foo", "bar"])])
383
finished.addCallback(cbComplete)
387
def test_identityPumpPolicy(self):
389
L{identityPumpPolicy} is a pump policy which calls the target's
390
C{dataReceived} method one for each string in the queue passed to it.
394
client.dataReceived = bytes.append
395
queue = loopback._LoopbackQueue()
400
loopback.identityPumpPolicy(queue, client)
402
self.assertEquals(bytes, ["foo", "bar"])
405
def test_collapsingPumpPolicy(self):
407
L{collapsingPumpPolicy} is a pump policy which calls the target's
408
C{dataReceived} only once with all of the strings in the queue passed
409
to it joined together.
413
client.dataReceived = bytes.append
414
queue = loopback._LoopbackQueue()
419
loopback.collapsingPumpPolicy(queue, client)
421
self.assertEquals(bytes, ["foobar"])
425
class LoopbackTCPTestCase(LoopbackTestCase):
426
loopbackFunc = staticmethod(loopback.loopbackTCP)
429
class LoopbackUNIXTestCase(LoopbackTestCase):
430
loopbackFunc = staticmethod(loopback.loopbackUNIX)
432
if interfaces.IReactorUNIX(reactor, None) is None:
433
skip = "Current reactor does not support UNIX sockets"