~ntt-pf-lab/nova/monkey_patch_notification

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/test/test_amp.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2005 Divmod, Inc.
 
2
# Copyright (c) 2007-2009 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
"""
 
6
Tests for L{twisted.protocols.amp}.
 
7
"""
 
8
 
 
9
from zope.interface.verify import verifyObject
 
10
 
 
11
from twisted.python.util import setIDFunction
 
12
from twisted.python import filepath
 
13
from twisted.python.failure import Failure
 
14
from twisted.protocols import amp
 
15
from twisted.trial import unittest
 
16
from twisted.internet import protocol, defer, error, reactor, interfaces
 
17
from twisted.test import iosim
 
18
from twisted.test.proto_helpers import StringTransport
 
19
 
 
20
try:
 
21
    from twisted.internet import ssl
 
22
except ImportError:
 
23
    ssl = None
 
24
if ssl and not ssl.supported:
 
25
    ssl = None
 
26
 
 
27
if ssl is None:
 
28
    skipSSL = "SSL not available"
 
29
else:
 
30
    skipSSL = None
 
31
 
 
32
 
 
33
class TestProto(protocol.Protocol):
 
34
    """
 
35
    A trivial protocol for use in testing where a L{Protocol} is expected.
 
36
 
 
37
    @ivar instanceId: the id of this instance
 
38
    @ivar onConnLost: deferred that will fired when the connection is lost
 
39
    @ivar dataToSend: data to send on the protocol
 
40
    """
 
41
 
 
42
    instanceCount = 0
 
43
 
 
44
    def __init__(self, onConnLost, dataToSend):
 
45
        self.onConnLost = onConnLost
 
46
        self.dataToSend = dataToSend
 
47
        self.instanceId = TestProto.instanceCount
 
48
        TestProto.instanceCount = TestProto.instanceCount + 1
 
49
 
 
50
    def connectionMade(self):
 
51
        self.data = []
 
52
        self.transport.write(self.dataToSend)
 
53
 
 
54
    def dataReceived(self, bytes):
 
55
        self.data.append(bytes)
 
56
        # self.transport.loseConnection()
 
57
 
 
58
    def connectionLost(self, reason):
 
59
        self.onConnLost.callback(self.data)
 
60
 
 
61
 
 
62
    def __repr__(self):
 
63
        """
 
64
        Custom repr for testing to avoid coupling amp tests with repr from
 
65
        L{Protocol}
 
66
 
 
67
        Returns a string which contains a unique identifier that can be looked
 
68
        up using the instanceId property::
 
69
 
 
70
            <TestProto #3>
 
71
        """
 
72
        return "<TestProto #%d>" % (self.instanceId,)
 
73
 
 
74
 
 
75
 
 
76
class SimpleSymmetricProtocol(amp.AMP):
 
77
 
 
78
    def sendHello(self, text):
 
79
        return self.callRemoteString(
 
80
            "hello",
 
81
            hello=text)
 
82
 
 
83
    def amp_HELLO(self, box):
 
84
        return amp.Box(hello=box['hello'])
 
85
 
 
86
    def amp_HOWDOYOUDO(self, box):
 
87
        return amp.QuitBox(howdoyoudo='world')
 
88
 
 
89
 
 
90
 
 
91
class UnfriendlyGreeting(Exception):
 
92
    """Greeting was insufficiently kind.
 
93
    """
 
94
 
 
95
class DeathThreat(Exception):
 
96
    """Greeting was insufficiently kind.
 
97
    """
 
98
 
 
99
class UnknownProtocol(Exception):
 
100
    """Asked to switch to the wrong protocol.
 
101
    """
 
102
 
 
103
 
 
104
class TransportPeer(amp.Argument):
 
105
    # this serves as some informal documentation for how to get variables from
 
106
    # the protocol or your environment and pass them to methods as arguments.
 
107
    def retrieve(self, d, name, proto):
 
108
        return ''
 
109
 
 
110
    def fromStringProto(self, notAString, proto):
 
111
        return proto.transport.getPeer()
 
112
 
 
113
    def toBox(self, name, strings, objects, proto):
 
114
        return
 
115
 
 
116
 
 
117
 
 
118
class Hello(amp.Command):
 
119
 
 
120
    commandName = 'hello'
 
121
 
 
122
    arguments = [('hello', amp.String()),
 
123
                 ('optional', amp.Boolean(optional=True)),
 
124
                 ('print', amp.Unicode(optional=True)),
 
125
                 ('from', TransportPeer(optional=True)),
 
126
                 ('mixedCase', amp.String(optional=True)),
 
127
                 ('dash-arg', amp.String(optional=True)),
 
128
                 ('underscore_arg', amp.String(optional=True))]
 
129
 
 
130
    response = [('hello', amp.String()),
 
131
                ('print', amp.Unicode(optional=True))]
 
132
 
 
133
    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
 
134
 
 
135
    fatalErrors = {DeathThreat: 'DEAD'}
 
136
 
 
137
class NoAnswerHello(Hello):
 
138
    commandName = Hello.commandName
 
139
    requiresAnswer = False
 
140
 
 
141
class FutureHello(amp.Command):
 
142
    commandName = 'hello'
 
143
 
 
144
    arguments = [('hello', amp.String()),
 
145
                 ('optional', amp.Boolean(optional=True)),
 
146
                 ('print', amp.Unicode(optional=True)),
 
147
                 ('from', TransportPeer(optional=True)),
 
148
                 ('bonus', amp.String(optional=True)), # addt'l arguments
 
149
                                                       # should generally be
 
150
                                                       # added at the end, and
 
151
                                                       # be optional...
 
152
                 ]
 
153
 
 
154
    response = [('hello', amp.String()),
 
155
                ('print', amp.Unicode(optional=True))]
 
156
 
 
157
    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
 
158
 
 
159
class WTF(amp.Command):
 
160
    """
 
161
    An example of an invalid command.
 
162
    """
 
163
 
 
164
 
 
165
class BrokenReturn(amp.Command):
 
166
    """ An example of a perfectly good command, but the handler is going to return
 
167
    None...
 
168
    """
 
169
 
 
170
    commandName = 'broken_return'
 
171
 
 
172
class Goodbye(amp.Command):
 
173
    # commandName left blank on purpose: this tests implicit command names.
 
174
    response = [('goodbye', amp.String())]
 
175
    responseType = amp.QuitBox
 
176
 
 
177
class Howdoyoudo(amp.Command):
 
178
    commandName = 'howdoyoudo'
 
179
    # responseType = amp.QuitBox
 
180
 
 
181
class WaitForever(amp.Command):
 
182
    commandName = 'wait_forever'
 
183
 
 
184
class GetList(amp.Command):
 
185
    commandName = 'getlist'
 
186
    arguments = [('length', amp.Integer())]
 
187
    response = [('body', amp.AmpList([('x', amp.Integer())]))]
 
188
 
 
189
class DontRejectMe(amp.Command):
 
190
    commandName = 'dontrejectme'
 
191
    arguments = [
 
192
            ('magicWord', amp.Unicode()),
 
193
            ('list', amp.AmpList([('name', amp.Unicode())], optional=True)),
 
194
            ]
 
195
    response = [('response', amp.Unicode())]
 
196
 
 
197
class SecuredPing(amp.Command):
 
198
    # XXX TODO: actually make this refuse to send over an insecure connection
 
199
    response = [('pinged', amp.Boolean())]
 
200
 
 
201
class TestSwitchProto(amp.ProtocolSwitchCommand):
 
202
    commandName = 'Switch-Proto'
 
203
 
 
204
    arguments = [
 
205
        ('name', amp.String()),
 
206
        ]
 
207
    errors = {UnknownProtocol: 'UNKNOWN'}
 
208
 
 
209
class SingleUseFactory(protocol.ClientFactory):
 
210
    def __init__(self, proto):
 
211
        self.proto = proto
 
212
        self.proto.factory = self
 
213
 
 
214
    def buildProtocol(self, addr):
 
215
        p, self.proto = self.proto, None
 
216
        return p
 
217
 
 
218
    reasonFailed = None
 
219
 
 
220
    def clientConnectionFailed(self, connector, reason):
 
221
        self.reasonFailed = reason
 
222
        return
 
223
 
 
224
THING_I_DONT_UNDERSTAND = 'gwebol nargo'
 
225
class ThingIDontUnderstandError(Exception):
 
226
    pass
 
227
 
 
228
class FactoryNotifier(amp.AMP):
 
229
    factory = None
 
230
    def connectionMade(self):
 
231
        if self.factory is not None:
 
232
            self.factory.theProto = self
 
233
            if hasattr(self.factory, 'onMade'):
 
234
                self.factory.onMade.callback(None)
 
235
 
 
236
    def emitpong(self):
 
237
        from twisted.internet.interfaces import ISSLTransport
 
238
        if not ISSLTransport.providedBy(self.transport):
 
239
            raise DeathThreat("only send secure pings over secure channels")
 
240
        return {'pinged': True}
 
241
    SecuredPing.responder(emitpong)
 
242
 
 
243
 
 
244
class SimpleSymmetricCommandProtocol(FactoryNotifier):
 
245
    maybeLater = None
 
246
    def __init__(self, onConnLost=None):
 
247
        amp.AMP.__init__(self)
 
248
        self.onConnLost = onConnLost
 
249
 
 
250
    def sendHello(self, text):
 
251
        return self.callRemote(Hello, hello=text)
 
252
 
 
253
    def sendUnicodeHello(self, text, translation):
 
254
        return self.callRemote(Hello, hello=text, Print=translation)
 
255
 
 
256
    greeted = False
 
257
 
 
258
    def cmdHello(self, hello, From, optional=None, Print=None,
 
259
                 mixedCase=None, dash_arg=None, underscore_arg=None):
 
260
        assert From == self.transport.getPeer()
 
261
        if hello == THING_I_DONT_UNDERSTAND:
 
262
            raise ThingIDontUnderstandError()
 
263
        if hello.startswith('fuck'):
 
264
            raise UnfriendlyGreeting("Don't be a dick.")
 
265
        if hello == 'die':
 
266
            raise DeathThreat("aieeeeeeeee")
 
267
        result = dict(hello=hello)
 
268
        if Print is not None:
 
269
            result.update(dict(Print=Print))
 
270
        self.greeted = True
 
271
        return result
 
272
    Hello.responder(cmdHello)
 
273
 
 
274
    def cmdGetlist(self, length):
 
275
        return {'body': [dict(x=1)] * length}
 
276
    GetList.responder(cmdGetlist)
 
277
 
 
278
    def okiwont(self, magicWord, list):
 
279
        return dict(response=u'%s accepted' % (list[0]['name']))
 
280
    DontRejectMe.responder(okiwont)
 
281
 
 
282
    def waitforit(self):
 
283
        self.waiting = defer.Deferred()
 
284
        return self.waiting
 
285
    WaitForever.responder(waitforit)
 
286
 
 
287
    def howdo(self):
 
288
        return dict(howdoyoudo='world')
 
289
    Howdoyoudo.responder(howdo)
 
290
 
 
291
    def saybye(self):
 
292
        return dict(goodbye="everyone")
 
293
    Goodbye.responder(saybye)
 
294
 
 
295
    def switchToTestProtocol(self, fail=False):
 
296
        if fail:
 
297
            name = 'no-proto'
 
298
        else:
 
299
            name = 'test-proto'
 
300
        p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
 
301
        return self.callRemote(
 
302
            TestSwitchProto,
 
303
            SingleUseFactory(p), name=name).addCallback(lambda ign: p)
 
304
 
 
305
    def switchit(self, name):
 
306
        if name == 'test-proto':
 
307
            return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
 
308
        raise UnknownProtocol(name)
 
309
    TestSwitchProto.responder(switchit)
 
310
 
 
311
    def donothing(self):
 
312
        return None
 
313
    BrokenReturn.responder(donothing)
 
314
 
 
315
 
 
316
class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
 
317
    def switchit(self, name):
 
318
        if name == 'test-proto':
 
319
            self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
 
320
            self.maybeLater = defer.Deferred()
 
321
            return self.maybeLater
 
322
        raise UnknownProtocol(name)
 
323
    TestSwitchProto.responder(switchit)
 
324
 
 
325
class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
 
326
    def badResponder(self, hello, From, optional=None, Print=None,
 
327
                     mixedCase=None, dash_arg=None, underscore_arg=None):
 
328
        """
 
329
        This responder does nothing and forgets to return a dictionary.
 
330
        """
 
331
    NoAnswerHello.responder(badResponder)
 
332
 
 
333
class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
 
334
    def goodNoAnswerResponder(self, hello, From, optional=None, Print=None,
 
335
                              mixedCase=None, dash_arg=None, underscore_arg=None):
 
336
        return dict(hello=hello+"-noanswer")
 
337
    NoAnswerHello.responder(goodNoAnswerResponder)
 
338
 
 
339
def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol,
 
340
                             ClientClass=SimpleSymmetricProtocol,
 
341
                             *a, **kw):
 
342
    """Returns a 3-tuple: (client, server, pump)
 
343
    """
 
