1
# Copyright 2005 Divmod, Inc. See LICENSE file for details
3
from twisted.python import filepath
4
from twisted.protocols import amp
5
from twisted.test import iosim
6
from twisted.trial import unittest
7
from twisted.internet import protocol, defer, error
9
from twisted.internet.error import PeerVerifyError
11
class TestProto(protocol.Protocol):
12
def __init__(self, onConnLost, dataToSend):
13
self.onConnLost = onConnLost
14
self.dataToSend = dataToSend
16
def connectionMade(self):
18
self.transport.write(self.dataToSend)
20
def dataReceived(self, bytes):
21
self.data.append(bytes)
22
# self.transport.loseConnection()
24
def connectionLost(self, reason):
25
self.onConnLost.callback(self.data)
27
class SimpleSymmetricProtocol(amp.AMP):
29
def sendHello(self, text):
30
return self.callRemoteString(
34
def amp_HELLO(self, box):
35
return amp.Box(hello=box['hello'])
37
def amp_HOWDOYOUDO(self, box):
38
return amp.QuitBox(howdoyoudo='world')
40
class UnfriendlyGreeting(Exception):
41
"""Greeting was insufficiently kind.
44
class DeathThreat(Exception):
45
"""Greeting was insufficiently kind.
48
class UnknownProtocol(Exception):
49
"""Asked to switch to the wrong protocol.
53
class TransportPeer(amp.Argument):
54
# this serves as some informal documentation for how to get variables from
55
# the protocol or your environment and pass them to methods as arguments.
56
def retrieve(self, d, name, proto):
59
def fromStringProto(self, notAString, proto):
60
return proto.transport.getPeer()
62
def toBox(self, name, strings, objects, proto):
65
class Hello(amp.Command):
69
arguments = [('hello', amp.String()),
70
('optional', amp.Boolean(optional=True)),
71
('print', amp.Unicode(optional=True)),
72
('from', TransportPeer(optional=True)),
73
('mixedCase', amp.String(optional=True)),
74
('dash-arg', amp.String(optional=True)),
75
('underscore_arg', amp.String(optional=True))]
77
response = [('hello', amp.String()),
78
('print', amp.Unicode(optional=True))]
80
errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
82
fatalErrors = {DeathThreat: 'DEAD'}
84
class NoAnswerHello(Hello):
85
commandName = Hello.commandName
86
requiresAnswer = False
88
class FutureHello(amp.Command):
91
arguments = [('hello', amp.String()),
92
('optional', amp.Boolean(optional=True)),
93
('print', amp.Unicode(optional=True)),
94
('from', TransportPeer(optional=True)),
95
('bonus', amp.String(optional=True)), # addt'l arguments
97
# added at the end, and
101
response = [('hello', amp.String()),
102
('print', amp.Unicode(optional=True))]
104
errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
106
class WTF(amp.Command):
108
An example of an invalid command.
112
class BrokenReturn(amp.Command):
113
""" An example of a perfectly good command, but the handler is going to return
117
commandName = 'broken_return'
119
class Goodbye(amp.Command):
120
# commandName left blank on purpose: this tests implicit command names.
121
response = [('goodbye', amp.String())]
122
responseType = amp.QuitBox
124
class Howdoyoudo(amp.Command):
125
commandName = 'howdoyoudo'
126
# responseType = amp.QuitBox
128
class WaitForever(amp.Command):
129
commandName = 'wait_forever'
131
class GetList(amp.Command):
132
commandName = 'getlist'
133
arguments = [('length', amp.Integer())]
134
response = [('body', amp.AmpList([('x', amp.Integer())]))]
136
class SecuredPing(amp.Command):
137
# XXX TODO: actually make this refuse to send over an insecure connection
138
response = [('pinged', amp.Boolean())]
140
class TestSwitchProto(amp.ProtocolSwitchCommand):
141
commandName = 'Switch-Proto'
144
('name', amp.String()),
146
errors = {UnknownProtocol: 'UNKNOWN'}
148
class SingleUseFactory(protocol.ClientFactory):
149
def __init__(self, proto):
151
self.proto.factory = self
153
def buildProtocol(self, addr):
154
p, self.proto = self.proto, None
159
def clientConnectionFailed(self, connector, reason):
160
self.reasonFailed = reason
163
THING_I_DONT_UNDERSTAND = 'gwebol nargo'
164
class ThingIDontUnderstandError(Exception):
167
class FactoryNotifier(amp.AMP):
169
def connectionMade(self):
170
if self.factory is not None:
171
self.factory.theProto = self
172
if hasattr(self.factory, 'onMade'):
173
self.factory.onMade.callback(None)
176
from twisted.internet.interfaces import ISSLTransport
177
if not ISSLTransport.providedBy(self.transport):
178
raise DeathThreat("only send secure pings over secure channels")
179
return {'pinged': True}
180
SecuredPing.responder(emitpong)
183
class SimpleSymmetricCommandProtocol(FactoryNotifier):
185
def __init__(self, onConnLost=None):
186
amp.AMP.__init__(self)
187
self.onConnLost = onConnLost
189
def sendHello(self, text):
190
return self.callRemote(Hello, hello=text)
192
def sendUnicodeHello(self, text, translation):
193
return self.callRemote(Hello, hello=text, Print=translation)
197
def cmdHello(self, hello, From, optional=None, Print=None,
198
mixedCase=None, dash_arg=None, underscore_arg=None):
199
assert From == self.transport.getPeer()
200
if hello == THING_I_DONT_UNDERSTAND:
201
raise ThingIDontUnderstandError()
202
if hello.startswith('fuck'):
203
raise UnfriendlyGreeting("Don't be a dick.")
205
raise DeathThreat("aieeeeeeeee")
206
result = dict(hello=hello)
207
if Print is not None:
208
result.update(dict(Print=Print))
211
Hello.responder(cmdHello)
213
def cmdGetlist(self, length):
214
return {'body': [dict(x=1)] * length}
215
GetList.responder(cmdGetlist)
218
self.waiting = defer.Deferred()
220
WaitForever.responder(waitforit)
223
return dict(howdoyoudo='world')
224
Howdoyoudo.responder(howdo)
227
return dict(goodbye="everyone")
228
Goodbye.responder(saybye)
230
def switchToTestProtocol(self, fail=False):
235
p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
236
return self.callRemote(
238
SingleUseFactory(p), name=name).addCallback(lambda ign: p)
240
def switchit(self, name):
241
if name == 'test-proto':
242
return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
243
raise UnknownProtocol(name)
244
TestSwitchProto.responder(switchit)
248
BrokenReturn.responder(donothing)
251
class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
252
def switchit(self, name):
253
if name == 'test-proto':
254
self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
255
self.maybeLater = defer.Deferred()
256
return self.maybeLater
257
raise UnknownProtocol(name)
258
TestSwitchProto.responder(switchit)
260
class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
261
def badResponder(self, hello, From, optional=None, Print=None,
262
mixedCase=None, dash_arg=None, underscore_arg=None):
264
This responder does nothing and forgets to return a dictionary.
266
NoAnswerHello.responder(badResponder)
268
class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
269
def goodNoAnswerResponder(self, hello, From, optional=None, Print=None,
270
mixedCase=None, dash_arg=None, underscore_arg=None):
271
return dict(hello=hello+"-noanswer")
272
NoAnswerHello.responder(goodNoAnswerResponder)
274
def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol,
275
ClientClass=SimpleSymmetricProtocol,
277
"""Returns a 3-tuple: (client, server, pump)
279
return iosim.connectedServerAndClient(
280
ServerClass, ClientClass,
283
class TotallyDumbProtocol(protocol.Protocol):
285
def dataReceived(self, data):
288
class LiteralAmp(amp.AMP):
292
def ampBoxReceived(self, box):
293
self.boxes.append(box)
296
class ParsingTest(unittest.TestCase):
298
def test_booleanValues(self):
300
Verify that the Boolean parser parses 'True' and 'False', but nothing
304
self.assertEquals(b.fromString("True"), True)
305
self.assertEquals(b.fromString("False"), False)
306
self.assertRaises(TypeError, b.fromString, "ninja")
307
self.assertRaises(TypeError, b.fromString, "true")
308
self.assertRaises(TypeError, b.fromString, "TRUE")
309
self.assertEquals(b.toString(True), 'True')
310
self.assertEquals(b.toString(False), 'False')
312
def test_pathValueRoundTrip(self):
314
Verify the 'Path' argument can parse and emit a file path.
316
fp = filepath.FilePath(self.mktemp())
320
self.assertNotIdentical(fp, v) # sanity check
321
self.assertEquals(fp, v)
324
def test_sillyEmptyThing(self):
326
Test that empty boxes raise an error; they aren't supposed to be sent
330
return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
333
def test_ParsingRoundTrip(self):
335
Verify that various kinds of data make it through the encode/parse
338
c, s, p = connectedServerAndClient(ClientClass=LiteralAmp,
339
ServerClass=LiteralAmp)
341
SIMPLE = ('simple', 'test')
343
CR = ('crtest', 'test\r')
344
LF = ('lftest', 'hello\n')
345
NEWLINE = ('newline', 'test\r\none\r\ntwo')
346
NEWLINE2 = ('newline2', 'test\r\none\r\n two')
347
BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline')
348
BODYTEST = ('body', 'blah\r\n\r\ntesttest')
355
[SIMPLE, CE, CR, LF],
357
[SIMPLE, NEWLINE, CE, NEWLINE2],
358
[BODYTEST, SIMPLE, NEWLINE]
361
for test in testData:
363
jb.update(dict(test))
366
self.assertEquals(s.boxes[-1], jb)
368
SWITCH_CLIENT_DATA = 'Success!'
369
SWITCH_SERVER_DATA = 'No, really. Success.'
371
class AMPTest(unittest.TestCase):
373
def test_helloWorld(self):
375
Verify that a simple command can be sent and its response received with
376
the simple low-level string-based API.
378
c, s, p = connectedServerAndClient()
381
c.sendHello(HELLO).addCallback(L.append)
383
self.assertEquals(L[0]['hello'], HELLO)
386
def test_wireFormatRoundTrip(self):
388
Verify that mixed-case, underscored and dashed arguments are mapped to
389
their python names properly.
391
c, s, p = connectedServerAndClient()
394
c.sendHello(HELLO).addCallback(L.append)
396
self.assertEquals(L[0]['hello'], HELLO)
399
def test_helloWorldUnicode(self):
401
Verify that unicode arguments can be encoded and decoded.
403
c, s, p = connectedServerAndClient(
404
ServerClass=SimpleSymmetricCommandProtocol,
405
ClientClass=SimpleSymmetricCommandProtocol)
408
HELLO_UNICODE = 'wor\u1234ld'
409
c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
411
self.assertEquals(L[0]['hello'], HELLO)
412
self.assertEquals(L[0]['Print'], HELLO_UNICODE)
415
def test_unknownCommandLow(self):
417
Verify that unknown commands using low-level APIs will be rejected with an
418
error, but will NOT terminate the connection.
420
c, s, p = connectedServerAndClient()
424
You can't propagate the error...
426
e.trap(amp.UnhandledCommand)
428
c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append)
430
self.assertEquals(L.pop(), "OK")
432
c.sendHello(HELLO).addCallback(L.append)
434
self.assertEquals(L[0]['hello'], HELLO)
437
def test_unknownCommandHigh(self):
439
Verify that unknown commands using high-level APIs will be rejected with an
440
error, but will NOT terminate the connection.
442
c, s, p = connectedServerAndClient()
446
You can't propagate the error...
448
e.trap(amp.UnhandledCommand)
450
c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
452
self.assertEquals(L.pop(), "OK")
454
c.sendHello(HELLO).addCallback(L.append)
456
self.assertEquals(L[0]['hello'], HELLO)
459
def test_brokenReturnValue(self):
461
It can be very confusing if you write some code which responds to a
462
command, but gets the return value wrong. Most commonly you end up
463
returning None instead of a dictionary.
465
Verify that if that happens, the framework logs a useful error.
468
SimpleSymmetricCommandProtocol().dispatchCommand(
469
amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append)
470
blr = L[0].trap(amp.BadLocalReturn)
471
self.failUnlessIn('None', repr(L[0].value))
475
def test_unknownArgument(self):
477
Verify that unknown arguments are ignored, and not passed to a Python
478
function which can't accept them.
480
c, s, p = connectedServerAndClient(
481
ServerClass=SimpleSymmetricCommandProtocol,
482
ClientClass=SimpleSymmetricCommandProtocol)
485
# c.sendHello(HELLO).addCallback(L.append)
486
c.callRemote(FutureHello,
488
bonus="I'm not in the book!").addCallback(
491
self.assertEquals(L[0]['hello'], HELLO)
494
def test_simpleReprs(self):
496
Verify that the various Box objects repr properly, for debugging.
498
self.assertEquals(type(repr(amp._TLSBox())), str)
499
self.assertEquals(type(repr(amp._SwitchBox('a'))), str)
500
self.assertEquals(type(repr(amp.QuitBox())), str)
501
self.assertEquals(type(repr(amp.AmpBox())), str)
502
self.failUnless("AmpBox" in repr(amp.AmpBox()))
504
def test_keyTooLong(self):
506
Verify that a key that is too long will immediately raise a synchronous
509
c, s, p = connectedServerAndClient()
512
tl = self.assertRaises(amp.TooLong,
513
c.callRemoteString, "Hello",
515
self.failUnless(tl.isKey)
516
self.failUnless(tl.isLocal)
517
self.failUnlessIdentical(tl.keyName, None)
518
self.failUnlessIdentical(tl.value, x)
519
self.failUnless(str(len(x)) in repr(tl))
520
self.failUnless("key" in repr(tl))
523
def test_valueTooLong(self):
525
Verify that attempting to send value longer than 64k will immediately
528
c, s, p = connectedServerAndClient()
531
tl = self.assertRaises(amp.TooLong, c.sendHello, x)
533
self.failIf(tl.isKey)
534
self.failUnless(tl.isLocal)
535
self.failUnlessIdentical(tl.keyName, 'hello')
536
self.failUnlessIdentical(tl.value, x)
537
self.failUnless(str(len(x)) in repr(tl))
538
self.failUnless("value" in repr(tl))
539
self.failUnless('hello' in repr(tl))
542
def test_helloWorldCommand(self):
544
Verify that a simple command can be sent and its response received with
545
the high-level value parsing API.
547
c, s, p = connectedServerAndClient(
548
ServerClass=SimpleSymmetricCommandProtocol,
549
ClientClass=SimpleSymmetricCommandProtocol)
552
c.sendHello(HELLO).addCallback(L.append)
554
self.assertEquals(L[0]['hello'], HELLO)
557
def test_helloErrorHandling(self):
559
Verify that if a known error type is raised and handled, it will be
560
properly relayed to the other end of the connection and translated into
561
an exception, and no error will be logged.
564
c, s, p = connectedServerAndClient(
565
ServerClass=SimpleSymmetricCommandProtocol,
566
ClientClass=SimpleSymmetricCommandProtocol)
568
c.sendHello(HELLO).addErrback(L.append)
570
L[0].trap(UnfriendlyGreeting)
571
self.assertEquals(str(L[0].value), "Don't be a dick.")
574
def test_helloFatalErrorHandling(self):
576
Verify that if a known, fatal error type is raised and handled, it will
577
be properly relayed to the other end of the connection and translated
578
into an exception, no error will be logged, and the connection will be
582
c, s, p = connectedServerAndClient(
583
ServerClass=SimpleSymmetricCommandProtocol,
584
ClientClass=SimpleSymmetricCommandProtocol)
586
c.sendHello(HELLO).addErrback(L.append)
588
L.pop().trap(DeathThreat)
589
c.sendHello(HELLO).addErrback(L.append)
591
L.pop().trap(error.ConnectionDone)
595
def test_helloNoErrorHandling(self):
597
Verify that if an unknown error type is raised, it will be relayed to
598
the other end of the connection and translated into an exception, it
599
will be logged, and then the connection will be dropped.
602
c, s, p = connectedServerAndClient(
603
ServerClass=SimpleSymmetricCommandProtocol,
604
ClientClass=SimpleSymmetricCommandProtocol)
605
HELLO = THING_I_DONT_UNDERSTAND
606
c.sendHello(HELLO).addErrback(L.append)
609
ure.trap(amp.UnknownRemoteError)
610
c.sendHello(HELLO).addErrback(L.append)
612
cl.trap(error.ConnectionDone)
613
# The exception should have been logged.
614
self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError))
618
def test_lateAnswer(self):
620
Verify that a command that does not get answered until after the
621
connection terminates will not cause any errors.
623
c, s, p = connectedServerAndClient(
624
ServerClass=SimpleSymmetricCommandProtocol,
625
ClientClass=SimpleSymmetricCommandProtocol)
628
c.callRemote(WaitForever).addErrback(L.append)
630
self.assertEquals(L, [])
631
s.transport.loseConnection()
633
L.pop().trap(error.ConnectionDone)
634
# Just make sure that it doesn't error...
635
s.waiting.callback({})
639
def test_requiresNoAnswer(self):
641
Verify that a command that requires no answer is run.
644
c, s, p = connectedServerAndClient(
645
ServerClass=SimpleSymmetricCommandProtocol,
646
ClientClass=SimpleSymmetricCommandProtocol)
648
c.callRemote(NoAnswerHello, hello=HELLO)
650
self.failUnless(s.greeted)
653
def test_requiresNoAnswerFail(self):
655
Verify that commands sent after a failed no-answer request do not complete.
658
c, s, p = connectedServerAndClient(
659
ServerClass=SimpleSymmetricCommandProtocol,
660
ClientClass=SimpleSymmetricCommandProtocol)
662
c.callRemote(NoAnswerHello, hello=HELLO)
664
# This should be logged locally.
665
self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError))
667
c.callRemote(Hello, hello=HELLO).addErrback(L.append)
669
L.pop().trap(error.ConnectionDone)
670
self.failIf(s.greeted)
673
def test_noAnswerResponderBadAnswer(self):
675
Verify that responders of requiresAnswer=False commands have to return
678
(requiresAnswer is a hint from the _client_ - the server may be called
679
upon to answer commands in any case, if the client wants to know when
682
c, s, p = connectedServerAndClient(
683
ServerClass=BadNoAnswerCommandProtocol,
684
ClientClass=SimpleSymmetricCommandProtocol)
685
c.callRemote(NoAnswerHello, hello="hello")
687
le = self.flushLoggedErrors(amp.BadLocalReturn)
688
self.assertEquals(len(le), 1)
691
def test_noAnswerResponderAskedForAnswer(self):
693
Verify that responders with requiresAnswer=False will actually respond
694
if the client sets requiresAnswer=True. In other words, verify that
695
requiresAnswer is a hint honored only by the client.
697
c, s, p = connectedServerAndClient(
698
ServerClass=NoAnswerCommandProtocol,
699
ClientClass=SimpleSymmetricCommandProtocol)
701
c.callRemote(Hello, hello="Hello!").addCallback(L.append)
703
self.assertEquals(len(L), 1)
704
self.assertEquals(L, [dict(hello="Hello!-noanswer",
705
Print=None)]) # Optional response argument
708
def test_ampListCommand(self):
710
Test encoding of an argument that uses the AmpList encoding.
712
c, s, p = connectedServerAndClient(
713
ServerClass=SimpleSymmetricCommandProtocol,
714
ClientClass=SimpleSymmetricCommandProtocol)
716
c.callRemote(GetList, length=10).addCallback(L.append)
718
values = L.pop().get('body')
719
self.assertEquals(values, [{'x': 1}] * 10)
722
def test_failEarlyOnArgSending(self):
724
Verify that if we pass an invalid argument list (omitting an argument), an
725
exception will be raised.
727
okayCommand = Hello(hello="What?")
728
self.assertRaises(amp.InvalidSignature, Hello)
731
def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol,
732
spuriousTraffic=False):
734
Verify that it is possible to switch to another protocol mid-connection and
735
send data to it successfully.
737
self.testSucceeded = False
739
serverDeferred = defer.Deferred()
740
serverProto = switcher(serverDeferred)
741
clientDeferred = defer.Deferred()
742
clientProto = switcher(clientDeferred)
743
c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
744
ClientClass=lambda: clientProto)
748
wfd = c.callRemote(WaitForever).addErrback(wfdr.append)
749
switchDeferred = c.switchToTestProtocol()
751
self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world')
753
def cbConnsLost(((serverSuccess, serverData),
754
(clientSuccess, clientData))):
755
self.failUnless(serverSuccess)
756
self.failUnless(clientSuccess)
757
self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA)
758
self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA)
759
self.testSucceeded = True
762
return defer.DeferredList(
763
[serverDeferred, clientDeferred]).addCallback(cbConnsLost)
765
switchDeferred.addCallback(cbSwitch)
767
if serverProto.maybeLater is not None:
768
serverProto.maybeLater.callback(serverProto.maybeLaterProto)
771
# switch is done here; do this here to make sure that if we're
772
# going to corrupt the connection, we do it before it's closed.
773
s.waiting.callback({})
775
c.transport.loseConnection() # close it
777
self.failUnless(self.testSucceeded)
780
def test_protocolSwitchDeferred(self):
782
Verify that protocol-switching even works if the value returned from
783
the command that does the switch is deferred.
785
return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
787
def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
789
Verify that if we try to switch protocols and it fails, the connection
790
stays up and we can go back to speaking AMP.
792
self.testSucceeded = False
794
serverDeferred = defer.Deferred()
795
serverProto = switcher(serverDeferred)
796
clientDeferred = defer.Deferred()
797
clientProto = switcher(clientDeferred)
798
c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
799
ClientClass=lambda: clientProto)
801
switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append)
803
L.pop().trap(UnknownProtocol)
804
self.failIf(self.testSucceeded)
805
# It's a known error, so let's send a "hello" on the same connection;
807
c.sendHello('world').addCallback(L.append)
809
self.assertEqual(L.pop()['hello'], 'world')
812
def test_trafficAfterSwitch(self):
814
Verify that attempts to send traffic after a switch will not corrupt
817
return self.test_protocolSwitch(spuriousTraffic=True)
820
def test_quitBoxQuits(self):
822
Verify that commands with a responseType of QuitBox will in fact
823
terminate the connection.
825
c, s, p = connectedServerAndClient(
826
ServerClass=SimpleSymmetricCommandProtocol,
827
ClientClass=SimpleSymmetricCommandProtocol)
832
c.sendHello(HELLO).addCallback(L.append)
834
self.assertEquals(L.pop()['hello'], HELLO)
835
c.callRemote(Goodbye).addCallback(L.append)
837
self.assertEquals(L.pop()['goodbye'], GOODBYE)
838
c.sendHello(HELLO).addErrback(L.append)
839
L.pop().trap(error.ConnectionDone)
843
def test_basicLiteralEmit(self):
845
Verify that the command dictionaries for a callRemoteN look correct
846
after being serialized and parsed.
848
c, s, p = connectedServerAndClient()
850
s.ampBoxReceived = L.append
851
c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
852
dash_arg='x', underscore_arg='y')
854
self.assertEquals(len(L), 1)
855
for k, v in [('_command', Hello.commandName),
856
('hello', 'hello test'),
857
('mixedCase', 'mixed case arg test'),
859
('underscore_arg', 'y')]:
860
self.assertEquals(L[-1].pop(k), v)
862
self.assertEquals(L[-1], {})
865
def test_basicStructuredEmit(self):
867
Verify that a call similar to basicLiteralEmit's is handled properly with
868
high-level quoting and passing to Python methods, and that argument
869
names are correctly handled.
872
class StructuredHello(amp.AMP):
873
def h(self, *a, **k):
875
return dict(hello='aaa')
877
c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
878
c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
879
dash_arg='x', underscore_arg='y').addCallback(L.append)
881
self.assertEquals(len(L), 2)
882
self.assertEquals(L[0],
885
mixedCase='mixed case arg test',
889
# XXX - should optional arguments just not be passed?
890
# passing None seems a little odd, looking at the way it
891
# turns out here... -glyph
892
From=('file', 'file'),
896
self.assertEquals(L[1], dict(Print=None, hello='aaa'))
898
class PretendRemoteCertificateAuthority:
899
def checkIsPretendRemote(self):
905
def options(self, *ign):
908
def iosimVerify(self, otherCert):
910
This isn't a real certificate, and wouldn't work on a real socket, but
911
iosim specifies a different API so that we don't have to do any crypto
912
math to demonstrate that the right functions get called in the right
915
assert otherCert is self
916
self.verifyCount += 1
919
class OKCert(IOSimCert):
920
def options(self, x):
921
assert x.checkIsPretendRemote()
924
class GrumpyCert(IOSimCert):
925
def iosimVerify(self, otherCert):
926
self.verifyCount += 1
929
class DroppyCert(IOSimCert):
930
def __init__(self, toDrop):
933
def iosimVerify(self, otherCert):
934
self.verifyCount += 1
935
self.toDrop.loseConnection()
938
class SecurableProto(FactoryNotifier):
942
def verifyFactory(self):
943
return [PretendRemoteCertificateAuthority()]
945
def getTLSVars(self):
946
cert = self.certFactory()
947
verify = self.verifyFactory()
949
tls_localCertificate=cert,
950
tls_verifyAuthorities=verify)
951
amp.StartTLS.responder(getTLSVars)
955
class TLSTest(unittest.TestCase):
956
def test_startingTLS(self):
958
Verify that starting TLS and succeeding at handshaking sends all the
959
notifications to all the right places.
961
cli, svr, p = connectedServerAndClient(
962
ServerClass=SecurableProto,
963
ClientClass=SecurableProto)
966
svr.certFactory = lambda : okc
969
amp.StartTLS, tls_localCertificate=okc,
970
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
972
# let's buffer something to be delivered securely
974
d = cli.callRemote(SecuredPing).addCallback(L.append)
976
# once for client once for server
977
self.assertEquals(okc.verifyCount, 2)
979
d = cli.callRemote(SecuredPing).addCallback(L.append)
981
self.assertEqual(L[0], {'pinged': True})
983
def test_startTooManyTimes(self):
985
Verify that the protocol will complain if we attempt to renegotiate TLS,
986
which we don't support.
988
cli, svr, p = connectedServerAndClient(
989
ServerClass=SecurableProto,
990
ClientClass=SecurableProto)
993
svr.certFactory = lambda : okc
995
# print c, c.transport
996
cli.callRemote(amp.StartTLS,
997
tls_localCertificate=okc,
998
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1000
cli.noPeerCertificate = True # this is totally fake
1005
tls_localCertificate=okc,
1006
tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1008
def test_negotiationFailed(self):
1010
Verify that starting TLS and failing on both sides at handshaking sends
1011
notifications to all the right places and terminates the connection.
1014
badCert = GrumpyCert()
1016
cli, svr, p = connectedServerAndClient(
1017
ServerClass=SecurableProto,
1018
ClientClass=SecurableProto)
1019
svr.certFactory = lambda : badCert
1021
cli.callRemote(amp.StartTLS,
1022
tls_localCertificate=badCert)
1025
# once for client once for server - but both fail
1026
self.assertEquals(badCert.verifyCount, 2)
1027
d = cli.callRemote(SecuredPing)
1029
self.assertFailure(d, iosim.OpenSSLVerifyError)
1031
def test_negotiationFailedByClosing(self):
1033
Verify that starting TLS and failing by way of a lost connection
1034
notices that it is probably an SSL problem.
1037
cli, svr, p = connectedServerAndClient(
1038
ServerClass=SecurableProto,
1039
ClientClass=SecurableProto)
1040
droppyCert = DroppyCert(svr.transport)
1041
svr.certFactory = lambda : droppyCert
1043
secure = cli.callRemote(amp.StartTLS,
1044
tls_localCertificate=droppyCert)
1048
self.assertEquals(droppyCert.verifyCount, 2)
1050
d = cli.callRemote(SecuredPing)
1053
# it might be a good idea to move this exception somewhere more
1055
self.assertFailure(d, PeerVerifyError)
1059
class InheritedError(Exception):
1061
This error is used to check inheritance.
1066
class OtherInheritedError(Exception):
1068
This is a distinct error for checking inheritance.
1073
class BaseCommand(amp.Command):
1075
This provides a command that will be subclassed.
1077
errors = {InheritedError: 'INHERITED_ERROR'}
1081
class InheritedCommand(BaseCommand):
1083
This is a command which subclasses another command but does not override
1089
class AddErrorsCommand(BaseCommand):
1091
This is a command which subclasses another command but adds errors to the
1094
arguments = [('other', amp.Boolean())]
1095
errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'}
1099
class NormalCommandProtocol(amp.AMP):
1101
This is a protocol which responds to L{BaseCommand}, and is used to test
1102
that inheritance does not interfere with the normal handling of errors.
1105
raise InheritedError()
1106
BaseCommand.responder(resp)
1110
class InheritedCommandProtocol(amp.AMP):
1112
This is a protocol which responds to L{InheritedCommand}, and is used to
1113
test that inherited commands inherit their bases' errors if they do not
1114
respond to any of their own.
1117
raise InheritedError()
1118
InheritedCommand.responder(resp)
1122
class AddedCommandProtocol(amp.AMP):
1124
This is a protocol which responds to L{AddErrorsCommand}, and is used to
1125
test that inherited commands can add their own new types of errors, but
1126
still respond in the same way to their parents types of errors.
1128
def resp(self, other):
1130
raise OtherInheritedError()
1132
raise InheritedError()
1133
AddErrorsCommand.responder(resp)
1137
class CommandInheritanceTests(unittest.TestCase):
1139
These tests verify that commands inherit error conditions properly.
1142
def errorCheck(self, err, proto, cmd, **kw):
1144
Check that the appropriate kind of error is raised when a given command
1145
is sent to a given protocol.
1147
c, s, p = connectedServerAndClient(ServerClass=proto,
1149
d = c.callRemote(cmd, **kw)
1150
d2 = self.failUnlessFailure(d, err)
1155
def test_basicErrorPropagation(self):
1157
Verify that errors specified in a superclass are respected normally
1158
even if it has subclasses.
1160
return self.errorCheck(
1161
InheritedError, NormalCommandProtocol, BaseCommand)
1164
def test_inheritedErrorPropagation(self):
1166
Verify that errors specified in a superclass command are propagated to
1169
return self.errorCheck(
1170
InheritedError, InheritedCommandProtocol, InheritedCommand)
1173
def test_inheritedErrorAddition(self):
1175
Verify that new errors specified in a subclass of an existing command
1176
are honored even if the superclass defines some errors.
1178
return self.errorCheck(
1179
OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True)
1182
def test_additionWithOriginalError(self):
1184
Verify that errors specified in a command's superclass are respected
1185
even if that command defines new errors itself.
1187
return self.errorCheck(
1188
InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False)
1192
def _loseAndPass(err, proto):
1193
# be specific, pass on the error to the client.
1194
err.trap(error.ConnectionLost, error.ConnectionDone)
1195
del proto.connectionLost
1196
proto.connectionLost(err)
1200
Utility for connected reactor-using tests.
1204
from twisted.internet import reactor
1205
self.serverFactory = protocol.ServerFactory()
1206
self.serverFactory.protocol = self.serverProto
1207
self.clientFactory = protocol.ClientFactory()
1208
self.clientFactory.protocol = self.clientProto
1209
self.clientFactory.onMade = defer.Deferred()
1210
self.serverFactory.onMade = defer.Deferred()
1211
self.serverPort = reactor.listenTCP(0, self.serverFactory)
1212
self.clientConn = reactor.connectTCP(
1213
'127.0.0.1', self.serverPort.getHost().port,
1215
def getProtos(rlst):
1216
self.cli = self.clientFactory.theProto
1217
self.svr = self.serverFactory.theProto
1218
dl = defer.DeferredList([self.clientFactory.onMade,
1219
self.serverFactory.onMade])
1220
return dl.addCallback(getProtos)
1224
for conn in self.cli, self.svr:
1225
if conn.transport is not None:
1226
# depend on amp's function connection-dropping behavior
1227
d = defer.Deferred().addErrback(_loseAndPass, conn)
1228
conn.connectionLost = d.errback
1229
conn.transport.loseConnection()
1231
if self.serverPort is not None:
1232
L.append(defer.maybeDeferred(self.serverPort.stopListening))
1233
if self.clientConn is not None:
1234
self.clientConn.disconnect()
1235
return defer.DeferredList(L)
1239
sys.stdout.write(x+'\n')
1242
def tempSelfSigned():
1243
from twisted.internet import ssl
1245
sharedDN = ssl.DN(CN='shared')
1246
key = ssl.KeyPair.generate()
1247
cr = key.certificateRequest(sharedDN)
1248
sscrd = key.signCertificateRequest(
1249
sharedDN, cr, lambda dn: True, 1234567)
1250
cert = key.newCertificate(sscrd)
1253
tempcert = tempSelfSigned()
1255
class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase):
1256
clientProto = SecurableProto
1257
serverProto = SecurableProto
1258
def test_liveFireCustomTLS(self):
1260
Using real, live TLS, actually negotiate a connection.
1262
This also looks at the 'peerCertificate' attribute's correctness, since
1263
that's actually loaded using OpenSSL calls, but the main purpose is to
1264
make sure that we didn't miss anything obvious in iosim about TLS
1270
self.svr.verifyFactory = lambda : [cert]
1271
self.svr.certFactory = lambda : cert
1272
# only needed on the server, we specify the client below.
1277
# Interesting. OpenSSL won't even _tell_ us about the peer
1278
# cert until we negotiate. we should be able to do this in
1279
# 'secured' instead, but it looks like we can't. I think this
1280
# is a bug somewhere far deeper than here.
1281
self.failUnlessEqual(x, self.cli.hostCertificate.digest())
1282
self.failUnlessEqual(x, self.cli.peerCertificate.digest())
1283
self.failUnlessEqual(x, self.svr.hostCertificate.digest())
1284
self.failUnlessEqual(x, self.svr.peerCertificate.digest())
1285
return self.cli.callRemote(SecuredPing).addCallback(pinged)
1286
return self.cli.callRemote(amp.StartTLS,
1287
tls_localCertificate=cert,
1288
tls_verifyAuthorities=[cert]).addCallback(secured)
1290
class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
1292
return dict(tls_localCertificate=tempcert)
1294
class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase):
1296
clientProto = SimpleSymmetricCommandProtocol
1297
serverProto = SimpleSymmetricCommandProtocol
1299
def test_liveFireDefaultTLS(self):
1301
Verify that out of the box, we can start TLS to at least encrypt the
1302
connection, even if we don't have any certificates to use.
1304
def secured(result):
1305
return self.cli.callRemote(SecuredPing)
1306
return self.cli.callRemote(amp.StartTLS).addCallback(secured)
1308
class WithServerTLSVerification(LiveFireBase, unittest.TestCase):
1309
clientProto = SimpleSymmetricCommandProtocol
1310
serverProto = SlightlySmartTLS
1312
def test_anonymousVerifyingClient(self):
1314
Verify that anonymous clients can verify server certificates.
1316
def secured(result):
1317
return self.cli.callRemote(SecuredPing)
1318
return self.cli.callRemote(amp.StartTLS, tls_verifyAuthorities=[tempcert])