1
# Copyright (c) 2005 Divmod, Inc.
2
# Copyright (c) 2007-2009 Twisted Matrix Laboratories.
3
# See LICENSE for details.
6
Tests for L{twisted.protocols.amp}.
9
from zope.interface.verify import verifyObject
11
from twisted.python.util import setIDFunction
12
from twisted.python import filepath
13
from twisted.python.failure import Failure
14
from twisted.protocols import amp
15
from twisted.trial import unittest
16
from twisted.internet import protocol, defer, error, reactor, interfaces
17
from twisted.test import iosim
18
from twisted.test.proto_helpers import StringTransport
21
from twisted.internet import ssl
24
if ssl and not ssl.supported:
28
skipSSL = "SSL not available"
33
class TestProto(protocol.Protocol):
35
A trivial protocol for use in testing where a L{Protocol} is expected.
37
@ivar instanceId: the id of this instance
38
@ivar onConnLost: deferred that will fired when the connection is lost
39
@ivar dataToSend: data to send on the protocol
44
def __init__(self, onConnLost, dataToSend):
45
self.onConnLost = onConnLost
46
self.dataToSend = dataToSend
47
self.instanceId = TestProto.instanceCount
48
TestProto.instanceCount = TestProto.instanceCount + 1
50
def connectionMade(self):
52
self.transport.write(self.dataToSend)
54
def dataReceived(self, bytes):
55
self.data.append(bytes)
56
# self.transport.loseConnection()
58
def connectionLost(self, reason):
59
self.onConnLost.callback(self.data)
64
Custom repr for testing to avoid coupling amp tests with repr from
67
Returns a string which contains a unique identifier that can be looked
68
up using the instanceId property::
72
return "<TestProto #%d>" % (self.instanceId,)
76
class SimpleSymmetricProtocol(amp.AMP):
78
def sendHello(self, text):
79
return self.callRemoteString(
83
def amp_HELLO(self, box):
84
return amp.Box(hello=box['hello'])
86
def amp_HOWDOYOUDO(self, box):
87
return amp.QuitBox(howdoyoudo='world')
91
class UnfriendlyGreeting(Exception):
92
"""Greeting was insufficiently kind.
95
class DeathThreat(Exception):
96
"""Greeting was insufficiently kind.
99
class UnknownProtocol(Exception):
100
"""Asked to switch to the wrong protocol.
104
class TransportPeer(amp.Argument):
105
# this serves as some informal documentation for how to get variables from
106
# the protocol or your environment and pass them to methods as arguments.
107
def retrieve(self, d, name, proto):
110
def fromStringProto(self, notAString, proto):
111
return proto.transport.getPeer()
113
def toBox(self, name, strings, objects, proto):
118
class Hello(amp.Command):
120
commandName = 'hello'
122
arguments = [('hello', amp.String()),
123
('optional', amp.Boolean(optional=True)),
124
('print', amp.Unicode(optional=True)),
125
('from', TransportPeer(optional=True)),
126
('mixedCase', amp.String(optional=True)),
127
('dash-arg', amp.String(optional=True)),
128
('underscore_arg', amp.String(optional=True))]
130
response = [('hello', amp.String()),
131
('print', amp.Unicode(optional=True))]
133
errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
135
fatalErrors = {DeathThreat: 'DEAD'}
137
class NoAnswerHello(Hello):
138
commandName = Hello.commandName
139
requiresAnswer = False
141
class FutureHello(amp.Command):
142
commandName = 'hello'
144
arguments = [('hello', amp.String()),
145
('optional', amp.Boolean(optional=True)),
146
('print', amp.Unicode(optional=True)),
147
('from', TransportPeer(optional=True)),
148
('bonus', amp.String(optional=True)), # addt'l arguments
149
# should generally be
150
# added at the end, and
154
response = [('hello', amp.String()),
155
('print', amp.Unicode(optional=True))]
157
errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
159
class WTF(amp.Command):
161
An example of an invalid command.
165
class BrokenReturn(amp.Command):
166
""" An example of a perfectly good command, but the handler is going to return
170
commandName = 'broken_return'
172
class Goodbye(amp.Command):
173
# commandName left blank on purpose: this tests implicit command names.
174
response = [('goodbye', amp.String())]
175
responseType = amp.QuitBox
177
class Howdoyoudo(amp.Command):
178
commandName = 'howdoyoudo'
179
# responseType = amp.QuitBox
181
class WaitForever(amp.Command):
182
commandName = 'wait_forever'
184
class GetList(amp.Command):
185
commandName = 'getlist'
186
arguments = [('length', amp.Integer())]
187
response = [('body', amp.AmpList([('x', amp.Integer())]))]
189
class DontRejectMe(amp.Command):
190
commandName = 'dontrejectme'
192
('magicWord', amp.Unicode()),
193
('list', amp.AmpList([('name', amp.Unicode())], optional=True)),
195
response = [('response', amp.Unicode())]
197
class SecuredPing(amp.Command):
198
# XXX TODO: actually make this refuse to send over an insecure connection
199
response = [('pinged', amp.Boolean())]
201
class TestSwitchProto(amp.ProtocolSwitchCommand):
202
commandName = 'Switch-Proto'
205
('name', amp.String()),
207
errors = {UnknownProtocol: 'UNKNOWN'}
209
class SingleUseFactory(protocol.ClientFactory):
210
def __init__(self, proto):
212
self.proto.factory = self
214
def buildProtocol(self, addr):
215
p, self.proto = self.proto, None
220
def clientConnectionFailed(self, connector, reason):
221
self.reasonFailed = reason
224
THING_I_DONT_UNDERSTAND = 'gwebol nargo'
225
class ThingIDontUnderstandError(Exception):
228
class FactoryNotifier(amp.AMP):
230
def connectionMade(self):
231
if self.factory is not None:
232
self.factory.theProto = self
233
if hasattr(self.factory, 'onMade'):
234
self.factory.onMade.callback(None)
237
from twisted.internet.interfaces import ISSLTransport
238
if not ISSLTransport.providedBy(self.transport):
239
raise DeathThreat("only send secure pings over secure channels")
240
return {'pinged': True}
241
SecuredPing.responder(emitpong)
244
class SimpleSymmetricCommandProtocol(FactoryNotifier):
246
def __init__(self, onConnLost=None):
247
amp.AMP.__init__(self)
248
self.onConnLost = onConnLost
250
def sendHello(self, text):
251
return self.callRemote(Hello, hello=text)
253
def sendUnicodeHello(self, text, translation):
254
return self.callRemote(Hello, hello=text, Print=translation)
258
def cmdHello(self, hello, From, optional=None, Print=None,
259
mixedCase=None, dash_arg=None, underscore_arg=None):
260
assert From == self.transport.getPeer()
261
if hello == THING_I_DONT_UNDERSTAND:
262
raise ThingIDontUnderstandError()
263
if hello.startswith('fuck'):
264
raise UnfriendlyGreeting("Don't be a dick.")
266
raise DeathThreat("aieeeeeeeee")
267
result = dict(hello=hello)
268
if Print is not None:
269
result.update(dict(Print=Print))
272
Hello.responder(cmdHello)
274
def cmdGetlist(self, length):
275
return {'body': [dict(x=1)] * length}
276
GetList.responder(cmdGetlist)
278
def okiwont(self, magicWord, list):
279
return dict(response=u'%s accepted' % (list[0]['name']))
280
DontRejectMe.responder(okiwont)
283
self.waiting = defer.Deferred()
285
WaitForever.responder(waitforit)
288
return dict(howdoyoudo='world')
289
Howdoyoudo.responder(howdo)
292
return dict(goodbye="everyone")
293
Goodbye.responder(saybye)
295
def switchToTestProtocol(self, fail=False):
300
p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
301
return self.callRemote(
303
SingleUseFactory(p), name=name).addCallback(lambda ign: p)
305
def switchit(self, name):
306
if name == 'test-proto':
307
return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
308
raise UnknownProtocol(name)
309
TestSwitchProto.responder(switchit)
313
BrokenReturn.responder(donothing)
316
class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
317
def switchit(self, name):
318
if name == 'test-proto':
319
self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
320
self.maybeLater = defer.Deferred()
321
return self.maybeLater
322
raise UnknownProtocol(name)
323
TestSwitchProto.responder(switchit)
325
class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
326
def badResponder(self, hello, From, optional=None, Print=None,
327
mixedCase=None, dash_arg=None, underscore_arg=None):
329
This responder does nothing and forgets to return a dictionary.
331
NoAnswerHello.responder(badResponder)
333
class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
334
def goodNoAnswerResponder(self, hello, From, optional=None, Print=None,
335
mixedCase=None, dash_arg=None, underscore_arg=None):
336
return dict(hello=hello+"-noanswer")
337
NoAnswerHello.responder(goodNoAnswerResponder)
339
def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol,
340
ClientClass=SimpleSymmetricProtocol,
342
"""Returns a 3-tuple: (client, server, pump)
344
return iosim.connectedServerAndClient(
345
ServerClass, ClientClass,
348
class TotallyDumbProtocol(protocol.Protocol):
350
def dataReceived(self, data):
353
class LiteralAmp(amp.AMP):
357
def ampBoxReceived(self, box):
358
self.boxes.append(box)
361
class ParsingTest(unittest.TestCase):
363
def test_booleanValues(self):
365
Verify that the Boolean parser parses 'True' and 'False', but nothing
369
self.assertEquals(b.fromString("True"), True)
370
self.assertEquals(b.fromString("False"), False)
371
self.assertRaises(TypeError, b.fromString, "ninja")
372
self.assertRaises(TypeError, b.fromString, "true")
373
self.assertRaises(TypeError, b.fromString, "TRUE")
374
self.assertEquals(b.toString(True), 'True')
375
self.assertEquals(b.toString(False), 'False')
377
def test_pathValueRoundTrip(self):
379
Verify the 'Path' argument can parse and emit a file path.
381
fp = filepath.FilePath(self.mktemp())
385
self.assertNotIdentical(fp, v) # sanity check
386
self.assertEquals(fp, v)
389
def test_sillyEmptyThing(self):
391
Test that empty boxes raise an error; they aren't supposed to be sent
395
return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
398
def test_ParsingRoundTrip(self):
400
Verify that various kinds of data make it through the encode/parse
403
c, s, p = connectedServerAndClient(ClientClass=LiteralAmp,
404
ServerClass=LiteralAmp)
406
SIMPLE = ('simple', 'test')
408
CR = ('crtest', 'test\r')
409
LF = ('lftest', 'hello\n')
410
NEWLINE = ('newline', 'test\r\none\r\ntwo')
411
NEWLINE2 = ('newline2', 'test\r\none\r\n two')
412
BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline')
413
BODYTEST = ('body', 'blah\r\n\r\ntesttest')
420
[SIMPLE, CE, CR, LF],
422
[SIMPLE, NEWLINE, CE, NEWLINE2],
423
[BODYTEST, SIMPLE, NEWLINE]
426
for test in testData:
428
jb.update(dict(test))
431
self.assertEquals(s.boxes[-1], jb)
435
class FakeLocator(object):
437
This is a fake implementation of the interface implied by
442
Remember the given keyword arguments as a set of responders.
447
def locateResponder(self, commandName):
449
Look up and return a function passed as a keyword argument of the given
450
name to the constructor.
452
return self.commands[commandName]
457
This is a fake implementation of the 'box sender' interface implied by
462
Create a fake sender and initialize the list of received boxes and
466
self.unhandledErrors = []
467
self.expectedErrors = 0
470
def expectError(self):
472
Expect one error, so that the test doesn't fail.
474
self.expectedErrors += 1
477
def sendBox(self, box):
479
Accept a box, but don't do anything.
481
self.sentBoxes.append(box)
484
def unhandledError(self, failure):
486
Deal with failures by instantly re-raising them for easier debugging.
488
self.expectedErrors -= 1
489
if self.expectedErrors < 0:
490
failure.raiseException()
492
self.unhandledErrors.append(failure)
496
class CommandDispatchTests(unittest.TestCase):
498
The AMP CommandDispatcher class dispatches converts AMP boxes into commands
499
and responses using Command.responder decorator.
501
Note: Originally, AMP's factoring was such that many tests for this
502
functionality are now implemented as full round-trip tests in L{AMPTest}.
503
Future tests should be written at this level instead, to ensure API
504
compatibility and to provide more granular, readable units of test
510
Create a dispatcher to use.
512
self.locator = FakeLocator()
513
self.sender = FakeSender()
514
self.dispatcher = amp.BoxDispatcher(self.locator)
515
self.dispatcher.startReceivingBoxes(self.sender)
518
def test_receivedAsk(self):
520
L{CommandDispatcher.ampBoxReceived} should locate the appropriate
521
command in its responder lookup, based on the '_ask' key.
526
return amp.Box({"hello": "goodbye"})
527
input = amp.Box(_command="hello",
528
_ask="test-command-id",
530
self.locator.commands['hello'] = thunk
531
self.dispatcher.ampBoxReceived(input)
532
self.assertEquals(received, [input])
535
def test_sendUnhandledError(self):
537
L{CommandDispatcher} should relay its unhandled errors in responding to
538
boxes to its boxSender.
540
err = RuntimeError("something went wrong, oh no")
541
self.sender.expectError()
542
self.dispatcher.unhandledError(Failure(err))
543
self.assertEqual(len(self.sender.unhandledErrors), 1)
544
self.assertEqual(self.sender.unhandledErrors[0].value, err)
547
def test_unhandledSerializationError(self):
549
Errors during serialization ought to be relayed to the sender's
550
unhandledError method.
552
err = RuntimeError("something undefined went wrong")
554
class BrokenBox(amp.Box):
555
def _sendTo(self, proto):
558
self.locator.commands['hello'] = thunk
559
input = amp.Box(_command="hello",
560
_ask="test-command-id",
562
self.sender.expectError()
563
self.dispatcher.ampBoxReceived(input)
564
self.assertEquals(len(self.sender.unhandledErrors), 1)
565
self.assertEquals(self.sender.unhandledErrors[0].value, err)
568
def test_callRemote(self):
570
L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
571
box to its boxSender and record an outstanding L{Deferred}. When a
572
corresponding '_answer' packet is received, the L{Deferred} should be
573
fired, and the results translated via the given L{Command}'s response
576
D = self.dispatcher.callRemote(Hello, hello='world')
577
self.assertEquals(self.sender.sentBoxes,
578
[amp.AmpBox(_command="hello",
582
D.addCallback(answers.append)
583
self.assertEquals(answers, [])
584
self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay",
587
self.assertEquals(answers, [dict(hello="yay",
591
class SimpleGreeting(amp.Command):
593
A very simple greeting command that uses a few basic argument types.
595
commandName = 'simple'
596
arguments = [('greeting', amp.Unicode()),
597
('cookie', amp.Integer())]
598
response = [('cookieplus', amp.Integer())]
601
class TestLocator(amp.CommandLocator):
603
A locator which implements a responder to a 'hello' command.
609
def greetingResponder(self, greeting, cookie):
610
self.greetings.append((greeting, cookie))
611
return dict(cookieplus=cookie + 3)
612
greetingResponder = SimpleGreeting.responder(greetingResponder)
616
class OverrideLocatorAMP(amp.AMP):
618
amp.AMP.__init__(self)
619
self.customResponder = object()
620
self.expectations = {"custom": self.customResponder}
624
def lookupFunction(self, name):
626
Override the deprecated lookupFunction function.
628
if name in self.expectations:
629
result = self.expectations[name]
632
return super(OverrideLocatorAMP, self).lookupFunction(name)
635
def greetingResponder(self, greeting, cookie):
636
self.greetings.append((greeting, cookie))
637
return dict(cookieplus=cookie + 3)
638
greetingResponder = SimpleGreeting.responder(greetingResponder)
643
class CommandLocatorTests(unittest.TestCase):
645
The CommandLocator should enable users to specify responders to commands as
646
functions that take structured objects, annotated with metadata.
649
def test_responderDecorator(self):
651
A method on a L{CommandLocator} subclass decorated with a L{Command}
652
subclass's L{responder} decorator should be returned from
653
locateResponder, wrapped in logic to serialize and deserialize its
656
locator = TestLocator()
657
responderCallable = locator.locateResponder("simple")
658
result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
660
self.assertEquals(values, amp.AmpBox(cookieplus='8'))
661
return result.addCallback(done)
664
def test_lookupFunctionDeprecatedOverride(self):
666
Subclasses which override locateResponder under its old name,
667
lookupFunction, should have the override invoked instead. (This tests
668
an AMP subclass, because in the version of the code that could invoke
669
this deprecated code path, there was no L{CommandLocator}.)
671
locator = OverrideLocatorAMP()
672
customResponderObject = self.assertWarns(
673
PendingDeprecationWarning,
674
"Override locateResponder, not lookupFunction.",
675
__file__, lambda : locator.locateResponder("custom"))
676
self.assertEquals(locator.customResponder, customResponderObject)
677
# Make sure upcalling works too
678
normalResponderObject = self.assertWarns(
679
PendingDeprecationWarning,
680
"Override locateResponder, not lookupFunction.",
681
__file__, lambda : locator.locateResponder("simple"))
682
result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5"))
684
self.assertEquals(values, amp.AmpBox(cookieplus='8'))
685
return result.addCallback(done)
688
def test_lookupFunctionDeprecatedInvoke(self):
690
Invoking locateResponder under its old name, lookupFunction, should
691
emit a deprecation warning, but do the same thing.
693
locator = TestLocator()
694
responderCallable = self.assertWarns(
695
PendingDeprecationWarning,
696
"Call locateResponder, not lookupFunction.", __file__,
697
lambda : locator.lookupFunction("simple"))
698
result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
700
self.assertEquals(values, amp.AmpBox(cookieplus='8'))
701
return result.addCallback(done)
705
SWITCH_CLIENT_DATA = 'Success!'
706
SWITCH_SERVER_DATA = 'No, really. Success.'
709
class BinaryProtocolTests(unittest.TestCase):
711
Tests for L{amp.BinaryBoxProtocol}.
713
@ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender}
714
which was passed to it.
719
Keep track of all boxes received by this test in its capacity as an
720
L{IBoxReceiver} implementor.
726
def startReceivingBoxes(self, sender):
728
Implement L{IBoxReceiver.startReceivingBoxes} to just remember the
731
self._boxSender = sender
734
def ampBoxReceived(self, box):
736
A box was received by the protocol.
738
self.boxes.append(box)
741
def stopReceivingBoxes(self, reason):
743
Record the reason that we stopped receiving boxes.
745
self.stopReason = reason
757
def write(self, data):
758
self.data.append(data)
761
def test_startReceivingBoxes(self):
763
When L{amp.BinaryBoxProtocol} is connected to a transport, it calls
764
C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the
765
L{IBoxSender} parameter.
767
protocol = amp.BinaryBoxProtocol(self)
768
protocol.makeConnection(None)
769
self.assertIdentical(self._boxSender, protocol)
772
def test_sendBoxInStartReceivingBoxes(self):
774
The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is
775
connected to a transport can call C{sendBox} on the L{IBoxSender}
776
passed to it before C{startReceivingBoxes} returns and have that box
779
class SynchronouslySendingReceiver:
780
def startReceivingBoxes(self, sender):
781
sender.sendBox(amp.Box({'foo': 'bar'}))
783
transport = StringTransport()
784
protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver())
785
protocol.makeConnection(transport)
788
'\x00\x03foo\x00\x03bar\x00\x00')
791
def test_receiveBoxStateMachine(self):
793
When a binary box protocol receives:
797
it should emit a box and send it to its boxReceiver.
799
a = amp.BinaryBoxProtocol(self)
800
a.stringReceived("hello")
801
a.stringReceived("world")
803
self.assertEquals(self.boxes, [amp.AmpBox(hello="world")])
806
def test_firstBoxFirstKeyExcessiveLength(self):
808
L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
809
the first a key it receives is larger than 255.
811
transport = StringTransport()
812
protocol = amp.BinaryBoxProtocol(self)
813
protocol.makeConnection(transport)
814
protocol.dataReceived('\x01\x00')
815
self.assertTrue(transport.disconnecting)
818
def test_firstBoxSubsequentKeyExcessiveLength(self):
820
L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
821
a subsequent key in the first box it receives is larger than 255.
823
transport = StringTransport()
824
protocol = amp.BinaryBoxProtocol(self)
825
protocol.makeConnection(transport)
826
protocol.dataReceived('\x00\x01k\x00\x01v')
827
self.assertFalse(transport.disconnecting)
828
protocol.dataReceived('\x01\x00')
829
self.assertTrue(transport.disconnecting)
832
def test_subsequentBoxFirstKeyExcessiveLength(self):
834
L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
835
the first key in a subsequent box it receives is larger than 255.
837
transport = StringTransport()
838
protocol = amp.BinaryBoxProtocol(self)
839
protocol.makeConnection(transport)
840
protocol.dataReceived('\x00\x01k\x00\x01v\x00\x00')
841
self.assertFalse(transport.disconnecting)
842
protocol.dataReceived('\x01\x00')
843
self.assertTrue(transport.disconnecting)
846
def test_excessiveKeyFailure(self):
848
If L{amp.BinaryBoxProtocol} disconnects because it received a key
849
length prefix which was too large, the L{IBoxReceiver}'s
850
C{stopReceivingBoxes} method is called with a L{TooLong} failure.
852
protocol = amp.BinaryBoxProtocol(self)
853
protocol.makeConnection(StringTransport())
854
protocol.dataReceived('\x01\x00')
855
protocol.connectionLost(
856
Failure(error.ConnectionDone("simulated connection done")))
857
self.stopReason.trap(amp.TooLong)
858
self.assertTrue(self.stopReason.value.isKey)
859
self.assertFalse(self.stopReason.value.isLocal)
860
self.assertIdentical(self.stopReason.value.value, None)
861
self.assertIdentical(self.stopReason.value.keyName, None)
864
def test_receiveBoxData(self):
866
When a binary box protocol receives the serialized form of an AMP box,
867
it should emit a similar box to its boxReceiver.
869
a = amp.BinaryBoxProtocol(self)
870
a.dataReceived(amp.Box({"testKey": "valueTest",
871
"anotherKey": "anotherValue"}).serialize())
872
self.assertEquals(self.boxes,
873
[amp.Box({"testKey": "valueTest",
874
"anotherKey": "anotherValue"})])
877
def test_receiveLongerBoxData(self):
879
An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with
880
values of up to (2 ** 16 - 1) bytes.
882
length = (2 ** 16 - 1)
884
transport = StringTransport()
885
protocol = amp.BinaryBoxProtocol(self)
886
protocol.makeConnection(transport)
887
protocol.dataReceived(amp.Box({'k': value}).serialize())
888
self.assertEqual(self.boxes, [amp.Box({'k': value})])
889
self.assertFalse(transport.disconnecting)
892
def test_sendBox(self):
894
When a binary box protocol sends a box, it should emit the serialized
895
bytes of that box to its transport.
897
a = amp.BinaryBoxProtocol(self)
898
a.makeConnection(self)
899
aBox = amp.Box({"testKey": "valueTest",
900
"someData": "hello"})
901
a.makeConnection(self)
903
self.assertEquals(''.join(self.data), aBox.serialize())
906
def test_connectionLostStopSendingBoxes(self):
908
When a binary box protocol loses its connection, it should notify its
909
box receiver that it has stopped receiving boxes.
911
a = amp.BinaryBoxProtocol(self)
912
a.makeConnection(self)
913
aBox = amp.Box({"sample": "data"})
914
a.makeConnection(self)
915
connectionFailure = Failure(RuntimeError())
916
a.connectionLost(connectionFailure)
917
self.assertIdentical(self.stopReason, connectionFailure)
920
def test_protocolSwitch(self):
922
L{BinaryBoxProtocol} has the capacity to switch to a different protocol
923
on a box boundary. When a protocol is in the process of switching, it
924
cannot receive traffic.
926
otherProto = TestProto(None, "outgoing data")
928
class SwitchyReceiver:
930
def startReceivingBoxes(self, sender):
932
def ampBoxReceived(self, box):
933
test.assertFalse(self.switched,
934
"Should only receive one box!")
937
a._switchTo(otherProto)
938
a = amp.BinaryBoxProtocol(SwitchyReceiver())
939
anyOldBox = amp.Box({"include": "lots",
941
a.makeConnection(self)
942
# Include a 0-length box at the beginning of the next protocol's data,
943
# to make sure that AMP doesn't eat the data or try to deliver extra
945
moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!"
946
a.dataReceived(moreThanOneBox)
947
self.assertIdentical(otherProto.transport, self)
948
self.assertEquals("".join(otherProto.data), "\x00\x00Hello, world!")
949
self.assertEquals(self.data, ["outgoing data"])
950
a.dataReceived("more data")
951
self.assertEquals("".join(otherProto.data),
952
"\x00\x00Hello, world!more data")
953
self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)
956
def test_protocolSwitchInvalidStates(self):
958
In order to make sure the protocol never gets any invalid data sent
959
into the middle of a box, it must be locked for switching before it is
960
switched. It can only be unlocked if the switch failed, and attempting
961
to send a box while it is locked should raise an exception.
963
a = amp.BinaryBoxProtocol(self)
964
a.makeConnection(self)
965
sampleBox = amp.Box({"some": "data"})
967
self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
968
a._unlockFromSwitch()
970
self.assertEquals(''.join(self.data), sampleBox.serialize())
972
otherProto = TestProto(None, "outgoing data")
973
a._switchTo(otherProto)
974
self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)
977
def test_protocolSwitchLoseConnection(self):
979
When the protocol is switched, it should notify its nested protocol of
982
class Loser(protocol.Protocol):
984
def connectionLost(self, reason):
986
connectionLoser = Loser()
987
a = amp.BinaryBoxProtocol(self)
988
a.makeConnection(self)
990
a._switchTo(connectionLoser)
991
connectionFailure = Failure(RuntimeError())
992
a.connectionLost(connectionFailure)
993
self.assertEquals(connectionLoser.reason, connectionFailure)
996
def test_protocolSwitchLoseClientConnection(self):
998
When the protocol is switched, it should notify its nested client
999
protocol factory of disconnection.
1003
def clientConnectionLost(self, connector, reason):
1004
self.reason = reason
1005
a = amp.BinaryBoxProtocol(self)
1006
connectionLoser = protocol.Protocol()
1007
clientLoser = ClientLoser()
1008
a.makeConnection(self)
1010
a._switchTo(connectionLoser, clientLoser)
1011
connectionFailure = Failure(RuntimeError())
1012
a.connectionLost(connectionFailure)
1013
self.assertEquals(clientLoser.reason, connectionFailure)
1017
class AMPTest(unittest.TestCase):
1019
def test_interfaceDeclarations(self):
1021
The classes in the amp module ought to implement the interfaces that
1022
are declared for their benefit.
1024
for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol),
1025
(amp.IBoxReceiver, amp.BoxDispatcher),
1026
(amp.IResponderLocator, amp.CommandLocator),
1027
(amp.IResponderLocator, amp.SimpleStringLocator),
1028
(amp.IBoxSender, amp.AMP),
1029
(amp.IBoxReceiver, amp.AMP),
1030
(amp.IResponderLocator, amp.AMP)]:
1031
self.failUnless(interface.implementedBy(implementation),
1032
"%s does not implements(%s)" % (implementation, interface))
1035
def test_helloWorld(self):
1037
Verify that a simple command can be sent and its response received with
1038
the simple low-level string-based API.
1040
c, s, p = connectedServerAndClient()
1043
c.sendHello(HELLO).addCallback(L.append)
1045
self.assertEquals(L[0]['hello'], HELLO)
1048
def test_wireFormatRoundTrip(self):
1050
Verify that mixed-case, underscored and dashed arguments are mapped to
1051
their python names properly.
1053
c, s, p = connectedServerAndClient()
1056
c.sendHello(HELLO).addCallback(L.append)
1058
self.assertEquals(L[0]['hello'], HELLO)
1061
def test_helloWorldUnicode(self):
1063
Verify that unicode arguments can be encoded and decoded.
1065
c, s, p = connectedServerAndClient(
1066
ServerClass=SimpleSymmetricCommandProtocol,
1067
ClientClass=SimpleSymmetricCommandProtocol)
1070
HELLO_UNICODE = 'wor\u1234ld'
1071
c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
1073
self.assertEquals(L[0]['hello'], HELLO)
1074
self.assertEquals(L[0]['Print'], HELLO_UNICODE)
1077
def test_callRemoteStringRequiresAnswerFalse(self):
1079
L{BoxDispatcher.callRemoteString} returns C{None} if C{requiresAnswer}
1082
c, s, p = connectedServerAndClient()
1083
ret = c.callRemoteString("WTF", requiresAnswer=False)
1084
self.assertIdentical(ret, None)
1087
def test_unknownCommandLow(self):
1089
Verify that unknown commands using low-level APIs will be rejected with an
1090
error, but will NOT terminate the connection.
1092
c, s, p = connectedServerAndClient()
1096
You can't propagate the error...
1098
e.trap(amp.UnhandledCommand)
1100
c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append)
1102
self.assertEquals(L.pop(), "OK")
1104
c.sendHello(HELLO).addCallback(L.append)
1106
self.assertEquals(L[0]['hello'], HELLO)
1109
def test_unknownCommandHigh(self):
1111
Verify that unknown commands using high-level APIs will be rejected with an
1112
error, but will NOT terminate the connection.
1114
c, s, p = connectedServerAndClient()
1118
You can't propagate the error...
1120
e.trap(amp.UnhandledCommand)
1122
c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
1124
self.assertEquals(L.pop(), "OK")
1126
c.sendHello(HELLO).addCallback(L.append)
1128
self.assertEquals(L[0]['hello'], HELLO)
1131
def test_brokenReturnValue(self):
1133
It can be very confusing if you write some code which responds to a
1134
command, but gets the return value wrong. Most commonly you end up
1135
returning None instead of a dictionary.
1137
Verify that if that happens, the framework logs a useful error.
1140
SimpleSymmetricCommandProtocol().dispatchCommand(
1141
amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append)
1142
blr = L[0].trap(amp.BadLocalReturn)
1143
self.failUnlessIn('None', repr(L[0].value))
1146
def test_unknownArgument(self):
1148
Verify that unknown arguments are ignored, and not passed to a Python
1149
function which can't accept them.
1151
c, s, p = connectedServerAndClient(
1152
ServerClass=SimpleSymmetricCommandProtocol,
1153
ClientClass=SimpleSymmetricCommandProtocol)
1156
# c.sendHello(HELLO).addCallback(L.append)
1157
c.callRemote(FutureHello,
1159
bonus="I'm not in the book!").addCallback(
1162
self.assertEquals(L[0]['hello'], HELLO)
1165
def test_simpleReprs(self):
1167
Verify that the various Box objects repr properly, for debugging.
1169
self.assertEquals(type(repr(amp._SwitchBox('a'))), str)
1170
self.assertEquals(type(repr(amp.QuitBox())), str)
1171
self.assertEquals(type(repr(amp.AmpBox())), str)
1172
self.failUnless("AmpBox" in repr(amp.AmpBox()))
1175
def test_innerProtocolInRepr(self):
1177
Verify that L{AMP} objects output their innerProtocol when set.
1179
otherProto = TestProto(None, "outgoing data")
1181
a.innerProtocol = otherProto
1183
return {a: 0x1234}.get(obj, id(obj))
1184
self.addCleanup(setIDFunction, setIDFunction(fakeID))
1187
repr(a), "<AMP inner <TestProto #%d> at 0x1234>" % (
1188
otherProto.instanceId,))
1191
def test_innerProtocolNotInRepr(self):
1193
Verify that L{AMP} objects do not output 'inner' when no innerProtocol
1198
return {a: 0x4321}.get(obj, id(obj))
1199
self.addCleanup(setIDFunction, setIDFunction(fakeID))
1200
self.assertEquals(repr(a), "<AMP at 0x4321>")
1203
def test_simpleSSLRepr(self):
1205
L{amp._TLSBox.__repr__} returns a string.
1207
self.assertEquals(type(repr(amp._TLSBox())), str)
1209
test_simpleSSLRepr.skip = skipSSL
1212
def test_keyTooLong(self):
1214
Verify that a key that is too long will immediately raise a synchronous
1217
c, s, p = connectedServerAndClient()
1220
tl = self.assertRaises(amp.TooLong,
1221
c.callRemoteString, "Hello",
1223
self.failUnless(tl.isKey)
1224
self.failUnless(tl.isLocal)
1225
self.failUnlessIdentical(tl.keyName, None)
1226
self.failUnlessIdentical(tl.value, x)
1227
self.failUnless(str(len(x)) in repr(tl))
1228
self.failUnless("key" in repr(tl))
1231
def test_valueTooLong(self):
1233
Verify that attempting to send value longer than 64k will immediately
1236
c, s, p = connectedServerAndClient()
1238
x = "H" * (0xffff+1)
1239
tl = self.assertRaises(amp.TooLong, c.sendHello, x)
1241
self.failIf(tl.isKey)
1242
self.failUnless(tl.isLocal)
1243
self.assertEquals(tl.keyName, 'hello')
1244
self.failUnlessIdentical(tl.value, x)
1245
self.failUnless(str(len(x)) in repr(tl))
1246
self.failUnless("value" in repr(tl))
1247
self.failUnless('hello' in repr(tl))
1250
def test_helloWorldCommand(self):
1252
Verify that a simple command can be sent and its response received with
1253
the high-level value parsing API.
1255
c, s, p = connectedServerAndClient(
1256
ServerClass=SimpleSymmetricCommandProtocol,
1257
ClientClass=SimpleSymmetricCommandProtocol)
1260
c.sendHello(HELLO).addCallback(L.append)
1262
self.assertEquals(L[0]['hello'], HELLO)
1265
def test_helloErrorHandling(self):
1267
Verify that if a known error type is raised and handled, it will be
1268
properly relayed to the other end of the connection and translated into
1269
an exception, and no error will be logged.
1272
c, s, p = connectedServerAndClient(
1273
ServerClass=SimpleSymmetricCommandProtocol,
1274
ClientClass=SimpleSymmetricCommandProtocol)
1276
c.sendHello(HELLO).addErrback(L.append)
1278
L[0].trap(UnfriendlyGreeting)
1279
self.assertEquals(str(L[0].value), "Don't be a dick.")
1282
def test_helloFatalErrorHandling(self):
1284
Verify that if a known, fatal error type is raised and handled, it will
1285
be properly relayed to the other end of the connection and translated
1286
into an exception, no error will be logged, and the connection will be
1290
c, s, p = connectedServerAndClient(
1291
ServerClass=SimpleSymmetricCommandProtocol,
1292
ClientClass=SimpleSymmetricCommandProtocol)
1294
c.sendHello(HELLO).addErrback(L.append)
1296
L.pop().trap(DeathThreat)
1297
c.sendHello(HELLO).addErrback(L.append)
1299
L.pop().trap(error.ConnectionDone)
1303
def test_helloNoErrorHandling(self):
1305
Verify that if an unknown error type is raised, it will be relayed to
1306
the other end of the connection and translated into an exception, it
1307
will be logged, and then the connection will be dropped.
1310
c, s, p = connectedServerAndClient(
1311
ServerClass=SimpleSymmetricCommandProtocol,
1312
ClientClass=SimpleSymmetricCommandProtocol)
1313
HELLO = THING_I_DONT_UNDERSTAND
1314
c.sendHello(HELLO).addErrback(L.append)
1317
ure.trap(amp.UnknownRemoteError)
1318
c.sendHello(HELLO).addErrback(L.append)
1320
cl.trap(error.ConnectionDone)
1321
# The exception should have been logged.
1322
self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError))
1326
def test_lateAnswer(self):
1328
Verify that a command that does not get answered until after the
1329
connection terminates will not cause any errors.
1331
c, s, p = connectedServerAndClient(
1332
ServerClass=SimpleSymmetricCommandProtocol,
1333
ClientClass=SimpleSymmetricCommandProtocol)
1336
c.callRemote(WaitForever).addErrback(L.append)
1338
self.assertEquals(L, [])
1339
s.transport.loseConnection()
1341
L.pop().trap(error.ConnectionDone)
1342
# Just make sure that it doesn't error...
1343
s.waiting.callback({})
1347
def test_requiresNoAnswer(self):
1349
Verify that a command that requires no answer is run.
1352
c, s, p = connectedServerAndClient(
1353
ServerClass=SimpleSymmetricCommandProtocol,
1354
ClientClass=SimpleSymmetricCommandProtocol)
1356
c.callRemote(NoAnswerHello, hello=HELLO)
1358
self.failUnless(s.greeted)
1361
def test_requiresNoAnswerFail(self):
1363
Verify that commands sent after a failed no-answer request do not complete.
1366
c, s, p = connectedServerAndClient(
1367
ServerClass=SimpleSymmetricCommandProtocol,
1368
ClientClass=SimpleSymmetricCommandProtocol)
1370
c.callRemote(NoAnswerHello, hello=HELLO)
1372
# This should be logged locally.
1373
self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError))
1375
c.callRemote(Hello, hello=HELLO).addErrback(L.append)
1377
L.pop().trap(error.ConnectionDone)
1378
self.failIf(s.greeted)
1381
def test_noAnswerResponderBadAnswer(self):
1383
Verify that responders of requiresAnswer=False commands have to return
1384
a dictionary anyway.
1386
(requiresAnswer is a hint from the _client_ - the server may be called
1387
upon to answer commands in any case, if the client wants to know when
1390
c, s, p = connectedServerAndClient(
1391
ServerClass=BadNoAnswerCommandProtocol,
1392
ClientClass=SimpleSymmetricCommandProtocol)
1393
c.callRemote(NoAnswerHello, hello="hello")
1395
le = self.flushLoggedErrors(amp.BadLocalReturn)
1396
self.assertEquals(len(le), 1)
1399
def test_noAnswerResponderAskedForAnswer(self):
1401
Verify that responders with requiresAnswer=False will actually respond
1402
if the client sets requiresAnswer=True. In other words, verify that
1403
requiresAnswer is a hint honored only by the client.
1405
c, s, p = connectedServerAndClient(
1406
ServerClass=NoAnswerCommandProtocol,
1407
ClientClass=SimpleSymmetricCommandProtocol)
1409
c.callRemote(Hello, hello="Hello!").addCallback(L.append)
1411
self.assertEquals(len(L), 1)
1412
self.assertEquals(L, [dict(hello="Hello!-noanswer",
1413
Print=None)]) # Optional response argument
1416
def test_ampListCommand(self):
1418
Test encoding of an argument that uses the AmpList encoding.
1420
c, s, p = connectedServerAndClient(
1421
ServerClass=SimpleSymmetricCommandProtocol,
1422
ClientClass=SimpleSymmetricCommandProtocol)
1424
c.callRemote(GetList, length=10).addCallback(L.append)
1426
values = L.pop().get('body')
1427
self.assertEquals(values, [{'x': 1}] * 10)
1430
def test_optionalAmpListOmitted(self):
1432
Test that sending a command with an omitted AmpList argument that is
1433
designated as optional does not raise an InvalidSignature error.
1435
dontRejectMeCommand = DontRejectMe(magicWord=u'please')
1438
def test_optionalAmpListPresent(self):
1440
Sanity check that optional AmpList arguments are processed normally.
1442
dontRejectMeCommand = DontRejectMe(magicWord=u'please',
1443
list=[{'name': 'foo'}])
1444
c, s, p = connectedServerAndClient(
1445
ServerClass=SimpleSymmetricCommandProtocol,
1446
ClientClass=SimpleSymmetricCommandProtocol)
1448
c.callRemote(DontRejectMe, magicWord=u'please',
1449
list=[{'name': 'foo'}]).addCallback(L.append)
1451
response = L.pop().get('response')
1452
self.assertEquals(response, 'foo accepted')
1455
def test_failEarlyOnArgSending(self):
1457
Verify that if we pass an invalid argument list (omitting an argument), an
1458
exception will be raised.
1460
okayCommand = Hello(hello="What?")
1461
self.assertRaises(amp.InvalidSignature, Hello)
1464
def test_doubleProtocolSwitch(self):
1466
As a debugging aid, a protocol system should raise a
1467
L{ProtocolSwitched} exception when asked to switch a protocol that is
1470
serverDeferred = defer.Deferred()
1471
serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
1472
clientDeferred = defer.Deferred()
1473
clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
1474
c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1475
ClientClass=lambda: clientProto)
1476
def switched(result):
1477
self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
1478
self.testSucceeded = True
1479
c.switchToTestProtocol().addCallback(switched)
1481
self.failUnless(self.testSucceeded)
1484
def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol,
1485
spuriousTraffic=False,
1486
spuriousError=False):
1488
Verify that it is possible to switch to another protocol mid-connection and
1489
send data to it successfully.
1491
self.testSucceeded = False
1493
serverDeferred = defer.Deferred()
1494
serverProto = switcher(serverDeferred)
1495
clientDeferred = defer.Deferred()
1496
clientProto = switcher(clientDeferred)
1497
c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1498
ClientClass=lambda: clientProto)
1502
wfd = c.callRemote(WaitForever).addErrback(wfdr.append)
1503
switchDeferred = c.switchToTestProtocol()
1505
self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world')
1507
def cbConnsLost(((serverSuccess, serverData),
1508
(clientSuccess, clientData))):
1509
self.failUnless(serverSuccess)
1510
self.failUnless(clientSuccess)
1511
self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA)
1512
self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA)
1513
self.testSucceeded = True
1515
def cbSwitch(proto):
1516
return defer.DeferredList(
1517
[serverDeferred, clientDeferred]).addCallback(cbConnsLost)
1519
switchDeferred.addCallback(cbSwitch)
1521
if serverProto.maybeLater is not None:
1522
serverProto.maybeLater.callback(serverProto.maybeLaterProto)
1525
# switch is done here; do this here to make sure that if we're
1526
# going to corrupt the connection, we do it before it's closed.
1528
s.waiting.errback(amp.RemoteAmpError(
1530
"Here's some traffic in the form of an error."))
1532
s.waiting.callback({})
1534
c.transport.loseConnection() # close it
1536
self.failUnless(self.testSucceeded)
1539
def test_protocolSwitchDeferred(self):
1541
Verify that protocol-switching even works if the value returned from
1542
the command that does the switch is deferred.
1544
return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
1547
def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
1549
Verify that if we try to switch protocols and it fails, the connection
1550
stays up and we can go back to speaking AMP.
1552
self.testSucceeded = False
1554
serverDeferred = defer.Deferred()
1555
serverProto = switcher(serverDeferred)
1556
clientDeferred = defer.Deferred()
1557
clientProto = switcher(clientDeferred)
1558
c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1559
ClientClass=lambda: clientProto)
1561
switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append)
1563
L.pop().trap(UnknownProtocol)
1564
self.failIf(self.testSucceeded)
1565
# It's a known error, so let's send a "hello" on the same connection;
1567
c.sendHello('world').addCallback(L.append)
1569
self.assertEqual(L.pop()['hello'], 'world')
1572
def test_trafficAfterSwitch(self):
1574
Verify that attempts to send traffic after a switch will not corrupt
1575
the nested protocol.
1577
return self.test_protocolSwitch(spuriousTraffic=True)
1580
def test_errorAfterSwitch(self):
1582
Returning an error after a protocol switch should record the underlying
1585
return self.test_protocolSwitch(spuriousTraffic=True,
1589
def test_quitBoxQuits(self):
1591
Verify that commands with a responseType of QuitBox will in fact
1592
terminate the connection.
1594
c, s, p = connectedServerAndClient(
1595
ServerClass=SimpleSymmetricCommandProtocol,
1596
ClientClass=SimpleSymmetricCommandProtocol)
1600
GOODBYE = 'everyone'
1601
c.sendHello(HELLO).addCallback(L.append)
1603
self.assertEquals(L.pop()['hello'], HELLO)
1604
c.callRemote(Goodbye).addCallback(L.append)
1606
self.assertEquals(L.pop()['goodbye'], GOODBYE)
1607
c.sendHello(HELLO).addErrback(L.append)
1608
L.pop().trap(error.ConnectionDone)
1611
def test_basicLiteralEmit(self):
1613
Verify that the command dictionaries for a callRemoteN look correct
1614
after being serialized and parsed.
1616
c, s, p = connectedServerAndClient()
1618
s.ampBoxReceived = L.append
1619
c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
1620
dash_arg='x', underscore_arg='y')
1622
self.assertEquals(len(L), 1)
1623
for k, v in [('_command', Hello.commandName),
1624
('hello', 'hello test'),
1625
('mixedCase', 'mixed case arg test'),
1627
('underscore_arg', 'y')]:
1628
self.assertEquals(L[-1].pop(k), v)
1630
self.assertEquals(L[-1], {})
1633
def test_basicStructuredEmit(self):
1635
Verify that a call similar to basicLiteralEmit's is handled properly with
1636
high-level quoting and passing to Python methods, and that argument
1637
names are correctly handled.
1640
class StructuredHello(amp.AMP):
1641
def h(self, *a, **k):
1643
return dict(hello='aaa')
1645
c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
1646
c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
1647
dash_arg='x', underscore_arg='y').addCallback(L.append)
1649
self.assertEquals(len(L), 2)
1650
self.assertEquals(L[0],
1653
mixedCase='mixed case arg test',
1657
# XXX - should optional arguments just not be passed?
1658
# passing None seems a little odd, looking at the way it
1659
# turns out here... -glyph
1660
From=('file', 'file'),
1664
self.assertEquals(L[1], dict(Print=None, hello='aaa'))
1666
class PretendRemoteCertificateAuthority:
1667
def checkIsPretendRemote(self):
1673
def options(self, *ign):
1676
def iosimVerify(self, otherCert):
1678
This isn't a real certificate, and wouldn't work on a real socket, but
1679
iosim specifies a different API so that we don't have to do any crypto
1680
math to demonstrate that the right functions get called in the right
1683
assert otherCert is self
1684
self.verifyCount += 1
1687
class OKCert(IOSimCert):
1688
def options(self, x):
1689
assert x.checkIsPretendRemote()
1692
class GrumpyCert(IOSimCert):
1693
def iosimVerify(self, otherCert):
1694
self.verifyCount += 1
1697
class DroppyCert(IOSimCert):
1698
def __init__(self, toDrop):
1699
self.toDrop = toDrop
1701
def iosimVerify(self, otherCert):
1702
self.verifyCount += 1
1703
self.toDrop.loseConnection()
1706
class SecurableProto(FactoryNotifier):
1710
def verifyFactory(self):
1711
return [PretendRemoteCertificateAuthority()]
1713
def getTLSVars(self):
1714
cert = self.certFactory()
1715
verify = self.verifyFactory()
1717
tls_localCertificate=cert,
1718
tls_verifyAuthorities=verify)
1719
amp.StartTLS.responder(getTLSVars)
1723
class TLSTest(unittest.TestCase):
1724
def test_startingTLS(self):
1726
Verify that starting TLS and succeeding at handshaking sends all the
1727
notifications to all the right places.
1729
cli, svr, p = connectedServerAndClient(
1730
ServerClass=SecurableProto,
1731
ClientClass=SecurableProto)
1734
svr.certFactory = lambda : okc
1737
amp.StartTLS, tls_localCertificate=okc,
1738
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1740
# let's buffer something to be delivered securely
1742
d = cli.callRemote(SecuredPing).addCallback(L.append)
1744
# once for client once for server
1745
self.assertEquals(okc.verifyCount, 2)
1747
d = cli.callRemote(SecuredPing).addCallback(L.append)
1749
self.assertEqual(L[0], {'pinged': True})
1752
def test_startTooManyTimes(self):
1754
Verify that the protocol will complain if we attempt to renegotiate TLS,
1755
which we don't support.
1757
cli, svr, p = connectedServerAndClient(
1758
ServerClass=SecurableProto,
1759
ClientClass=SecurableProto)
1762
svr.certFactory = lambda : okc
1764
cli.callRemote(amp.StartTLS,
1765
tls_localCertificate=okc,
1766
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1768
cli.noPeerCertificate = True # this is totally fake
1773
tls_localCertificate=okc,
1774
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1777
def test_negotiationFailed(self):
1779
Verify that starting TLS and failing on both sides at handshaking sends
1780
notifications to all the right places and terminates the connection.
1783
badCert = GrumpyCert()
1785
cli, svr, p = connectedServerAndClient(
1786
ServerClass=SecurableProto,
1787
ClientClass=SecurableProto)
1788
svr.certFactory = lambda : badCert
1790
cli.callRemote(amp.StartTLS,
1791
tls_localCertificate=badCert)
1794
# once for client once for server - but both fail
1795
self.assertEquals(badCert.verifyCount, 2)
1796
d = cli.callRemote(SecuredPing)
1798
self.assertFailure(d, iosim.NativeOpenSSLError)
1801
def test_negotiationFailedByClosing(self):
1803
Verify that starting TLS and failing by way of a lost connection
1804
notices that it is probably an SSL problem.
1807
cli, svr, p = connectedServerAndClient(
1808
ServerClass=SecurableProto,
1809
ClientClass=SecurableProto)
1810
droppyCert = DroppyCert(svr.transport)
1811
svr.certFactory = lambda : droppyCert
1813
secure = cli.callRemote(amp.StartTLS,
1814
tls_localCertificate=droppyCert)
1818
self.assertEquals(droppyCert.verifyCount, 2)
1820
d = cli.callRemote(SecuredPing)
1823
# it might be a good idea to move this exception somewhere more
1825
self.assertFailure(d, error.PeerVerifyError)
1831
class TLSNotAvailableTest(unittest.TestCase):
1833
Tests what happened when ssl is not available in current installation.
1851
def test_callRemoteError(self):
1853
Check that callRemote raises an exception when called with a
1856
cli, svr, p = connectedServerAndClient(
1857
ServerClass=SecurableProto,
1858
ClientClass=SecurableProto)
1861
svr.certFactory = lambda : okc
1863
return self.assertFailure(cli.callRemote(
1864
amp.StartTLS, tls_localCertificate=okc,
1865
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]),
1869
def test_messageReceivedError(self):
1871
When a client with SSL enabled talks to a server without SSL, it
1872
should return a meaningful error.
1874
svr = SecurableProto()
1876
svr.certFactory = lambda : okc
1878
box['_command'] = 'StartTLS'
1881
svr.sendBox = boxes.append
1882
svr.makeConnection(StringTransport())
1883
svr.ampBoxReceived(box)
1884
self.assertEquals(boxes,
1885
[{'_error_code': 'TLS_ERROR',
1887
'_error_description': 'TLS not available'}])
1891
class InheritedError(Exception):
1893
This error is used to check inheritance.
1898
class OtherInheritedError(Exception):
1900
This is a distinct error for checking inheritance.
1905
class BaseCommand(amp.Command):
1907
This provides a command that will be subclassed.
1909
errors = {InheritedError: 'INHERITED_ERROR'}
1913
class InheritedCommand(BaseCommand):
1915
This is a command which subclasses another command but does not override
1921
class AddErrorsCommand(BaseCommand):
1923
This is a command which subclasses another command but adds errors to the
1926
arguments = [('other', amp.Boolean())]
1927
errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'}
1931
class NormalCommandProtocol(amp.AMP):
1933
This is a protocol which responds to L{BaseCommand}, and is used to test
1934
that inheritance does not interfere with the normal handling of errors.
1937
raise InheritedError()
1938
BaseCommand.responder(resp)
1942
class InheritedCommandProtocol(amp.AMP):
1944
This is a protocol which responds to L{InheritedCommand}, and is used to
1945
test that inherited commands inherit their bases' errors if they do not
1946
respond to any of their own.
1949
raise InheritedError()
1950
InheritedCommand.responder(resp)
1954
class AddedCommandProtocol(amp.AMP):
1956
This is a protocol which responds to L{AddErrorsCommand}, and is used to
1957
test that inherited commands can add their own new types of errors, but
1958
still respond in the same way to their parents types of errors.
1960
def resp(self, other):
1962
raise OtherInheritedError()
1964
raise InheritedError()
1965
AddErrorsCommand.responder(resp)
1969
class CommandInheritanceTests(unittest.TestCase):
1971
These tests verify that commands inherit error conditions properly.
1974
def errorCheck(self, err, proto, cmd, **kw):
1976
Check that the appropriate kind of error is raised when a given command
1977
is sent to a given protocol.
1979
c, s, p = connectedServerAndClient(ServerClass=proto,
1981
d = c.callRemote(cmd, **kw)
1982
d2 = self.failUnlessFailure(d, err)
1987
def test_basicErrorPropagation(self):
1989
Verify that errors specified in a superclass are respected normally
1990
even if it has subclasses.
1992
return self.errorCheck(
1993
InheritedError, NormalCommandProtocol, BaseCommand)
1996
def test_inheritedErrorPropagation(self):
1998
Verify that errors specified in a superclass command are propagated to
2001
return self.errorCheck(
2002
InheritedError, InheritedCommandProtocol, InheritedCommand)
2005
def test_inheritedErrorAddition(self):
2007
Verify that new errors specified in a subclass of an existing command
2008
are honored even if the superclass defines some errors.
2010
return self.errorCheck(
2011
OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True)
2014
def test_additionWithOriginalError(self):
2016
Verify that errors specified in a command's superclass are respected
2017
even if that command defines new errors itself.
2019
return self.errorCheck(
2020
InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False)
2023
def _loseAndPass(err, proto):
2024
# be specific, pass on the error to the client.
2025
err.trap(error.ConnectionLost, error.ConnectionDone)
2026
del proto.connectionLost
2027
proto.connectionLost(err)
2032
Utility for connected reactor-using tests.
2037
Create an amp server and connect a client to it.
2039
from twisted.internet import reactor
2040
self.serverFactory = protocol.ServerFactory()
2041
self.serverFactory.protocol = self.serverProto
2042
self.clientFactory = protocol.ClientFactory()
2043
self.clientFactory.protocol = self.clientProto
2044
self.clientFactory.onMade = defer.Deferred()
2045
self.serverFactory.onMade = defer.Deferred()
2046
self.serverPort = reactor.listenTCP(0, self.serverFactory)
2047
self.addCleanup(self.serverPort.stopListening)
2048
self.clientConn = reactor.connectTCP(
2049
'127.0.0.1', self.serverPort.getHost().port,
2051
self.addCleanup(self.clientConn.disconnect)
2052
def getProtos(rlst):
2053
self.cli = self.clientFactory.theProto
2054
self.svr = self.serverFactory.theProto
2055
dl = defer.DeferredList([self.clientFactory.onMade,
2056
self.serverFactory.onMade])
2057
return dl.addCallback(getProtos)
2061
Cleanup client and server connections, and check the error got at
2065
for conn in self.cli, self.svr:
2066
if conn.transport is not None:
2067
# depend on amp's function connection-dropping behavior
2068
d = defer.Deferred().addErrback(_loseAndPass, conn)
2069
conn.connectionLost = d.errback
2070
conn.transport.loseConnection()
2072
return defer.gatherResults(L
2073
).addErrback(lambda first: first.value.subFailure)
2078
sys.stdout.write(x+'\n')
2082
def tempSelfSigned():
2083
from twisted.internet import ssl
2085
sharedDN = ssl.DN(CN='shared')
2086
key = ssl.KeyPair.generate()
2087
cr = key.certificateRequest(sharedDN)
2088
sscrd = key.signCertificateRequest(
2089
sharedDN, cr, lambda dn: True, 1234567)
2090
cert = key.newCertificate(sscrd)
2094
tempcert = tempSelfSigned()
2097
class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase):
2098
clientProto = SecurableProto
2099
serverProto = SecurableProto
2100
def test_liveFireCustomTLS(self):
2102
Using real, live TLS, actually negotiate a connection.
2104
This also looks at the 'peerCertificate' attribute's correctness, since
2105
that's actually loaded using OpenSSL calls, but the main purpose is to
2106
make sure that we didn't miss anything obvious in iosim about TLS
2112
self.svr.verifyFactory = lambda : [cert]
2113
self.svr.certFactory = lambda : cert
2114
# only needed on the server, we specify the client below.
2119
# Interesting. OpenSSL won't even _tell_ us about the peer
2120
# cert until we negotiate. we should be able to do this in
2121
# 'secured' instead, but it looks like we can't. I think this
2122
# is a bug somewhere far deeper than here.
2123
self.failUnlessEqual(x, self.cli.hostCertificate.digest())
2124
self.failUnlessEqual(x, self.cli.peerCertificate.digest())
2125
self.failUnlessEqual(x, self.svr.hostCertificate.digest())
2126
self.failUnlessEqual(x, self.svr.peerCertificate.digest())
2127
return self.cli.callRemote(SecuredPing).addCallback(pinged)
2128
return self.cli.callRemote(amp.StartTLS,
2129
tls_localCertificate=cert,
2130
tls_verifyAuthorities=[cert]).addCallback(secured)
2136
class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
2138
Specific implementation of server side protocol with different
2141
def getTLSVars(self):
2143
@return: the global C{tempcert} certificate as local certificate.
2145
return dict(tls_localCertificate=tempcert)
2146
amp.StartTLS.responder(getTLSVars)
2149
class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase):
2151
clientProto = SimpleSymmetricCommandProtocol
2152
serverProto = SimpleSymmetricCommandProtocol
2154
def test_liveFireDefaultTLS(self):
2156
Verify that out of the box, we can start TLS to at least encrypt the
2157
connection, even if we don't have any certificates to use.
2159
def secured(result):
2160
return self.cli.callRemote(SecuredPing)
2161
return self.cli.callRemote(amp.StartTLS).addCallback(secured)
2167
class WithServerTLSVerification(LiveFireBase, unittest.TestCase):
2168
clientProto = SimpleSymmetricCommandProtocol
2169
serverProto = SlightlySmartTLS
2171
def test_anonymousVerifyingClient(self):
2173
Verify that anonymous clients can verify server certificates.
2175
def secured(result):
2176
return self.cli.callRemote(SecuredPing)
2177
return self.cli.callRemote(amp.StartTLS,
2178
tls_verifyAuthorities=[tempcert]
2179
).addCallback(secured)
2185
class ProtocolIncludingArgument(amp.Argument):
2187
An L{amp.Argument} which encodes its parser and serializer
2188
arguments *including the protocol* into its parsed and serialized
2192
def fromStringProto(self, string, protocol):
2194
Don't decode anything; just return all possible information.
2196
@return: A two-tuple of the input string and the protocol.
2198
return (string, protocol)
2200
def toStringProto(self, obj, protocol):
2202
Encode identifying information about L{object} and protocol
2203
into a string for later verification.
2205
@type obj: L{object}
2206
@type protocol: L{amp.AMP}
2208
return "%s:%s" % (id(obj), id(protocol))
2212
class ProtocolIncludingCommand(amp.Command):
2214
A command that has argument and response schemas which use
2215
L{ProtocolIncludingArgument}.
2217
arguments = [('weird', ProtocolIncludingArgument())]
2218
response = [('weird', ProtocolIncludingArgument())]
2222
class MagicSchemaCommand(amp.Command):
2224
A command which overrides L{parseResponse}, L{parseArguments}, and
2227
def parseResponse(self, strings, protocol):
2229
Don't do any parsing, just jam the input strings and protocol
2230
onto the C{protocol.parseResponseArguments} attribute as a
2231
two-tuple. Return the original strings.
2233
protocol.parseResponseArguments = (strings, protocol)
2235
parseResponse = classmethod(parseResponse)
2238
def parseArguments(cls, strings, protocol):
2240
Don't do any parsing, just jam the input strings and protocol
2241
onto the C{protocol.parseArgumentsArguments} attribute as a
2242
two-tuple. Return the original strings.
2244
protocol.parseArgumentsArguments = (strings, protocol)
2246
parseArguments = classmethod(parseArguments)
2249
def makeArguments(cls, objects, protocol):
2251
Don't do any serializing, just jam the input strings and protocol
2252
onto the C{protocol.makeArgumentsArguments} attribute as a
2253
two-tuple. Return the original strings.
2255
protocol.makeArgumentsArguments = (objects, protocol)
2257
makeArguments = classmethod(makeArguments)
2261
class NoNetworkProtocol(amp.AMP):
2263
An L{amp.AMP} subclass which overrides private methods to avoid
2264
testing the network. It also provides a responder for
2265
L{MagicSchemaCommand} that does nothing, so that tests can test
2266
aspects of the interaction of L{amp.Command}s and L{amp.AMP}.
2268
@ivar parseArgumentsArguments: Arguments that have been passed to any
2269
L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2272
@ivar parseResponseArguments: Responses that have been returned from a
2273
L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2276
@ivar makeArgumentsArguments: Arguments that have been serialized by any
2277
L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2280
def _sendBoxCommand(self, commandName, strings, requiresAnswer):
2282
Return a Deferred which fires with the original strings.
2284
return defer.succeed(strings)
2286
MagicSchemaCommand.responder(lambda s, weird: {})
2292
A unique dict subclass.
2297
class ProtocolIncludingCommandWithDifferentCommandType(
2298
ProtocolIncludingCommand):
2300
A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
2306
class CommandTestCase(unittest.TestCase):
2308
Tests for L{amp.Argument} and L{amp.Command}.
2310
def test_argumentInterface(self):
2312
L{Argument} instances provide L{amp.IArgumentType}.
2314
self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument()))
2317
def test_parseResponse(self):
2319
There should be a class method of Command which accepts a
2320
mapping of argument names to serialized forms and returns a
2321
similar mapping whose values have been parsed via the
2322
Command's response schema.
2326
strings = {'weird': result}
2328
ProtocolIncludingCommand.parseResponse(strings, protocol),
2329
{'weird': (result, protocol)})
2332
def test_callRemoteCallsParseResponse(self):
2334
Making a remote call on a L{amp.Command} subclass which
2335
overrides the C{parseResponse} method should call that
2336
C{parseResponse} method to get the response.
2338
client = NoNetworkProtocol()
2340
response = client.callRemote(MagicSchemaCommand, weird=thingy)
2341
def gotResponse(ign):
2342
self.assertEquals(client.parseResponseArguments,
2343
({"weird": thingy}, client))
2344
response.addCallback(gotResponse)
2348
def test_parseArguments(self):
2350
There should be a class method of L{amp.Command} which accepts
2351
a mapping of argument names to serialized forms and returns a
2352
similar mapping whose values have been parsed via the
2353
command's argument schema.
2357
strings = {'weird': result}
2359
ProtocolIncludingCommand.parseArguments(strings, protocol),
2360
{'weird': (result, protocol)})
2363
def test_responderCallsParseArguments(self):
2365
Making a remote call on a L{amp.Command} subclass which
2366
overrides the C{parseArguments} method should call that
2367
C{parseArguments} method to get the arguments.
2369
protocol = NoNetworkProtocol()
2370
responder = protocol.locateResponder(MagicSchemaCommand.commandName)
2372
response = responder(dict(weird=argument))
2373
response.addCallback(
2374
lambda ign: self.assertEqual(protocol.parseArgumentsArguments,
2375
({"weird": argument}, protocol)))
2379
def test_makeArguments(self):
2381
There should be a class method of L{amp.Command} which accepts
2382
a mapping of argument names to objects and returns a similar
2383
mapping whose values have been serialized via the command's
2388
objects = {'weird': argument}
2390
ProtocolIncludingCommand.makeArguments(objects, protocol),
2391
{'weird': "%d:%d" % (id(argument), id(protocol))})
2394
def test_makeArgumentsUsesCommandType(self):
2396
L{amp.Command.makeArguments}'s return type should be the type
2397
of the result of L{amp.Command.commandType}.
2400
objects = {"weird": "whatever"}
2402
result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
2404
self.assertIdentical(type(result), MyBox)
2407
def test_callRemoteCallsMakeArguments(self):
2409
Making a remote call on a L{amp.Command} subclass which
2410
overrides the C{makeArguments} method should call that
2411
C{makeArguments} method to get the response.
2413
client = NoNetworkProtocol()
2415
response = client.callRemote(MagicSchemaCommand, weird=argument)
2416
def gotResponse(ign):
2417
self.assertEqual(client.makeArgumentsArguments,
2418
({"weird": argument}, client))
2419
response.addCallback(gotResponse)
2423
def test_extraArgumentsDisallowed(self):
2425
L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects
2426
dictionary passed to it includes a key which does not correspond to the
2427
Python identifier for a defined argument.
2430
amp.InvalidSignature,
2431
Hello.makeArguments,
2432
dict(hello="hello", bogusArgument=object()), None)
2435
def test_wireSpellingDisallowed(self):
2437
If a command argument conflicts with a Python keyword, the
2438
untransformed argument name is not allowed as a key in the dictionary
2439
passed to L{Command.makeArguments}. If it is supplied,
2440
L{amp.InvalidSignature} is raised.
2442
This may be a pointless implementation restriction which may be lifted.
2443
The current behavior is tested to verify that such arguments are not
2444
silently dropped on the floor (the previous behavior).
2447
amp.InvalidSignature,
2448
Hello.makeArguments,
2449
dict(hello="required", **{"print": "print value"}),
2453
class ListOfTestsMixin:
2455
Base class for testing L{ListOf}, a parameterized zero-or-more argument
2458
@ivar elementType: Subclasses should set this to an L{Argument}
2459
instance. The tests will make a L{ListOf} using this.
2461
@ivar strings: Subclasses should set this to a dictionary mapping some
2462
number of keys to the correct serialized form for some example
2463
values. These should agree with what L{elementType}
2466
@ivar objects: Subclasses should set this to a dictionary with the same
2467
keys as C{strings} and with values which are the lists which should
2468
serialize to the values in the C{strings} dictionary.
2470
def test_toBox(self):
2472
L{ListOf.toBox} extracts the list of objects from the C{objects}
2473
dictionary passed to it, using the C{name} key also passed to it,
2474
serializes each of the elements in that list using the L{Argument}
2475
instance previously passed to its initializer, combines the serialized
2476
results, and inserts the result into the C{strings} dictionary using
2477
the same C{name} key.
2479
stringList = amp.ListOf(self.elementType)
2480
strings = amp.AmpBox()
2481
for key in self.objects:
2482
stringList.toBox(key, strings, self.objects.copy(), None)
2483
self.assertEquals(strings, self.strings)
2486
def test_fromBox(self):
2488
L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
2490
stringList = amp.ListOf(self.elementType)
2492
for key in self.strings:
2493
stringList.fromBox(key, self.strings.copy(), objects, None)
2494
self.assertEquals(objects, self.objects)
2498
class ListOfStringsTests(unittest.TestCase, ListOfTestsMixin):
2500
Tests for L{ListOf} combined with L{String}.
2502
elementType = amp.String()
2506
"single": "\x00\x03foo",
2507
"multiple": "\x00\x03bar\x00\x03baz\x00\x04quux"}
2512
"multiple": ["bar", "baz", "quux"]}
2515
class ListOfIntegersTests(unittest.TestCase, ListOfTestsMixin):
2517
Tests for L{ListOf} combined with L{Integer}.
2519
elementType = amp.Integer()
2523
"single": "\x00\x0210",
2524
"multiple": "\x00\x011\x00\x0220\x00\x03500"}
2529
"multiple": [1, 20, 500]}
2532
class ListOfUnicodeTests(unittest.TestCase, ListOfTestsMixin):
2534
Tests for L{ListOf} combined with L{Unicode}.
2536
elementType = amp.Unicode()
2540
"single": "\x00\x03foo",
2541
"multiple": "\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world"}
2546
"multiple": [u"\N{SNOWMAN}", u"Hello", u"world"]}
2550
if not interfaces.IReactorSSL.providedBy(reactor):
2551
skipMsg = 'This test case requires SSL support in the reactor'
2552
TLSTest.skip = skipMsg
2553
LiveFireTLSTestCase.skip = skipMsg
2554
PlainVanillaLiveFire.skip = skipMsg
2555
WithServerTLSVerification.skip = skipMsg