344
    return iosim.connectedServerAndClient(
 
345
        ServerClass, ClientClass,
 
346
        *a, **kw)
 
347
 
 
348
class TotallyDumbProtocol(protocol.Protocol):
 
349
    buf = ''
 
350
    def dataReceived(self, data):
 
351
        self.buf += data
 
352
 
 
353
class LiteralAmp(amp.AMP):
 
354
    def __init__(self):
 
355
        self.boxes = []
 
356
 
 
357
    def ampBoxReceived(self, box):
 
358
        self.boxes.append(box)
 
359
        return
 
360
 
 
361
class ParsingTest(unittest.TestCase):
 
362
 
 
363
    def test_booleanValues(self):
 
364
        """
 
365
        Verify that the Boolean parser parses 'True' and 'False', but nothing
 
366
        else.
 
367
        """
 
368
        b = amp.Boolean()
 
369
        self.assertEquals(b.fromString("True"), True)
 
370
        self.assertEquals(b.fromString("False"), False)
 
371
        self.assertRaises(TypeError, b.fromString, "ninja")
 
372
        self.assertRaises(TypeError, b.fromString, "true")
 
373
        self.assertRaises(TypeError, b.fromString, "TRUE")
 
374
        self.assertEquals(b.toString(True), 'True')
 
375
        self.assertEquals(b.toString(False), 'False')
 
376
 
 
377
    def test_pathValueRoundTrip(self):
 
378
        """
 
379
        Verify the 'Path' argument can parse and emit a file path.
 
380
        """
 
381
        fp = filepath.FilePath(self.mktemp())
 
382
        p = amp.Path()
 
383
        s = p.toString(fp)
 
384
        v = p.fromString(s)
 
385
        self.assertNotIdentical(fp, v) # sanity check
 
386
        self.assertEquals(fp, v)
 
387
 
 
388
 
 
389
    def test_sillyEmptyThing(self):
 
390
        """
 
391
        Test that empty boxes raise an error; they aren't supposed to be sent
 
392
        on purpose.
 
393
        """
 
394
        a = amp.AMP()
 
395
        return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
 
396
 
 
397
 
 
398
    def test_ParsingRoundTrip(self):
 
399
        """
 
400
        Verify that various kinds of data make it through the encode/parse
 
401
        round-trip unharmed.
 
402
        """
 
403
        c, s, p = connectedServerAndClient(ClientClass=LiteralAmp,
 
404
                                           ServerClass=LiteralAmp)
 
405
 
 
406
        SIMPLE = ('simple', 'test')
 
407
        CE = ('ceq', ': ')
 
408
        CR = ('crtest', 'test\r')
 
409
        LF = ('lftest', 'hello\n')
 
410
        NEWLINE = ('newline', 'test\r\none\r\ntwo')
 
411
        NEWLINE2 = ('newline2', 'test\r\none\r\n two')
 
412
        BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline')
 
413
        BODYTEST = ('body', 'blah\r\n\r\ntesttest')
 
414
 
 
415
        testData = [
 
416
            [SIMPLE],
 
417
            [SIMPLE, BODYTEST],
 
418
            [SIMPLE, CE],
 
419
            [SIMPLE, CR],
 
420
            [SIMPLE, CE, CR, LF],
 
421
            [CE, CR, LF],
 
422
            [SIMPLE, NEWLINE, CE, NEWLINE2],
 
423
            [BODYTEST, SIMPLE, NEWLINE]
 
424
            ]
 
425
 
 
426
        for test in testData:
 
427
            jb = amp.Box()
 
428
            jb.update(dict(test))
 
429
            jb._sendTo(c)
 
430
            p.flush()
 
431
            self.assertEquals(s.boxes[-1], jb)
 
432
 
 
433
 
 
434
 
 
435
class FakeLocator(object):
 
436
    """
 
437
    This is a fake implementation of the interface implied by
 
438
    L{CommandLocator}.
 
439
    """
 
440
    def __init__(self):
 
441
        """
 
442
        Remember the given keyword arguments as a set of responders.
 
443
        """
 
444
        self.commands = {}
 
445
 
 
446
 
 
447
    def locateResponder(self, commandName):
 
448
        """
 
449
        Look up and return a function passed as a keyword argument of the given
 
450
        name to the constructor.
 
451
        """
 
452
        return self.commands[commandName]
 
453
 
 
454
 
 
455
class FakeSender:
 
456
    """
 
457
    This is a fake implementation of the 'box sender' interface implied by
 
458
    L{AMP}.
 
459
    """
 
460
    def __init__(self):
 
461
        """
 
462
        Create a fake sender and initialize the list of received boxes and
 
463
        unhandled errors.
 
464
        """
 
465
        self.sentBoxes = []
 
466
        self.unhandledErrors = []
 
467
        self.expectedErrors = 0
 
468
 
 
469
 
 
470
    def expectError(self):
 
471
        """
 
472
        Expect one error, so that the test doesn't fail.
 
473
        """
 
474
        self.expectedErrors += 1
 
475
 
 
476
 
 
477
    def sendBox(self, box):
 
478
        """
 
479
        Accept a box, but don't do anything.
 
480
        """
 
481
        self.sentBoxes.append(box)
 
482
 
 
483
 
 
484
    def unhandledError(self, failure):
 
485
        """
 
486
        Deal with failures by instantly re-raising them for easier debugging.
 
487
        """
 
488
        self.expectedErrors -= 1
 
489
        if self.expectedErrors < 0:
 
490
            failure.raiseException()
 
491
        else:
 
492
            self.unhandledErrors.append(failure)
 
493
 
 
494
 
 
495
 
 
496
class CommandDispatchTests(unittest.TestCase):
 
497
    """
 
498
    The AMP CommandDispatcher class dispatches converts AMP boxes into commands
 
499
    and responses using Command.responder decorator.
 
500
 
 
501
    Note: Originally, AMP's factoring was such that many tests for this
 
502
    functionality are now implemented as full round-trip tests in L{AMPTest}.
 
503
    Future tests should be written at this level instead, to ensure API
 
504
    compatibility and to provide more granular, readable units of test
 
505
    coverage.
 
506
    """
 
507
 
 
508
    def setUp(self):
 
509
        """
 
510
        Create a dispatcher to use.
 
511
        """
 
512
        self.locator = FakeLocator()
 
513
        self.sender = FakeSender()
 
514
        self.dispatcher = amp.BoxDispatcher(self.locator)
 
515
        self.dispatcher.startReceivingBoxes(self.sender)
 
516
 
 
517
 
 
518
    def test_receivedAsk(self):
 
519
        """
 
520
        L{CommandDispatcher.ampBoxReceived} should locate the appropriate
 
521
        command in its responder lookup, based on the '_ask' key.
 
522
        """
 
523
        received = []
 
524
        def thunk(box):
 
525
            received.append(box)
 
526
            return amp.Box({"hello": "goodbye"})
 
527
        input = amp.Box(_command="hello",
 
528
                        _ask="test-command-id",
 
529
                        hello="world")
 
530
        self.locator.commands['hello'] = thunk
 
531
        self.dispatcher.ampBoxReceived(input)
 
532
        self.assertEquals(received, [input])
 
533
 
 
534
 
 
535
    def test_sendUnhandledError(self):
 
536
        """
 
537
        L{CommandDispatcher} should relay its unhandled errors in responding to
 
538
        boxes to its boxSender.
 
539
        """
 
540
        err = RuntimeError("something went wrong, oh no")
 
541
        self.sender.expectError()
 
542
        self.dispatcher.unhandledError(Failure(err))
 
543
        self.assertEqual(len(self.sender.unhandledErrors), 1)
 
544
        self.assertEqual(self.sender.unhandledErrors[0].value, err)
 
545
 
 
546
 
 
547
    def test_unhandledSerializationError(self):
 
548
        """
 
549
        Errors during serialization ought to be relayed to the sender's
 
550
        unhandledError method.
 
551
        """
 
552
        err = RuntimeError("something undefined went wrong")
 
553
        def thunk(result):
 
554
            class BrokenBox(amp.Box):
 
555
                def _sendTo(self, proto):
 
556
                    raise err
 
557
            return BrokenBox()
 
558
        self.locator.commands['hello'] = thunk
 
559
        input = amp.Box(_command="hello",
 
560
                        _ask="test-command-id",
 
561
                        hello="world")
 
562
        self.sender.expectError()
 
563
        self.dispatcher.ampBoxReceived(input)
 
564
        self.assertEquals(len(self.sender.unhandledErrors), 1)
 
565
        self.assertEquals(self.sender.unhandledErrors[0].value, err)
 
566
 
 
567
 
 
568
    def test_callRemote(self):
 
569
        """
 
570
        L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
 
571
        box to its boxSender and record an outstanding L{Deferred}.  When a
 
572
        corresponding '_answer' packet is received, the L{Deferred} should be
 
573
        fired, and the results translated via the given L{Command}'s response
 
574
        de-serialization.
 
575
        """
 
576
        D = self.dispatcher.callRemote(Hello, hello='world')
 
577
        self.assertEquals(self.sender.sentBoxes,
 
578
                          [amp.AmpBox(_command="hello",
 
579
                                      _ask="1",
 
580
                                      hello="world")])
 
581
        answers = []
 
582
        D.addCallback(answers.append)
 
583
        self.assertEquals(answers, [])
 
584
        self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay",
 
585
                                                   'print': "ignored",
 
586
                                                   '_answer': "1"}))
 
587
        self.assertEquals(answers, [dict(hello="yay",
 
588
                                         Print=u"ignored")])
 
589
 
 
590
 
 
591
class SimpleGreeting(amp.Command):
 
592
    """
 
593
    A very simple greeting command that uses a few basic argument types.
 
594
    """
 
595
    commandName = 'simple'
 
596
    arguments = [('greeting', amp.Unicode()),
 
597
                 ('cookie', amp.Integer())]
 
598
    response = [('cookieplus', amp.Integer())]
 
599
 
 
600
 
 
601
class TestLocator(amp.CommandLocator):
 
602
    """
 
603
    A locator which implements a responder to a 'hello' command.
 
604
    """
 
605
    def __init__(self):
 
606
        self.greetings = []
 
607
 
 
608
 
 
609
    def greetingResponder(self, greeting, cookie):
 
610
        self.greetings.append((greeting, cookie))
 
611
        return dict(cookieplus=cookie + 3)
 
612
    greetingResponder = SimpleGreeting.responder(greetingResponder)
 
613
 
 
614
 
 
615
 
 
616
class OverrideLocatorAMP(amp.AMP):
 
617
    def __init__(self):
 
618
        amp.AMP.__init__(self)
 
619
        self.customResponder = object()
 
620
        self.expectations = {"custom": self.customResponder}
 
621
        self.greetings = []
 
622
 
 
623
 
 
624
    def lookupFunction(self, name):
 
625
        """
 
626
        Override the deprecated lookupFunction function.
 
627
        """
 
628
        if name in self.expectations:
 
629
            result = self.expectations[name]
 
630
            return result
 
631
        else:
 
632
            return super(OverrideLocatorAMP, self).lookupFunction(name)
 
633
 
 
634
 
 
635
    def greetingResponder(self, greeting, cookie):
 
636
        self.greetings.append((greeting, cookie))
 
637
        return dict(cookieplus=cookie + 3)
 
638
    greetingResponder = SimpleGreeting.responder(greetingResponder)
 
639
 
 
640
 
 
641
 
 
642
 
 
643
class CommandLocatorTests(unittest.TestCase):
 
644
    """
 
645
    The CommandLocator should enable users to specify responders to commands as
 
646
    functions that take structured objects, annotated with metadata.
 
647
    """
 
648
 
 
649
    def test_responderDecorator(self):
 
650
        """
 
651
        A method on a L{CommandLocator} subclass decorated with a L{Command}
 
652
        subclass's L{responder} decorator should be returned from
 
653
        locateResponder, wrapped in logic to serialize and deserialize its
 
654
        arguments.
 
655
        """
 
656
        locator = TestLocator()
 
657
        responderCallable = locator.locateResponder("simple")
 
658
        result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
 
659
        def done(values):
 
660
            self.assertEquals(values, amp.AmpBox(cookieplus='8'))
 
661
        return result.addCallback(done)
 
662
 
 
663
 
 
664
    def test_lookupFunctionDeprecatedOverride(self):
 
665
        """
 
666
        Subclasses which override locateResponder under its old name,
 
667
        lookupFunction, should have the override invoked instead.  (This tests
 
668
        an AMP subclass, because in the version of the code that could invoke
 
669
        this deprecated code path, there was no L{CommandLocator}.)
 
670
        """
 
671
        locator = OverrideLocatorAMP()
 
672
        customResponderObject = self.assertWarns(
 
673
            PendingDeprecationWarning,
 
674
            "Override locateResponder, not lookupFunction.",
 
675
            __file__, lambda : locator.locateResponder("custom"))
 
676
        self.assertEquals(locator.customResponder, customResponderObject)
 
677
        # Make sure upcalling works too
 
678
        normalResponderObject = self.assertWarns(
 
679
            PendingDeprecationWarning,
 
680
            "Override locateResponder, not lookupFunction.",
 
681
            __file__, lambda : locator.locateResponder("simple"))
 
682
        result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5"))
 
683
        def done(values):
 
684
            self.assertEquals(values, amp.AmpBox(cookieplus='8'))
 
685
        return result.addCallback(done)
 
686
 
 
687
 
 
688
    def test_lookupFunctionDeprecatedInvoke(self):
 
689
        """
 
690
        Invoking locateResponder under its old name, lookupFunction, should
 
691
        emit a deprecation warning, but do the same thing.
 
692
        """
 
693
        locator = TestLocator()
 
694
        responderCallable = self.assertWarns(
 
695
            PendingDeprecationWarning,
 
696
            "Call locateResponder, not lookupFunction.", __file__,
 
697
            lambda : locator.lookupFunction("simple"))
 
698
        result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
 
699
        def done(values):
 
700
            self.assertEquals(values, amp.AmpBox(cookieplus='8'))
 
701
        return result.addCallback(done)
 
702
 
 
703
 
 
704
 
 
705
SWITCH_CLIENT_DATA = 'Success!'
 
706
SWITCH_SERVER_DATA = 'No, really.  Success.'
 
707
 
 
708
 
 
709
class BinaryProtocolTests(unittest.TestCase):
 
710
    """
 
711
    Tests for L{amp.BinaryBoxProtocol}.
 
712
 
 
713
    @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender}
 
714
        which was passed to it.
 
715
    """
 
716
 
 
717
    def setUp(self):
 
718
        """
 
719
        Keep track of all boxes received by this test in its capacity as an
 
720
        L{IBoxReceiver} implementor.
 
721
        """
 
722
        self.boxes = []
 
723
        self.data = []
 
724
 
 
725
 
 
726
    def startReceivingBoxes(self, sender):
 
727
        """
 
728
        Implement L{IBoxReceiver.startReceivingBoxes} to just remember the
 
729
        value passed in.
 
730
        """
 
731
        self._boxSender = sender
 
732
 
 
733
 
 
734
    def ampBoxReceived(self, box):
 
735
        """
 
736
        A box was received by the protocol.
 
737
        """
 
738
        self.boxes.append(box)
 
739
 
 
740
    stopReason = None
 
741
    def stopReceivingBoxes(self, reason):
 
742
        """
 
743
        Record the reason that we stopped receiving boxes.
 
744
        """
 
745
        self.stopReason = reason
 
746
 
 
747
 
 
748
    # fake ITransport
 
749
    def getPeer(self):
 
750
        return 'no peer'
 
751
 
 
752
 
 
753
    def getHost(self):
 
754
        return 'no host'
 
755
 
 
756
 
 
757
    def write(self, data):
 
758
        self.data.append(data)
 
759
 
 
760
 
 
761
    def test_startReceivingBoxes(self):
 
762
        """
 
763
        When L{amp.BinaryBoxProtocol} is connected to a transport, it calls
 
764
        C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the
 
765
        L{IBoxSender} parameter.
 
766
        """
 
767
        protocol = amp.BinaryBoxProtocol(self)
 
768
        protocol.makeConnection(None)
 
769
        self.assertIdentical(self._boxSender, protocol)
 
770
 
 
771
 
 
772
    def test_sendBoxInStartReceivingBoxes(self):
 
773
        """
 
774
        The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is
 
775
        connected to a transport can call C{sendBox} on the L{IBoxSender}
 
776
        passed to it before C{startReceivingBoxes} returns and have that box
 
777
        sent.
 
778
        """
 
779
        class SynchronouslySendingReceiver:
 
780
            def startReceivingBoxes(self, sender):
 
781
                sender.sendBox(amp.Box({'foo': 'bar'}))
 
782
 
 
783
        transport = StringTransport()
 
784
        protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver())
 
785
        protocol.makeConnection(transport)
 
786
        self.assertEqual(
 
787
            transport.value(),
 
788
            '\x00\x03foo\x00\x03bar\x00\x00')
 
789
 
 
790
 
 
791
    def test_receiveBoxStateMachine(self):
 
792
        """
 
793
        When a binary box protocol receives:
 
794
            * a key
 
795
            * a value
 
796
            * an empty string
 
797
        it should emit a box and send it to its boxReceiver.
 
798
        """
 
799
        a = amp.BinaryBoxProtocol(self)
 
800
        a.stringReceived("hello")
 
801
        a.stringReceived("world")
 
802
        a.stringReceived("")
 
803
        self.assertEquals(self.boxes, [amp.AmpBox(hello="world")])
 
804
 
 
805
 
 
806
    def test_firstBoxFirstKeyExcessiveLength(self):
 
807
        """
 
808
        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
 
809
        the first a key it receives is larger than 255.
 
810
        """
 
811
        transport = StringTransport()
 
812
        protocol = amp.BinaryBoxProtocol(self)
 
813
        protocol.makeConnection(transport)
 
814
        protocol.dataReceived('\x01\x00')
 
815
        self.assertTrue(transport.disconnecting)
 
816
 
 
817
 
 
818
    def test_firstBoxSubsequentKeyExcessiveLength(self):
 
819
        """
 
820
        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
 
821
        a subsequent key in the first box it receives is larger than 255.
 
822
        """
 
823
        transport = StringTransport()
 
824
        protocol = amp.BinaryBoxProtocol(self)
 
825
        protocol.makeConnection(transport)
 
826
        protocol.dataReceived('\x00\x01k\x00\x01v')
 
827
        self.assertFalse(transport.disconnecting)
 
828
        protocol.dataReceived('\x01\x00')
 
829
        self.assertTrue(transport.disconnecting)
 
830
 
 
831
 
 
832
    def test_subsequentBoxFirstKeyExcessiveLength(self):
 
833
        """
 
834
        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
 
835
        the first key in a subsequent box it receives is larger than 255.
 
836
        """
 
837
        transport = StringTransport()
 
838
        protocol = amp.BinaryBoxProtocol(self)
 
839
        protocol.makeConnection(transport)
 
840
        protocol.dataReceived('\x00\x01k\x00\x01v\x00\x00')
 
841
        self.assertFalse(transport.disconnecting)
 
842
        protocol.dataReceived('\x01\x00')
 
843
        self.assertTrue(transport.disconnecting)
 
844
 
 
845
 
 
846
    def test_excessiveKeyFailure(self):
 
847
        """
 
848
        If L{amp.BinaryBoxProtocol} disconnects because it received a key
 
849
        length prefix which was too large, the L{IBoxReceiver}'s
 
850
        C{stopReceivingBoxes} method is called with a L{TooLong} failure.
 
851
        """
 
852
        protocol = amp.BinaryBoxProtocol(self)
 
853
        protocol.makeConnection(StringTransport())
 
854
        protocol.dataReceived('\x01\x00')
 
855
        protocol.connectionLost(
 
856
            Failure(error.ConnectionDone("simulated connection done")))
 
857
        self.stopReason.trap(amp.TooLong)
 
858
        self.assertTrue(self.stopReason.value.isKey)
 
859
        self.assertFalse(self.stopReason.value.isLocal)
 
860
        self.assertIdentical(self.stopReason.value.value, None)
 
861
        self.assertIdentical(self.stopReason.value.keyName, None)
 
862
 
 
863
 
 
864
    def test_receiveBoxData(self):
 
865
        """
 
866
        When a binary box protocol receives the serialized form of an AMP box,
 
867
        it should emit a similar box to its boxReceiver.
 
868
        """
 
869
        a = amp.BinaryBoxProtocol(self)
 
870
        a.dataReceived(amp.Box({"testKey": "valueTest",
 
871
                                "anotherKey": "anotherValue"}).serialize())
 
872
        self.assertEquals(self.boxes,
 
873
                          [amp.Box({"testKey": "valueTest",
 
874
                                    "anotherKey": "anotherValue"})])
 
875
 
 
876
 
 
877
    def test_receiveLongerBoxData(self):
 
878
        """
 
879
        An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with
 
880
        values of up to (2 ** 16 - 1) bytes.
 
881
        """
 
882
        length = (2 ** 16 - 1)
 
883
        value = 'x' * length
 
884
        transport = StringTransport()
 
885
        protocol = amp.BinaryBoxProtocol(self)
 
886
        protocol.makeConnection(transport)
 
887
        protocol.dataReceived(amp.Box({'k': value}).serialize())
 
888
        self.assertEqual(self.boxes, [amp.Box({'k': value})])
 
889
        self.assertFalse(transport.disconnecting)
 
890
 
 
891
 
 
892
    def test_sendBox(self):
 
893
        """
 
894
        When a binary box protocol sends a box, it should emit the serialized
 
895
        bytes of that box to its transport.
 
896
        """
 
897
        a = amp.BinaryBoxProtocol(self)
 
898
        a.makeConnection(self)
 
899
        aBox = amp.Box({"testKey": "valueTest",
 
900
                        "someData": "hello"})
 
901
        a.makeConnection(self)
 
902
        a.sendBox(aBox)
 
903
        self.assertEquals(''.join(self.data), aBox.serialize())
 
904
 
 
905
 
 
906
    def test_connectionLostStopSendingBoxes(self):
 
907
        """
 
908
        When a binary box protocol loses its connection, it should notify its
 
909
        box receiver that it has stopped receiving boxes.
 
910
        """
 
911
        a = amp.BinaryBoxProtocol(self)
 
912
        a.makeConnection(self)
 
913
        aBox = amp.Box({"sample": "data"})
 
914
        a.makeConnection(self)
 
915
        connectionFailure = Failure(RuntimeError())
 
916
        a.connectionLost(connectionFailure)
 
917
        self.assertIdentical(self.stopReason, connectionFailure)
 
918
 
 
919
 
 
920
    def test_protocolSwitch(self):
 
921
        """
 
922
        L{BinaryBoxProtocol} has the capacity to switch to a different protocol
 
923
        on a box boundary.  When a protocol is in the process of switching, it
 
924
        cannot receive traffic.
 
925
        """
 
926
        otherProto = TestProto(None, "outgoing data")
 
927
        test = self
 
928
        class SwitchyReceiver:
 
929
            switched = False
 
930
            def startReceivingBoxes(self, sender):
 
931
                pass
 
932
            def ampBoxReceived(self, box):
 
933
                test.assertFalse(self.switched,
 
934
                                 "Should only receive one box!")
 
935
                self.switched = True
 
936
                a._lockForSwitch()
 
937
                a._switchTo(otherProto)
 
938
        a = amp.BinaryBoxProtocol(SwitchyReceiver())
 
939
        anyOldBox = amp.Box({"include": "lots",
 
940
                             "of": "data"})
 
941
        a.makeConnection(self)
 
942
        # Include a 0-length box at the beginning of the next protocol's data,
 
943
        # to make sure that AMP doesn't eat the data or try to deliver extra
 
944
        # boxes either...
 
945
        moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!"
 
946
        a.dataReceived(moreThanOneBox)
 
947
        self.assertIdentical(otherProto.transport, self)
 
948
        self.assertEquals("".join(otherProto.data), "\x00\x00Hello, world!")
 
949
        self.assertEquals(self.data, ["outgoing data"])
 
950
        a.dataReceived("more data")
 
951
        self.assertEquals("".join(otherProto.data),
 
952
                          "\x00\x00Hello, world!more data")
 
953
        self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)
 
954
 
 
955
 
 
956
    def test_protocolSwitchInvalidStates(self):
 
957
        """
 
958
        In order to make sure the protocol never gets any invalid data sent
 
959
        into the middle of a box, it must be locked for switching before it is
 
960
        switched.  It can only be unlocked if the switch failed, and attempting
 
961
        to send a box while it is locked should raise an exception.
 
962
        """
 
963
        a = amp.BinaryBoxProtocol(self)
 
964
        a.makeConnection(self)
 
965
        sampleBox = amp.Box({"some": "data"})
 
966
        a._lockForSwitch()
 
967
        self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
 
968
        a._unlockFromSwitch()
 
969
        a.sendBox(sampleBox)
 
970
        self.assertEquals(''.join(self.data), sampleBox.serialize())
 
971
        a._lockForSwitch()
 
972
        otherProto = TestProto(None, "outgoing data")
 
973
        a._switchTo(otherProto)
 
974
        self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)
 
975
 
 
976
 
 
977
    def test_protocolSwitchLoseConnection(self):
 
978
        """
 
979
        When the protocol is switched, it should notify its nested protocol of
 
980
        disconnection.
 
981
        """
 
982
        class Loser(protocol.Protocol):
 
983
            reason = None
 
984
            def connectionLost(self, reason):
 
985
                self.reason = reason
 
986
        connectionLoser = Loser()
 
987
        a = amp.BinaryBoxProtocol(self)
 
988
        a.makeConnection(self)
 
989
        a._lockForSwitch()
 
990
        a._switchTo(connectionLoser)
 
991
        connectionFailure = Failure(RuntimeError())
 
992
        a.connectionLost(connectionFailure)
 
993
        self.assertEquals(connectionLoser.reason, connectionFailure)
 
994
 
 
995
 
 
996
    def test_protocolSwitchLoseClientConnection(self):
 
997
        """
 
998
        When the protocol is switched, it should notify its nested client
 
999
        protocol factory of disconnection.
 
1000
        """
 
1001
        class ClientLoser:
 
1002
            reason = None
 
1003
            def clientConnectionLost(self, connector, reason):
 
1004
                self.reason = reason
 
1005
        a = amp.BinaryBoxProtocol(self)
 
1006
        connectionLoser = protocol.Protocol()
 
1007
        clientLoser = ClientLoser()
 
1008
        a.makeConnection(self)
 
1009
        a._lockForSwitch()
 
1010
        a._switchTo(connectionLoser, clientLoser)
 
1011
        connectionFailure = Failure(RuntimeError())
 
1012
        a.connectionLost(connectionFailure)
 
1013
        self.assertEquals(clientLoser.reason, connectionFailure)
 
1014
 
 
1015
 
 
1016
 
 
1017
class AMPTest(unittest.TestCase):
 
1018
 
 
1019
    def test_interfaceDeclarations(self):
 
1020
        """
 
1021
        The classes in the amp module ought to implement the interfaces that
 
1022
        are declared for their benefit.
 
1023
        """
 
1024
        for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol),
 
1025
                                          (amp.IBoxReceiver, amp.BoxDispatcher),
 
1026
                                          (amp.IResponderLocator, amp.CommandLocator),
 
1027
                                          (amp.IResponderLocator, amp.SimpleStringLocator),
 
1028
                                          (amp.IBoxSender, amp.AMP),
 
1029
                                          (amp.IBoxReceiver, amp.AMP),
 
1030
                                          (amp.IResponderLocator, amp.AMP)]:
 
1031
            self.failUnless(interface.implementedBy(implementation),
 
1032
                            "%s does not implements(%s)" % (implementation, interface))
 
1033
 
 
1034
 
 
1035
    def test_helloWorld(self):
 
1036
        """
 
1037
        Verify that a simple command can be sent and its response received with
 
1038
        the simple low-level string-based API.
 
1039
        """
 
1040
        c, s, p = connectedServerAndClient()
 
1041
        L = []
 
1042
        HELLO = 'world'
 
1043
        c.sendHello(HELLO).addCallback(L.append)
 
1044
        p.flush()
 
1045
        self.assertEquals(L[0]['hello'], HELLO)
 
1046
 
 
1047
 
 
1048
    def test_wireFormatRoundTrip(self):
 
1049
        """
 
1050
        Verify that mixed-case, underscored and dashed arguments are mapped to
 
1051
        their python names properly.
 
1052
        """
 
1053
        c, s, p = connectedServerAndClient()
 
1054
        L = []
 
1055
        HELLO = 'world'
 
1056
        c.sendHello(HELLO).addCallback(L.append)
 
1057
        p.flush()
 
1058
        self.assertEquals(L[0]['hello'], HELLO)
 
1059
 
 
1060
 
 
1061
    def test_helloWorldUnicode(self):
 
1062
        """
 
1063
        Verify that unicode arguments can be encoded and decoded.
 
1064
        """
 
1065
        c, s, p = connectedServerAndClient(
 
1066
            ServerClass=SimpleSymmetricCommandProtocol,
 
1067
            ClientClass=SimpleSymmetricCommandProtocol)
 
1068
        L = []
 
1069
        HELLO = 'world'
 
1070
        HELLO_UNICODE = 'wor\u1234ld'
 
1071
        c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
 
1072
        p.flush()
 
1073
        self.assertEquals(L[0]['hello'], HELLO)
 
1074
        self.assertEquals(L[0]['Print'], HELLO_UNICODE)
 
1075
 
 
1076
 
 
1077
    def test_callRemoteStringRequiresAnswerFalse(self):
 
1078
        """
 
1079
        L{BoxDispatcher.callRemoteString} returns C{None} if C{requiresAnswer}
 
1080
        is C{False}.
 
1081
        """
 
1082
        c, s, p = connectedServerAndClient()
 
1083
        ret = c.callRemoteString("WTF", requiresAnswer=False)
 
1084
        self.assertIdentical(ret, None)
 
1085
 
 
1086
 
 
1087
    def test_unknownCommandLow(self):
 
1088
        """
 
1089
        Verify that unknown commands using low-level APIs will be rejected with an
 
1090
        error, but will NOT terminate the connection.
 
1091
        """
 
1092
        c, s, p = connectedServerAndClient()
 
1093
        L = []
 
1094
        def clearAndAdd(e):
 
1095
            """
 
1096
            You can't propagate the error...
 
1097
            """
 
1098
            e.trap(amp.UnhandledCommand)
 
1099
            return "OK"
 
1100
        c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append)
 
1101
        p.flush()
 
1102
        self.assertEquals(L.pop(), "OK")
 
1103
        HELLO = 'world'
 
1104
        c.sendHello(HELLO).addCallback(L.append)
 
1105
        p.flush()
 
1106
        self.assertEquals(L[0]['hello'], HELLO)
 
1107
 
 
1108
 
 
1109
    def test_unknownCommandHigh(self):
 
1110
        """
 
1111
        Verify that unknown commands using high-level APIs will be rejected with an
 
1112
        error, but will NOT terminate the connection.
 
1113
        """
 
1114
        c, s, p = connectedServerAndClient()
 
1115
        L = []
 
1116
        def clearAndAdd(e):
 
1117
            """
 
1118
            You can't propagate the error...
 
1119
            """
 
1120
            e.trap(amp.UnhandledCommand)
 
1121
            return "OK"
 
1122
        c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
 
1123
        p.flush()
 
1124
        self.assertEquals(L.pop(), "OK")
 
1125
        HELLO = 'world'
 
1126
        c.sendHello(HELLO).addCallback(L.append)
 
1127
        p.flush()
 
1128
        self.assertEquals(L[0]['hello'], HELLO)
 
1129
 
 
1130
 
 
1131
    def test_brokenReturnValue(self):
 
1132
        """
 
1133
        It can be very confusing if you write some code which responds to a
 
1134
        command, but gets the return value wrong.  Most commonly you end up
 
1135
        returning None instead of a dictionary.
 
1136
 
 
1137
        Verify that if that happens, the framework logs a useful error.
 
1138
        """
 
1139
        L = []
 
1140
        SimpleSymmetricCommandProtocol().dispatchCommand(
 
1141
            amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append)
 
1142
        blr = L[0].trap(amp.BadLocalReturn)
 
1143
        self.failUnlessIn('None', repr(L[0].value))
 
1144
 
 
1145
 
 
1146
    def test_unknownArgument(self):
 
1147
        """
 
1148
        Verify that unknown arguments are ignored, and not passed to a Python
 
1149
        function which can't accept them.
 
1150
        """
 
1151
        c, s, p = connectedServerAndClient(
 
1152
            ServerClass=SimpleSymmetricCommandProtocol,
 
1153
            ClientClass=SimpleSymmetricCommandProtocol)
 
1154
        L = []
 
1155
        HELLO = 'world'
 
1156
        # c.sendHello(HELLO).addCallback(L.append)
 
1157
        c.callRemote(FutureHello,
 
1158
                     hello=HELLO,
 
1159
                     bonus="I'm not in the book!").addCallback(
 
1160
            L.append)
 
1161
        p.flush()
 
1162
        self.assertEquals(L[0]['hello'], HELLO)
 
1163
 
 
1164
 
 
1165
    def test_simpleReprs(self):
 
1166
        """
 
1167
        Verify that the various Box objects repr properly, for debugging.
 
1168
        """
 
1169
        self.assertEquals(type(repr(amp._SwitchBox('a'))), str)
 
1170
        self.assertEquals(type(repr(amp.QuitBox())), str)
 
1171
        self.assertEquals(type(repr(amp.AmpBox())), str)
 
1172
        self.failUnless("AmpBox" in repr(amp.AmpBox()))
 
1173
 
 
1174
 
 
1175
    def test_innerProtocolInRepr(self):
 
1176
        """
 
1177
        Verify that L{AMP} objects output their innerProtocol when set.
 
1178
        """
 
1179
        otherProto = TestProto(None, "outgoing data")
 
1180
        a = amp.AMP()
 
1181
        a.innerProtocol = otherProto
 
1182
        def fakeID(obj):
 
1183
            return {a: 0x1234}.get(obj, id(obj))
 
1184
        self.addCleanup(setIDFunction, setIDFunction(fakeID))
 
1185
 
 
1186
        self.assertEquals(
 
1187
            repr(a), "<AMP inner <TestProto #%d> at 0x1234>" % (
 
1188
                otherProto.instanceId,))
 
1189
 
 
1190
 
 
1191
    def test_innerProtocolNotInRepr(self):
 
1192
        """
 
1193
        Verify that L{AMP} objects do not output 'inner' when no innerProtocol
 
1194
        is set.
 
1195
        """
 
1196
        a = amp.AMP()
 
1197
        def fakeID(obj):
 
1198
            return {a: 0x4321}.get(obj, id(obj))
 
1199
        self.addCleanup(setIDFunction, setIDFunction(fakeID))
 
1200
        self.assertEquals(repr(a), "<AMP at 0x4321>")
 
1201
 
 
1202
 
 
1203
    def test_simpleSSLRepr(self):
 
1204
        """
 
1205
        L{amp._TLSBox.__repr__} returns a string.
 
1206
        """
 
1207
        self.assertEquals(type(repr(amp._TLSBox())), str)
 
1208
 
 
1209
    test_simpleSSLRepr.skip = skipSSL
 
1210
 
 
1211
 
 
1212
    def test_keyTooLong(self):
 
1213
        """
 
1214
        Verify that a key that is too long will immediately raise a synchronous
 
1215
        exception.
 
1216
        """
 
1217
        c, s, p = connectedServerAndClient()
 
1218
        L = []
 
1219
        x = "H" * (0xff+1)
 
1220
        tl = self.assertRaises(amp.TooLong,
 
1221
                               c.callRemoteString, "Hello",
 
1222
                               **{x: "hi"})
 
1223
        self.failUnless(tl.isKey)
 
1224
        self.failUnless(tl.isLocal)
 
1225
        self.failUnlessIdentical(tl.keyName, None)
 
1226
        self.failUnlessIdentical(tl.value, x)
 
1227
        self.failUnless(str(len(x)) in repr(tl))
 
1228
        self.failUnless("key" in repr(tl))
 
1229
 
 
1230
 
 
1231
    def test_valueTooLong(self):
 
1232
        """
 
1233
        Verify that attempting to send value longer than 64k will immediately
 
1234
        raise an exception.
 
1235
        """
 
1236
        c, s, p = connectedServerAndClient()
 
1237
        L = []
 
1238
        x = "H" * (0xffff+1)
 
1239
        tl = self.assertRaises(amp.TooLong, c.sendHello, x)
 
1240
        p.flush()
 
1241
        self.failIf(tl.isKey)
 
1242
        self.failUnless(tl.isLocal)
 
1243
        self.assertEquals(tl.keyName, 'hello')
 
1244
        self.failUnlessIdentical(tl.value, x)
 
1245
        self.failUnless(str(len(x)) in repr(tl))
 
1246
        self.failUnless("value" in repr(tl))
 
1247
        self.failUnless('hello' in repr(tl))
 
1248
 
 
1249
 
 
1250
    def test_helloWorldCommand(self):
 
1251
        """
 
1252
        Verify that a simple command can be sent and its response received with
 
1253
        the high-level value parsing API.
 
1254
        """
 
1255
        c, s, p = connectedServerAndClient(
 
1256
            ServerClass=SimpleSymmetricCommandProtocol,
 
1257
            ClientClass=SimpleSymmetricCommandProtocol)
 
1258
        L = []
 
1259
        HELLO = 'world'
 
1260
        c.sendHello(HELLO).addCallback(L.append)
 
1261
        p.flush()
 
1262
        self.assertEquals(L[0]['hello'], HELLO)
 
1263
 
 
1264
 
 
1265
    def test_helloErrorHandling(self):
 
1266
        """
 
1267
        Verify that if a known error type is raised and handled, it will be
 
1268
        properly relayed to the other end of the connection and translated into
 
1269
        an exception, and no error will be logged.
 
1270
        """
 
1271
        L=[]
 
1272
        c, s, p = connectedServerAndClient(
 
1273
            ServerClass=SimpleSymmetricCommandProtocol,
 
1274
            ClientClass=SimpleSymmetricCommandProtocol)
 
1275
        HELLO = 'fuck you'
 
1276
        c.sendHello(HELLO).addErrback(L.append)
 
1277
        p.flush()
 
1278
        L[0].trap(UnfriendlyGreeting)
 
1279
        self.assertEquals(str(L[0].value), "Don't be a dick.")
 
1280
 
 
1281
 
 
1282
    def test_helloFatalErrorHandling(self):
 
1283
        """
 
1284
        Verify that if a known, fatal error type is raised and handled, it will
 
1285
        be properly relayed to the other end of the connection and translated
 
1286
        into an exception, no error will be logged, and the connection will be
 
1287
        terminated.
 
1288
        """
 
1289
        L=[]
 
1290
        c, s, p = connectedServerAndClient(
 
1291
            ServerClass=SimpleSymmetricCommandProtocol,
 
1292
            ClientClass=SimpleSymmetricCommandProtocol)
 
1293
        HELLO = 'die'
 
1294
        c.sendHello(HELLO).addErrback(L.append)
 
1295
        p.flush()
 
1296
        L.pop().trap(DeathThreat)
 
1297
        c.sendHello(HELLO).addErrback(L.append)
 
1298
        p.flush()
 
1299
        L.pop().trap(error.ConnectionDone)
 
1300
 
 
1301
 
 
1302
 
 
1303
    def test_helloNoErrorHandling(self):
 
1304
        """
 
1305
        Verify that if an unknown error type is raised, it will be relayed to
 
1306
        the other end of the connection and translated into an exception, it
 
1307
        will be logged, and then the connection will be dropped.
 
1308
        """
 
1309
        L=[]
 
1310
        c, s, p = connectedServerAndClient(
 
1311
            ServerClass=SimpleSymmetricCommandProtocol,
 
1312
            ClientClass=SimpleSymmetricCommandProtocol)
 
1313
        HELLO = THING_I_DONT_UNDERSTAND
 
1314
        c.sendHello(HELLO).addErrback(L.append)
 
1315
        p.flush()
 
1316
        ure = L.pop()
 
1317
        ure.trap(amp.UnknownRemoteError)
 
1318
        c.sendHello(HELLO).addErrback(L.append)
 
1319
        cl = L.pop()
 
1320
        cl.trap(error.ConnectionDone)
 
1321
        # The exception should have been logged.
 
1322
        self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError))
 
1323
 
 
1324
 
 
1325
 
 
1326
    def test_lateAnswer(self):
 
1327
        """
 
1328
        Verify that a command that does not get answered until after the
 
1329
        connection terminates will not cause any errors.
 
1330
        """
 
1331
        c, s, p = connectedServerAndClient(
 
1332
            ServerClass=SimpleSymmetricCommandProtocol,
 
1333
            ClientClass=SimpleSymmetricCommandProtocol)
 
1334
        L = []
 
1335
        HELLO = 'world'
 
1336
        c.callRemote(WaitForever).addErrback(L.append)
 
1337
        p.flush()
 
1338
        self.assertEquals(L, [])
 
1339
        s.transport.loseConnection()
 
1340
        p.flush()
 
1341
        L.pop().trap(error.ConnectionDone)
 
1342
        # Just make sure that it doesn't error...
 
1343
        s.waiting.callback({})
 
1344
        return s.waiting
 
1345
 
 
1346
 
 
1347
    def test_requiresNoAnswer(self):
 
1348
        """
 
1349
        Verify that a command that requires no answer is run.
 
1350
        """
 
1351
        L=[]
 
1352
        c, s, p = connectedServerAndClient(
 
1353
            ServerClass=SimpleSymmetricCommandProtocol,
 
1354
            ClientClass=SimpleSymmetricCommandProtocol)
 
1355
        HELLO = 'world'
 
1356
        c.callRemote(NoAnswerHello, hello=HELLO)
 
1357
        p.flush()
 
1358
        self.failUnless(s.greeted)
 
1359
 
 
1360
 
 
1361
    def test_requiresNoAnswerFail(self):
 
1362
        """
 
1363
        Verify that commands sent after a failed no-answer request do not complete.
 
1364
        """
 
1365
        L=[]
 
1366
        c, s, p = connectedServerAndClient(
 
1367
            ServerClass=SimpleSymmetricCommandProtocol,
 
1368
            ClientClass=SimpleSymmetricCommandProtocol)
 
1369
        HELLO = 'fuck you'
 
1370
        c.callRemote(NoAnswerHello, hello=HELLO)
 
1371
        p.flush()
 
1372
        # This should be logged locally.
 
1373
        self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError))
 
1374
        HELLO = 'world'
 
1375
        c.callRemote(Hello, hello=HELLO).addErrback(L.append)
 
1376
        p.flush()
 
1377
        L.pop().trap(error.ConnectionDone)
 
1378
        self.failIf(s.greeted)
 
1379
 
 
1380
 
 
1381
    def test_noAnswerResponderBadAnswer(self):
 
1382
        """
 
1383
        Verify that responders of requiresAnswer=False commands have to return
 
1384
        a dictionary anyway.
 
1385
 
 
1386
        (requiresAnswer is a hint from the _client_ - the server may be called
 
1387
        upon to answer commands in any case, if the client wants to know when
 
1388
        they complete.)
 
1389
        """
 
1390
        c, s, p = connectedServerAndClient(
 
1391
            ServerClass=BadNoAnswerCommandProtocol,
 
1392
            ClientClass=SimpleSymmetricCommandProtocol)
 
1393
        c.callRemote(NoAnswerHello, hello="hello")
 
1394
        p.flush()
 
1395
        le = self.flushLoggedErrors(amp.BadLocalReturn)
 
1396
        self.assertEquals(len(le), 1)
 
1397
 
 
1398
 
 
1399
    def test_noAnswerResponderAskedForAnswer(self):
 
1400
        """
 
1401
        Verify that responders with requiresAnswer=False will actually respond
 
1402
        if the client sets requiresAnswer=True.  In other words, verify that
 
1403
        requiresAnswer is a hint honored only by the client.
 
1404
        """
 
1405
        c, s, p = connectedServerAndClient(
 
1406
            ServerClass=NoAnswerCommandProtocol,
 
1407
            ClientClass=SimpleSymmetricCommandProtocol)
 
1408
        L = []
 
1409
        c.callRemote(Hello, hello="Hello!").addCallback(L.append)
 
1410
        p.flush()
 
1411
        self.assertEquals(len(L), 1)
 
1412
        self.assertEquals(L, [dict(hello="Hello!-noanswer",
 
1413
                                   Print=None)])  # Optional response argument
 
1414
 
 
1415
 
 
1416
    def test_ampListCommand(self):
 
1417
        """
 
1418
        Test encoding of an argument that uses the AmpList encoding.
 
1419
        """
 
1420
        c, s, p = connectedServerAndClient(
 
1421
            ServerClass=SimpleSymmetricCommandProtocol,
 
1422
            ClientClass=SimpleSymmetricCommandProtocol)
 
1423
        L = []
 
1424
        c.callRemote(GetList, length=10).addCallback(L.append)
 
1425
        p.flush()
 
1426
        values = L.pop().get('body')
 
1427
        self.assertEquals(values, [{'x': 1}] * 10)
 
1428
 
 
1429
 
 
1430
    def test_optionalAmpListOmitted(self):
 
1431
        """
 
1432
        Test that sending a command with an omitted AmpList argument that is
 
1433
        designated as optional does not raise an InvalidSignature error.
 
1434
        """
 
1435
        dontRejectMeCommand = DontRejectMe(magicWord=u'please')
 
1436
 
 
1437
 
 
1438
    def test_optionalAmpListPresent(self):
 
1439
        """
 
1440
        Sanity check that optional AmpList arguments are processed normally.
 
1441
        """
 
1442
        dontRejectMeCommand = DontRejectMe(magicWord=u'please',
 
1443
                list=[{'name': 'foo'}])
 
1444
        c, s, p = connectedServerAndClient(
 
1445
            ServerClass=SimpleSymmetricCommandProtocol,
 
1446
            ClientClass=SimpleSymmetricCommandProtocol)
 
1447
        L = []
 
1448
        c.callRemote(DontRejectMe, magicWord=u'please',
 
1449
                list=[{'name': 'foo'}]).addCallback(L.append)
 
1450
        p.flush()
 
1451
        response = L.pop().get('response')
 
1452
        self.assertEquals(response, 'foo accepted')
 
1453
 
 
1454
 
 
1455
    def test_failEarlyOnArgSending(self):
 
1456
        """
 
1457
        Verify that if we pass an invalid argument list (omitting an argument), an
 
1458
        exception will be raised.
 
1459
        """
 
1460
        okayCommand = Hello(hello="What?")
 
1461
        self.assertRaises(amp.InvalidSignature, Hello)
 
1462
 
 
1463
 
 
1464
    def test_doubleProtocolSwitch(self):
 
1465
        """
 
1466
        As a debugging aid, a protocol system should raise a
 
1467
        L{ProtocolSwitched} exception when asked to switch a protocol that is
 
1468
        already switched.
 
1469
        """
 
1470
        serverDeferred = defer.Deferred()
 
1471
        serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
 
1472
        clientDeferred = defer.Deferred()
 
1473
        clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
 
1474
        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
 
1475
                                           ClientClass=lambda: clientProto)
 
1476
        def switched(result):
 
1477
            self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
 
1478
            self.testSucceeded = True
 
1479
        c.switchToTestProtocol().addCallback(switched)
 
1480
        p.flush()
 
1481
        self.failUnless(self.testSucceeded)
 
1482
 
 
1483
 
 
1484
    def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol,
 
1485
                            spuriousTraffic=False,
 
1486
                            spuriousError=False):
 
1487
        """
 
1488
        Verify that it is possible to switch to another protocol mid-connection and
 
1489
        send data to it successfully.
 
1490
        """
 
1491
        self.testSucceeded = False
 
1492
 
 
1493
        serverDeferred = defer.Deferred()
 
1494
        serverProto = switcher(serverDeferred)
 
1495
        clientDeferred = defer.Deferred()
 
1496
        clientProto = switcher(clientDeferred)
 
1497
        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
 
1498
                                           ClientClass=lambda: clientProto)
 
1499
 
 
1500
        if spuriousTraffic:
 
1501
            wfdr = []           # remote
 
1502
            wfd = c.callRemote(WaitForever).addErrback(wfdr.append)
 
1503
        switchDeferred = c.switchToTestProtocol()
 
1504
        if spuriousTraffic:
 
1505
            self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world')
 
1506
 
 
1507
        def cbConnsLost(((serverSuccess, serverData),
 
1508
                         (clientSuccess, clientData))):
 
1509
            self.failUnless(serverSuccess)
 
1510
            self.failUnless(clientSuccess)
 
1511
            self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA)
 
1512
            self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA)
 
1513
            self.testSucceeded = True
 
1514
 
 
1515
        def cbSwitch(proto):
 
1516
            return defer.DeferredList(
 
1517
                [serverDeferred, clientDeferred]).addCallback(cbConnsLost)
 
1518
 
 
1519
        switchDeferred.addCallback(cbSwitch)
 
1520
        p.flush()
 
1521
        if serverProto.maybeLater is not None:
 
1522
            serverProto.maybeLater.callback(serverProto.maybeLaterProto)
 
1523
            p.flush()
 
1524
        if spuriousTraffic:
 
1525
            # switch is done here; do this here to make sure that if we're
 
1526
            # going to corrupt the connection, we do it before it's closed.
 
1527
            if spuriousError:
 
1528
                s.waiting.errback(amp.RemoteAmpError(
 
1529
                        "SPURIOUS",
 
1530
                        "Here's some traffic in the form of an error."))
 
1531
            else:
 
1532
                s.waiting.callback({})
 
1533
            p.flush()
 
1534
        c.transport.loseConnection() # close it
 
1535
        p.flush()
 
1536
        self.failUnless(self.testSucceeded)
 
1537
 
 
1538
 
 
1539
    def test_protocolSwitchDeferred(self):
 
1540
        """
 
1541
        Verify that protocol-switching even works if the value returned from
 
1542
        the command that does the switch is deferred.
 
1543
        """
 
1544
        return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
 
1545
 
 
1546
 
 
1547
    def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
 
1548
        """
 
1549
        Verify that if we try to switch protocols and it fails, the connection
 
1550
        stays up and we can go back to speaking AMP.
 
1551
        """
 
1552
        self.testSucceeded = False
 
1553
 
 
1554
        serverDeferred = defer.Deferred()
 
1555
        serverProto = switcher(serverDeferred)
 
1556
        clientDeferred = defer.Deferred()
 
1557
        clientProto = switcher(clientDeferred)
 
1558
        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
 
1559
                                           ClientClass=lambda: clientProto)
 
1560
        L = []
 
1561
        switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append)
 
1562
        p.flush()
 
1563
        L.pop().trap(UnknownProtocol)
 
1564
        self.failIf(self.testSucceeded)
 
1565
        # It's a known error, so let's send a "hello" on the same connection;
 
1566
        # it should work.
 
1567
        c.sendHello('world').addCallback(L.append)
 
1568
        p.flush()
 
1569
        self.assertEqual(L.pop()['hello'], 'world')
 
1570
 
 
1571
 
 
1572
    def test_trafficAfterSwitch(self):
 
1573
        """
 
1574
        Verify that attempts to send traffic after a switch will not corrupt
 
1575
        the nested protocol.
 
1576
        """
 
1577
        return self.test_protocolSwitch(spuriousTraffic=True)
 
1578
 
 
1579
 
 
1580
    def test_errorAfterSwitch(self):
 
1581
        """
 
1582
        Returning an error after a protocol switch should record the underlying
 
1583
        error.
 
1584
        """
 
1585
        return self.test_protocolSwitch(spuriousTraffic=True,
 
1586
                                        spuriousError=True)
 
1587
 
 
1588
 
 
1589
    def test_quitBoxQuits(self):
 
1590
        """
 
1591
        Verify that commands with a responseType of QuitBox will in fact
 
1592
        terminate the connection.
 
1593
        """
 
1594
        c, s, p = connectedServerAndClient(
 
1595
            ServerClass=SimpleSymmetricCommandProtocol,
 
1596
            ClientClass=SimpleSymmetricCommandProtocol)
 
1597
 
 
1598
        L = []
 
1599
        HELLO = 'world'
 
1600
        GOODBYE = 'everyone'
 
1601
        c.sendHello(HELLO).addCallback(L.append)
 
1602
        p.flush()
 
1603
        self.assertEquals(L.pop()['hello'], HELLO)
 
1604
        c.callRemote(Goodbye).addCallback(L.append)
 
1605
        p.flush()
 
1606
        self.assertEquals(L.pop()['goodbye'], GOODBYE)
 
1607
        c.sendHello(HELLO).addErrback(L.append)
 
1608
        L.pop().trap(error.ConnectionDone)
 
1609
 
 
1610
 
 
1611
    def test_basicLiteralEmit(self):
 
1612
        """
 
1613
        Verify that the command dictionaries for a callRemoteN look correct
 
1614
        after being serialized and parsed.
 
1615
        """
 
1616
        c, s, p = connectedServerAndClient()
 
1617
        L = []
 
1618
        s.ampBoxReceived = L.append
 
1619
        c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
 
1620
                     dash_arg='x', underscore_arg='y')
 
1621
        p.flush()
 
1622
        self.assertEquals(len(L), 1)
 
1623
        for k, v in [('_command', Hello.commandName),
 
1624
                     ('hello', 'hello test'),
 
1625
                     ('mixedCase', 'mixed case arg test'),
 
1626
                     ('dash-arg', 'x'),
 
1627
                     ('underscore_arg', 'y')]:
 
1628
            self.assertEquals(L[-1].pop(k), v)
 
1629
        L[-1].pop('_ask')
 
1630
        self.assertEquals(L[-1], {})
 
1631
 
 
1632
 
 
1633
    def test_basicStructuredEmit(self):
 
1634
        """
 
1635
        Verify that a call similar to basicLiteralEmit's is handled properly with
 
1636
        high-level quoting and passing to Python methods, and that argument
 
1637
        names are correctly handled.
 
1638
        """
 
1639
        L = []
 
1640
        class StructuredHello(amp.AMP):
 
1641
            def h(self, *a, **k):
 
1642
                L.append((a, k))
 
1643
                return dict(hello='aaa')
 
1644
            Hello.responder(h)
 
1645
        c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
 
1646
        c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
 
1647
                     dash_arg='x', underscore_arg='y').addCallback(L.append)
 
1648
        p.flush()
 
1649
        self.assertEquals(len(L), 2)
 
1650
        self.assertEquals(L[0],
 
1651
                          ((), dict(
 
1652
                    hello='hello test',
 
1653
                    mixedCase='mixed case arg test',
 
1654
                    dash_arg='x',
 
1655
                    underscore_arg='y',
 
1656
 
 
1657
                    # XXX - should optional arguments just not be passed?
 
1658
                    # passing None seems a little odd, looking at the way it
 
1659
                    # turns out here... -glyph
 
1660
                    From=('file', 'file'),
 
1661
                    Print=None,
 
1662
                    optional=None,
 
1663
                    )))
 
1664
        self.assertEquals(L[1], dict(Print=None, hello='aaa'))
 
1665
 
 
1666
class PretendRemoteCertificateAuthority:
 
1667
    def checkIsPretendRemote(self):
 
1668
        return True
 
1669
 
 
1670
class IOSimCert:
 
1671
    verifyCount = 0
 
1672
 
 
1673
    def options(self, *ign):
 
1674
        return self
 
1675
 
 
1676
    def iosimVerify(self, otherCert):
 
1677
        """
 
1678
        This isn't a real certificate, and wouldn't work on a real socket, but
 
1679
        iosim specifies a different API so that we don't have to do any crypto
 
1680
        math to demonstrate that the right functions get called in the right
 
1681
        places.
 
1682
        """
 
1683
        assert otherCert is self
 
1684
        self.verifyCount += 1
 
1685
        return True
 
1686
 
 
1687
class OKCert(IOSimCert):
 
1688
    def options(self, x):
 
1689
        assert x.checkIsPretendRemote()
 
1690
        return self
 
1691
 
 
1692
class GrumpyCert(IOSimCert):
 
1693
    def iosimVerify(self, otherCert):
 
1694
        self.verifyCount += 1
 
1695
        return False
 
1696
 
 
1697
class DroppyCert(IOSimCert):
 
1698
    def __init__(self, toDrop):
 
1699
        self.toDrop = toDrop
 
1700
 
 
1701
    def iosimVerify(self, otherCert):
 
1702
        self.verifyCount += 1
 
1703
        self.toDrop.loseConnection()
 
1704
        return True
 
1705
 
 
1706
class SecurableProto(FactoryNotifier):
 
1707
 
 
1708
    factory = None
 
1709
 
 
1710
    def verifyFactory(self):
 
1711
        return [PretendRemoteCertificateAuthority()]
 
1712
 
 
1713
    def getTLSVars(self):
 
1714
        cert = self.certFactory()
 
1715
        verify = self.verifyFactory()
 
1716
        return dict(
 
1717
            tls_localCertificate=cert,
 
1718
            tls_verifyAuthorities=verify)
 
1719
    amp.StartTLS.responder(getTLSVars)
 
1720
 
 
1721
 
 
1722
 
 
1723
class TLSTest(unittest.TestCase):
 
1724
    def test_startingTLS(self):
 
1725
        """
 
1726
        Verify that starting TLS and succeeding at handshaking sends all the
 
1727
        notifications to all the right places.
 
1728
        """
 
1729
        cli, svr, p = connectedServerAndClient(
 
1730
            ServerClass=SecurableProto,
 
1731
            ClientClass=SecurableProto)
 
1732
 
 
1733
        okc = OKCert()
 
1734
        svr.certFactory = lambda : okc
 
1735
 
 
1736
        cli.callRemote(
 
1737
            amp.StartTLS, tls_localCertificate=okc,
 
1738
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
 
1739
 
 
1740
        # let's buffer something to be delivered securely
 
1741
        L = []
 
1742
        d = cli.callRemote(SecuredPing).addCallback(L.append)
 
1743
        p.flush()
 
1744
        # once for client once for server
 
1745
        self.assertEquals(okc.verifyCount, 2)
 
1746
        L = []
 
1747
        d = cli.callRemote(SecuredPing).addCallback(L.append)
 
1748
        p.flush()
 
1749
        self.assertEqual(L[0], {'pinged': True})
 
1750
 
 
1751
 
 
1752
    def test_startTooManyTimes(self):
 
1753
        """
 
1754
        Verify that the protocol will complain if we attempt to renegotiate TLS,
 
1755
        which we don't support.
 
1756
        """
 
1757
        cli, svr, p = connectedServerAndClient(
 
1758
            ServerClass=SecurableProto,
 
1759
            ClientClass=SecurableProto)
 
1760
 
 
1761
        okc = OKCert()
 
1762
        svr.certFactory = lambda : okc
 
1763
 
 
1764
        cli.callRemote(amp.StartTLS,
 
1765
                       tls_localCertificate=okc,
 
1766
                       tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
 
1767
        p.flush()
 
1768
        cli.noPeerCertificate = True # this is totally fake
 
1769
        self.assertRaises(
 
1770
            amp.OnlyOneTLS,
 
1771
            cli.callRemote,
 
1772
            amp.StartTLS,
 
1773
            tls_localCertificate=okc,
 
1774
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
 
1775
 
 
1776
 
 
1777
    def test_negotiationFailed(self):
 
1778
        """
 
1779
        Verify that starting TLS and failing on both sides at handshaking sends
 
1780
        notifications to all the right places and terminates the connection.
 
1781
        """
 
1782
 
 
1783
        badCert = GrumpyCert()
 
1784
 
 
1785
        cli, svr, p = connectedServerAndClient(
 
1786
            ServerClass=SecurableProto,
 
1787
            ClientClass=SecurableProto)
 
1788
        svr.certFactory = lambda : badCert
 
1789
 
 
1790
        cli.callRemote(amp.StartTLS,
 
1791
                       tls_localCertificate=badCert)
 
1792
 
 
1793
        p.flush()
 
1794
        # once for client once for server - but both fail
 
1795
        self.assertEquals(badCert.verifyCount, 2)
 
1796
        d = cli.callRemote(SecuredPing)
 
1797
        p.flush()
 
1798
        self.assertFailure(d, iosim.NativeOpenSSLError)
 
1799
 
 
1800
 
 
1801
    def test_negotiationFailedByClosing(self):
 
1802
        """
 
1803
        Verify that starting TLS and failing by way of a lost connection
 
1804
        notices that it is probably an SSL problem.
 
1805
        """
 
1806
 
 
1807
        cli, svr, p = connectedServerAndClient(
 
1808
            ServerClass=SecurableProto,
 
1809
            ClientClass=SecurableProto)
 
1810
        droppyCert = DroppyCert(svr.transport)
 
1811
        svr.certFactory = lambda : droppyCert
 
1812
 
 
1813
        secure = cli.callRemote(amp.StartTLS,
 
1814
                                tls_localCertificate=droppyCert)
 
1815
 
 
1816
        p.flush()
 
1817
 
 
1818
        self.assertEquals(droppyCert.verifyCount, 2)
 
1819
 
 
1820
        d = cli.callRemote(SecuredPing)
 
1821
        p.flush()
 
1822
 
 
1823
        # it might be a good idea to move this exception somewhere more
 
1824
        # reasonable.
 
1825
        self.assertFailure(d, error.PeerVerifyError)
 
1826
 
 
1827
    skip = skipSSL
 
1828
 
 
1829
 
 
1830
 
 
1831
class TLSNotAvailableTest(unittest.TestCase):
 
1832
    """
 
1833
    Tests what happened when ssl is not available in current installation.
 
1834
    """
 
1835
 
 
1836
    def setUp(self):
 
1837
        """
 
1838
        Disable ssl in amp.
 
1839
        """
 
1840
        self.ssl = amp.ssl
 
1841
        amp.ssl = None
 
1842
 
 
1843
 
 
1844
    def tearDown(self):
 
1845
        """
 
1846
        Restore ssl module.
 
1847
        """
 
1848
        amp.ssl = self.ssl
 
1849
 
 
1850
 
 
1851
    def test_callRemoteError(self):
 
1852
        """
 
1853
        Check that callRemote raises an exception when called with a
 
1854
        L{amp.StartTLS}.
 
1855
        """
 
1856
        cli, svr, p = connectedServerAndClient(
 
1857
            ServerClass=SecurableProto,
 
1858
            ClientClass=SecurableProto)
 
1859
 
 
1860
        okc = OKCert()
 
1861
        svr.certFactory = lambda : okc
 
1862
 
 
1863
        return self.assertFailure(cli.callRemote(
 
1864
            amp.StartTLS, tls_localCertificate=okc,
 
1865
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]),
 
1866
            RuntimeError)
 
1867
 
 
1868
 
 
1869
    def test_messageReceivedError(self):
 
1870
        """
 
1871
        When a client with SSL enabled talks to a server without SSL, it
 
1872
        should return a meaningful error.
 
1873
        """
 
1874
        svr = SecurableProto()
 
1875
        okc = OKCert()
 
1876
        svr.certFactory = lambda : okc
 
1877
        box = amp.Box()
 
1878
        box['_command'] = 'StartTLS'
 
1879
        box['_ask'] = '1'
 
1880
        boxes = []
 
1881
        svr.sendBox = boxes.append
 
1882
        svr.makeConnection(StringTransport())
 
1883
        svr.ampBoxReceived(box)
 
1884
        self.assertEquals(boxes,
 
1885
            [{'_error_code': 'TLS_ERROR',
 
1886
              '_error': '1',
 
1887
              '_error_description': 'TLS not available'}])
 
1888
 
 
1889
 
 
1890
 
 
1891
class InheritedError(Exception):
 
1892
    """
 
1893
    This error is used to check inheritance.
 
1894
    """
 
1895
 
 
1896
 
 
1897
 
 
1898
class OtherInheritedError(Exception):
 
1899
    """
 
1900
    This is a distinct error for checking inheritance.
 
1901
    """
 
1902
 
 
1903
 
 
1904
 
 
1905
class BaseCommand(amp.Command):
 
1906
    """
 
1907
    This provides a command that will be subclassed.
 
1908
    """
 
1909
    errors = {InheritedError: 'INHERITED_ERROR'}
 
1910
 
 
1911
 
 
1912
 
 
1913
class InheritedCommand(BaseCommand):
 
1914
    """
 
1915
    This is a command which subclasses another command but does not override
 
1916
    anything.
 
1917
    """
 
1918
 
 
1919
 
 
1920
 
 
1921
class AddErrorsCommand(BaseCommand):
 
1922
    """
 
1923
    This is a command which subclasses another command but adds errors to the
 
1924
    list.
 
1925
    """
 
1926
    arguments = [('other', amp.Boolean())]
 
1927
    errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'}
 
1928
 
 
1929
 
 
1930
 
 
1931
class NormalCommandProtocol(amp.AMP):
 
1932
    """
 
1933
    This is a protocol which responds to L{BaseCommand}, and is used to test
 
1934
    that inheritance does not interfere with the normal handling of errors.
 
1935
    """
 
1936
    def resp(self):
 
1937
        raise InheritedError()
 
1938
    BaseCommand.responder(resp)
 
1939
 
 
1940
 
 
1941
 
 
1942
class InheritedCommandProtocol(amp.AMP):
 
1943
    """
 
1944
    This is a protocol which responds to L{InheritedCommand}, and is used to
 
1945
    test that inherited commands inherit their bases' errors if they do not
 
1946
    respond to any of their own.
 
1947
    """
 
1948
    def resp(self):
 
1949
        raise InheritedError()
 
1950
    InheritedCommand.responder(resp)
 
1951
 
 
1952
 
 
1953
 
 
1954
class AddedCommandProtocol(amp.AMP):
 
1955
    """
 
1956
    This is a protocol which responds to L{AddErrorsCommand}, and is used to
 
1957
    test that inherited commands can add their own new types of errors, but
 
1958
    still respond in the same way to their parents types of errors.
 
1959
    """
 
1960
    def resp(self, other):
 
1961
        if other:
 
1962
            raise OtherInheritedError()
 
1963
        else:
 
1964
            raise InheritedError()
 
1965
    AddErrorsCommand.responder(resp)
 
1966
 
 
1967
 
 
1968
 
 
1969
class CommandInheritanceTests(unittest.TestCase):
 
1970
    """
 
1971
    These tests verify that commands inherit error conditions properly.
 
1972
    """
 
1973
 
 
1974
    def errorCheck(self, err, proto, cmd, **kw):
 
1975
        """
 
1976
        Check that the appropriate kind of error is raised when a given command
 
1977
        is sent to a given protocol.
 
1978
        """
 
1979
        c, s, p = connectedServerAndClient(ServerClass=proto,
 
1980
                                           ClientClass=proto)
 
1981
        d = c.callRemote(cmd, **kw)
 
1982
        d2 = self.failUnlessFailure(d, err)
 
1983
        p.flush()
 
1984
        return d2
 
1985
 
 
1986
 
 
1987
    def test_basicErrorPropagation(self):
 
1988
        """
 
1989
        Verify that errors specified in a superclass are respected normally
 
1990
        even if it has subclasses.
 
1991
        """
 
1992
        return self.errorCheck(
 
1993
            InheritedError, NormalCommandProtocol, BaseCommand)
 
1994
 
 
1995
 
 
1996
    def test_inheritedErrorPropagation(self):
 
1997
        """
 
1998
        Verify that errors specified in a superclass command are propagated to
 
1999
        its subclasses.
 
2000
        """
 
2001
        return self.errorCheck(
 
2002
            InheritedError, InheritedCommandProtocol, InheritedCommand)
 
2003
 
 
2004
 
 
2005
    def test_inheritedErrorAddition(self):
 
2006
        """
 
2007
        Verify that new errors specified in a subclass of an existing command
 
2008
        are honored even if the superclass defines some errors.
 
2009
        """
 
2010
        return self.errorCheck(
 
2011
            OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True)
 
2012
 
 
2013
 
 
2014
    def test_additionWithOriginalError(self):
 
2015
        """
 
2016
        Verify that errors specified in a command's superclass are respected
 
2017
        even if that command defines new errors itself.
 
2018
        """
 
2019
        return self.errorCheck(
 
2020
            InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False)
 
2021
 
 
2022
 
 
2023
def _loseAndPass(err, proto):
 
2024
    # be specific, pass on the error to the client.
 
2025
    err.trap(error.ConnectionLost, error.ConnectionDone)
 
2026
    del proto.connectionLost
 
2027
    proto.connectionLost(err)
 
2028
 
 
2029
 
 
2030
class LiveFireBase:
 
2031
    """
 
2032
    Utility for connected reactor-using tests.
 
2033
    """
 
2034
 
 
2035
    def setUp(self):
 
2036
        """
 
2037
        Create an amp server and connect a client to it.
 
2038
        """
 
2039
        from twisted.internet import reactor
 
2040
        self.serverFactory = protocol.ServerFactory()
 
2041
        self.serverFactory.protocol = self.serverProto
 
2042
        self.clientFactory = protocol.ClientFactory()
 
2043
        self.clientFactory.protocol = self.clientProto
 
2044
        self.clientFactory.onMade = defer.Deferred()
 
2045
        self.serverFactory.onMade = defer.Deferred()
 
2046
        self.serverPort = reactor.listenTCP(0, self.serverFactory)
 
2047
        self.addCleanup(self.serverPort.stopListening)
 
2048
        self.clientConn = reactor.connectTCP(
 
2049
            '127.0.0.1', self.serverPort.getHost().port,
 
2050
            self.clientFactory)
 
2051
        self.addCleanup(self.clientConn.disconnect)
 
2052
        def getProtos(rlst):
 
2053
            self.cli = self.clientFactory.theProto
 
2054
            self.svr = self.serverFactory.theProto
 
2055
        dl = defer.DeferredList([self.clientFactory.onMade,
 
2056
                                 self.serverFactory.onMade])
 
2057
        return dl.addCallback(getProtos)
 
2058
 
 
2059
    def tearDown(self):
 
2060
        """
 
2061
        Cleanup client and server connections, and check the error got at
 
2062
        C{connectionLost}.
 
2063
        """
 
2064
        L = []
 
2065
        for conn in self.cli, self.svr:
 
2066
            if conn.transport is not None:
 
2067
                # depend on amp's function connection-dropping behavior
 
2068
                d = defer.Deferred().addErrback(_loseAndPass, conn)
 
2069
                conn.connectionLost = d.errback
 
2070
                conn.transport.loseConnection()
 
2071
                L.append(d)
 
2072
        return defer.gatherResults(L
 
2073
            ).addErrback(lambda first: first.value.subFailure)
 
2074
 
 
2075
 
 
2076
def show(x):
 
2077
    import sys
 
2078
    sys.stdout.write(x+'\n')
 
2079
    sys.stdout.flush()
 
2080
 
 
2081
 
 
2082
def tempSelfSigned():
 
2083
    from twisted.internet import ssl
 
2084
 
 
2085
    sharedDN = ssl.DN(CN='shared')
 
2086
    key = ssl.KeyPair.generate()
 
2087
    cr = key.certificateRequest(sharedDN)
 
2088
    sscrd = key.signCertificateRequest(
 
2089
        sharedDN, cr, lambda dn: True, 1234567)
 
2090
    cert = key.newCertificate(sscrd)
 
2091
    return cert
 
2092
 
 
2093
if ssl is not None:
 
2094
    tempcert = tempSelfSigned()
 
2095
 
 
2096
 
 
2097
class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase):
 
2098
    clientProto = SecurableProto
 
2099
    serverProto = SecurableProto
 
2100
    def test_liveFireCustomTLS(self):
 
2101
        """
 
2102
        Using real, live TLS, actually negotiate a connection.
 
2103
 
 
2104
        This also looks at the 'peerCertificate' attribute's correctness, since
 
2105
        that's actually loaded using OpenSSL calls, but the main purpose is to
 
2106
        make sure that we didn't miss anything obvious in iosim about TLS
 
2107
        negotiations.
 
2108
        """
 
2109
 
 
2110
        cert = tempcert
 
2111
 
 
2112
        self.svr.verifyFactory = lambda : [cert]
 
2113
        self.svr.certFactory = lambda : cert
 
2114
        # only needed on the server, we specify the client below.
 
2115
 
 
2116
        def secured(rslt):
 
2117
            x = cert.digest()
 
2118
            def pinged(rslt2):
 
2119
                # Interesting.  OpenSSL won't even _tell_ us about the peer
 
2120
                # cert until we negotiate.  we should be able to do this in
 
2121
                # 'secured' instead, but it looks like we can't.  I think this
 
2122
                # is a bug somewhere far deeper than here.
 
2123
                self.failUnlessEqual(x, self.cli.hostCertificate.digest())
 
2124
                self.failUnlessEqual(x, self.cli.peerCertificate.digest())
 
2125
                self.failUnlessEqual(x, self.svr.hostCertificate.digest())
 
2126
                self.failUnlessEqual(x, self.svr.peerCertificate.digest())
 
2127
            return self.cli.callRemote(SecuredPing).addCallback(pinged)
 
2128
        return self.cli.callRemote(amp.StartTLS,
 
2129
                                   tls_localCertificate=cert,
 
2130
                                   tls_verifyAuthorities=[cert]).addCallback(secured)
 
2131
 
 
2132
    skip = skipSSL
 
2133
 
 
2134
 
 
2135
 
 
2136
class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
 
2137
    """
 
2138
    Specific implementation of server side protocol with different
 
2139
    management of TLS.
 
2140
    """
 
2141
    def getTLSVars(self):
 
2142
        """
 
2143
        @return: the global C{tempcert} certificate as local certificate.
 
2144
        """
 
2145
        return dict(tls_localCertificate=tempcert)
 
2146
    amp.StartTLS.responder(getTLSVars)
 
2147
 
 
2148
 
 
2149
class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase):
 
2150
 
 
2151
    clientProto = SimpleSymmetricCommandProtocol
 
2152
    serverProto = SimpleSymmetricCommandProtocol
 
2153
 
 
2154
    def test_liveFireDefaultTLS(self):
 
2155
        """
 
2156
        Verify that out of the box, we can start TLS to at least encrypt the
 
2157
        connection, even if we don't have any certificates to use.
 
2158
        """
 
2159
        def secured(result):
 
2160
            return self.cli.callRemote(SecuredPing)
 
2161
        return self.cli.callRemote(amp.StartTLS).addCallback(secured)
 
2162
 
 
2163
    skip = skipSSL
 
2164
 
 
2165
 
 
2166
 
 
2167
class WithServerTLSVerification(LiveFireBase, unittest.TestCase):
 
2168
    clientProto = SimpleSymmetricCommandProtocol
 
2169
    serverProto = SlightlySmartTLS
 
2170
 
 
2171
    def test_anonymousVerifyingClient(self):
 
2172
        """
 
2173
        Verify that anonymous clients can verify server certificates.
 
2174
        """
 
2175
        def secured(result):
 
2176
            return self.cli.callRemote(SecuredPing)
 
2177
        return self.cli.callRemote(amp.StartTLS,
 
2178
                                   tls_verifyAuthorities=[tempcert]
 
2179
            ).addCallback(secured)
 
2180
 
 
2181
    skip = skipSSL
 
2182
 
 
2183
 
 
2184
 
 
2185
class ProtocolIncludingArgument(amp.Argument):
 
2186
    """
 
2187
    An L{amp.Argument} which encodes its parser and serializer
 
2188
    arguments *including the protocol* into its parsed and serialized
 
2189
    forms.
 
2190
    """
 
2191
 
 
2192
    def fromStringProto(self, string, protocol):
 
2193
        """
 
2194
        Don't decode anything; just return all possible information.
 
2195
 
 
2196
        @return: A two-tuple of the input string and the protocol.
 
2197
        """
 
2198
        return (string, protocol)
 
2199
 
 
2200
    def toStringProto(self, obj, protocol):
 
2201
        """
 
2202
        Encode identifying information about L{object} and protocol
 
2203
        into a string for later verification.
 
2204
 
 
2205
        @type obj: L{object}
 
2206
        @type protocol: L{amp.AMP}
 
2207
        """
 
2208
        return "%s:%s" % (id(obj), id(protocol))
 
2209
 
 
2210
 
 
2211
 
 
2212
class ProtocolIncludingCommand(amp.Command):
 
2213
    """
 
2214
    A command that has argument and response schemas which use
 
2215
    L{ProtocolIncludingArgument}.
 
2216
    """
 
2217
    arguments = [('weird', ProtocolIncludingArgument())]
 
2218
    response = [('weird', ProtocolIncludingArgument())]
 
2219
 
 
2220
 
 
2221
 
 
2222
class MagicSchemaCommand(amp.Command):
 
2223
    """
 
2224
    A command which overrides L{parseResponse}, L{parseArguments}, and
 
2225
    L{makeResponse}.
 
2226
    """
 
2227
    def parseResponse(self, strings, protocol):
 
2228
        """
 
2229
        Don't do any parsing, just jam the input strings and protocol
 
2230
        onto the C{protocol.parseResponseArguments} attribute as a
 
2231
        two-tuple. Return the original strings.
 
2232
        """
 
2233
        protocol.parseResponseArguments = (strings, protocol)
 
2234
        return strings
 
2235
    parseResponse = classmethod(parseResponse)
 
2236
 
 
2237
 
 
2238
    def parseArguments(cls, strings, protocol):
 
2239
        """
 
2240
        Don't do any parsing, just jam the input strings and protocol
 
2241
        onto the C{protocol.parseArgumentsArguments} attribute as a
 
2242
        two-tuple. Return the original strings.
 
2243
        """
 
2244
        protocol.parseArgumentsArguments = (strings, protocol)
 
2245
        return strings
 
2246
    parseArguments = classmethod(parseArguments)
 
2247
 
 
2248
 
 
2249
    def makeArguments(cls, objects, protocol):
 
2250
        """
 
2251
        Don't do any serializing, just jam the input strings and protocol
 
2252
        onto the C{protocol.makeArgumentsArguments} attribute as a
 
2253
        two-tuple. Return the original strings.
 
2254
        """
 
2255
        protocol.makeArgumentsArguments = (objects, protocol)
 
2256
        return objects
 
2257
    makeArguments = classmethod(makeArguments)
 
2258
 
 
2259
 
 
2260
 
 
2261
class NoNetworkProtocol(amp.AMP):
 
2262
    """
 
2263
    An L{amp.AMP} subclass which overrides private methods to avoid
 
2264
    testing the network. It also provides a responder for
 
2265
    L{MagicSchemaCommand} that does nothing, so that tests can test
 
2266
    aspects of the interaction of L{amp.Command}s and L{amp.AMP}.
 
2267
 
 
2268
    @ivar parseArgumentsArguments: Arguments that have been passed to any
 
2269
    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
 
2270
    this protocol.
 
2271
 
 
2272
    @ivar parseResponseArguments: Responses that have been returned from a
 
2273
    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
 
2274
    this protocol.
 
2275
 
 
2276
    @ivar makeArgumentsArguments: Arguments that have been serialized by any
 
2277
    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
 
2278
    this protocol.
 
2279
    """
 
2280
    def _sendBoxCommand(self, commandName, strings, requiresAnswer):
 
2281
        """
 
2282
        Return a Deferred which fires with the original strings.
 
2283
        """
 
2284
        return defer.succeed(strings)
 
2285
 
 
2286
    MagicSchemaCommand.responder(lambda s, weird: {})
 
2287
 
 
2288
 
 
2289
 
 
2290
class MyBox(dict):
 
2291
    """
 
2292
    A unique dict subclass.
 
2293
    """
 
2294
 
 
2295
 
 
2296
 
 
2297
class ProtocolIncludingCommandWithDifferentCommandType(
 
2298
    ProtocolIncludingCommand):
 
2299
    """
 
2300
    A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
 
2301
    """
 
2302
    commandType = MyBox
 
2303
 
 
2304
 
 
2305
 
 
2306
class CommandTestCase(unittest.TestCase):
 
2307
    """
 
2308
    Tests for L{amp.Argument} and L{amp.Command}.
 
2309
    """
 
2310
    def test_argumentInterface(self):
 
2311
        """
 
2312
        L{Argument} instances provide L{amp.IArgumentType}.
 
2313
        """
 
2314
        self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument()))
 
2315
 
 
2316
 
 
2317
    def test_parseResponse(self):
 
2318
        """
 
2319
        There should be a class method of Command which accepts a
 
2320
        mapping of argument names to serialized forms and returns a
 
2321
        similar mapping whose values have been parsed via the
 
2322
        Command's response schema.
 
2323
        """
 
2324
        protocol = object()
 
2325
        result = 'whatever'
 
2326
        strings = {'weird': result}
 
2327
        self.assertEqual(
 
2328
            ProtocolIncludingCommand.parseResponse(strings, protocol),
 
2329
            {'weird': (result, protocol)})
 
2330
 
 
2331
 
 
2332
    def test_callRemoteCallsParseResponse(self):
 
2333
        """
 
2334
        Making a remote call on a L{amp.Command} subclass which
 
2335
        overrides the C{parseResponse} method should call that
 
2336
        C{parseResponse} method to get the response.
 
2337
        """
 
2338
        client = NoNetworkProtocol()
 
2339
        thingy = "weeoo"
 
2340
        response = client.callRemote(MagicSchemaCommand, weird=thingy)
 
2341
        def gotResponse(ign):
 
2342
            self.assertEquals(client.parseResponseArguments,
 
2343
                              ({"weird": thingy}, client))
 
2344
        response.addCallback(gotResponse)
 
2345
        return response
 
2346
 
 
2347
 
 
2348
    def test_parseArguments(self):
 
2349
        """
 
2350
        There should be a class method of L{amp.Command} which accepts
 
2351
        a mapping of argument names to serialized forms and returns a
 
2352
        similar mapping whose values have been parsed via the
 
2353
        command's argument schema.
 
2354
        """
 
2355
        protocol = object()
 
2356
        result = 'whatever'
 
2357
        strings = {'weird': result}
 
2358
        self.assertEqual(
 
2359
            ProtocolIncludingCommand.parseArguments(strings, protocol),
 
2360
            {'weird': (result, protocol)})
 
2361
 
 
2362
 
 
2363
    def test_responderCallsParseArguments(self):
 
2364
        """
 
2365
        Making a remote call on a L{amp.Command} subclass which
 
2366
        overrides the C{parseArguments} method should call that
 
2367
        C{parseArguments} method to get the arguments.
 
2368
        """
 
2369
        protocol = NoNetworkProtocol()
 
2370
        responder = protocol.locateResponder(MagicSchemaCommand.commandName)
 
2371
        argument = object()
 
2372
        response = responder(dict(weird=argument))
 
2373
        response.addCallback(
 
2374
            lambda ign: self.assertEqual(protocol.parseArgumentsArguments,
 
2375
                                         ({"weird": argument}, protocol)))
 
2376
        return response
 
2377
 
 
2378
 
 
2379
    def test_makeArguments(self):
 
2380
        """
 
2381
        There should be a class method of L{amp.Command} which accepts
 
2382
        a mapping of argument names to objects and returns a similar
 
2383
        mapping whose values have been serialized via the command's
 
2384
        argument schema.
 
2385
        """
 
2386
        protocol = object()
 
2387
        argument = object()
 
2388
        objects = {'weird': argument}
 
2389
        self.assertEqual(
 
2390
            ProtocolIncludingCommand.makeArguments(objects, protocol),
 
2391
            {'weird': "%d:%d" % (id(argument), id(protocol))})
 
2392
 
 
2393
 
 
2394
    def test_makeArgumentsUsesCommandType(self):
 
2395
        """
 
2396
        L{amp.Command.makeArguments}'s return type should be the type
 
2397
        of the result of L{amp.Command.commandType}.
 
2398
        """
 
2399
        protocol = object()
 
2400
        objects = {"weird": "whatever"}
 
2401
 
 
2402
        result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
 
2403
            objects, protocol)
 
2404
        self.assertIdentical(type(result), MyBox)
 
2405
 
 
2406
 
 
2407
    def test_callRemoteCallsMakeArguments(self):
 
2408
        """
 
2409
        Making a remote call on a L{amp.Command} subclass which
 
2410
        overrides the C{makeArguments} method should call that
 
2411
        C{makeArguments} method to get the response.
 
2412
        """
 
2413
        client = NoNetworkProtocol()
 
2414
        argument = object()
 
2415
        response = client.callRemote(MagicSchemaCommand, weird=argument)
 
2416
        def gotResponse(ign):
 
2417
            self.assertEqual(client.makeArgumentsArguments,
 
2418
                             ({"weird": argument}, client))
 
2419
        response.addCallback(gotResponse)
 
2420
        return response
 
2421
 
 
2422
 
 
2423
    def test_extraArgumentsDisallowed(self):
 
2424
        """
 
2425
        L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects
 
2426
        dictionary passed to it includes a key which does not correspond to the
 
2427
        Python identifier for a defined argument.
 
2428
        """
 
2429
        self.assertRaises(
 
2430
            amp.InvalidSignature,
 
2431
            Hello.makeArguments,
 
2432
            dict(hello="hello", bogusArgument=object()), None)
 
2433
 
 
2434
 
 
2435
    def test_wireSpellingDisallowed(self):
 
2436
        """
 
2437
        If a command argument conflicts with a Python keyword, the
 
2438
        untransformed argument name is not allowed as a key in the dictionary
 
2439
        passed to L{Command.makeArguments}.  If it is supplied,
 
2440
        L{amp.InvalidSignature} is raised.
 
2441
 
 
2442
        This may be a pointless implementation restriction which may be lifted.
 
2443
        The current behavior is tested to verify that such arguments are not
 
2444
        silently dropped on the floor (the previous behavior).
 
2445
        """
 
2446
        self.assertRaises(
 
2447
            amp.InvalidSignature,
 
2448
            Hello.makeArguments,
 
2449
            dict(hello="required", **{"print": "print value"}),
 
2450
            None)
 
2451
 
 
2452
 
 
2453
class ListOfTestsMixin:
 
2454
    """
 
2455
    Base class for testing L{ListOf}, a parameterized zero-or-more argument
 
2456
    type.
 
2457
 
 
2458
    @ivar elementType: Subclasses should set this to an L{Argument}
 
2459
        instance.  The tests will make a L{ListOf} using this.
 
2460
 
 
2461
    @ivar strings: Subclasses should set this to a dictionary mapping some
 
2462
        number of keys to the correct serialized form for some example
 
2463
        values.  These should agree with what L{elementType}
 
2464
        produces/accepts.
 
2465
 
 
2466
    @ivar objects: Subclasses should set this to a dictionary with the same
 
2467
        keys as C{strings} and with values which are the lists which should
 
2468
        serialize to the values in the C{strings} dictionary.
 
2469
    """
 
2470
    def test_toBox(self):
 
2471
        """
 
2472
        L{ListOf.toBox} extracts the list of objects from the C{objects}
 
2473
        dictionary passed to it, using the C{name} key also passed to it,
 
2474
        serializes each of the elements in that list using the L{Argument}
 
2475
        instance previously passed to its initializer, combines the serialized
 
2476
        results, and inserts the result into the C{strings} dictionary using
 
2477
        the same C{name} key.
 
2478
        """
 
2479
        stringList = amp.ListOf(self.elementType)
 
2480
        strings = amp.AmpBox()
 
2481
        for key in self.objects:
 
2482
            stringList.toBox(key, strings, self.objects.copy(), None)
 
2483
        self.assertEquals(strings, self.strings)
 
2484
 
 
2485
 
 
2486
    def test_fromBox(self):
 
2487
        """
 
2488
        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
 
2489
        """
 
2490
        stringList = amp.ListOf(self.elementType)
 
2491
        objects = {}
 
2492
        for key in self.strings:
 
2493
            stringList.fromBox(key, self.strings.copy(), objects, None)
 
2494
        self.assertEquals(objects, self.objects)
 
2495
 
 
2496
 
 
2497
 
 
2498
class ListOfStringsTests(unittest.TestCase, ListOfTestsMixin):
 
2499
    """
 
2500
    Tests for L{ListOf} combined with L{String}.
 
2501
    """
 
2502
    elementType = amp.String()
 
2503
 
 
2504
    strings = {
 
2505
        "empty": "",
 
2506
        "single": "\x00\x03foo",
 
2507
        "multiple": "\x00\x03bar\x00\x03baz\x00\x04quux"}
 
2508
 
 
2509
    objects = {
 
2510
        "empty": [],
 
2511
        "single": ["foo"],
 
2512
        "multiple": ["bar", "baz", "quux"]}
 
2513
 
 
2514
 
 
2515
class ListOfIntegersTests(unittest.TestCase, ListOfTestsMixin):
 
2516
    """
 
2517
    Tests for L{ListOf} combined with L{Integer}.
 
2518
    """
 
2519
    elementType = amp.Integer()
 
2520
 
 
2521
    strings = {
 
2522
        "empty": "",
 
2523
        "single": "\x00\x0210",
 
2524
        "multiple": "\x00\x011\x00\x0220\x00\x03500"}
 
2525
 
 
2526
    objects = {
 
2527
        "empty": [],
 
2528
        "single": [10],
 
2529
        "multiple": [1, 20, 500]}
 
2530
 
 
2531
 
 
2532
class ListOfUnicodeTests(unittest.TestCase, ListOfTestsMixin):
 
2533
    """
 
2534
    Tests for L{ListOf} combined with L{Unicode}.
 
2535
    """
 
2536
    elementType = amp.Unicode()
 
2537
 
 
2538
    strings = {
 
2539
        "empty": "",
 
2540
        "single": "\x00\x03foo",
 
2541
        "multiple": "\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world"}
 
2542
 
 
2543
    objects = {
 
2544
        "empty": [],
 
2545
        "single": [u"foo"],
 
2546
        "multiple": [u"\N{SNOWMAN}", u"Hello", u"world"]}
 
2547
 
 
2548
 
 
2549
 
 
2550
if not interfaces.IReactorSSL.providedBy(reactor):
 
2551
    skipMsg = 'This test case requires SSL support in the reactor'
 
2552
    TLSTest.skip = skipMsg
 
2553
    LiveFireTLSTestCase.skip = skipMsg
 
2554
    PlainVanillaLiveFire.skip = skipMsg
 
2555
    WithServerTLSVerification.skip = skipMsg