~certify-web-dev/twisted/certify-trunk

« back to all changes in this revision

Viewing changes to twisted/test/test_amp.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-17 14:52:35 UTC
  • mfrom: (1.1.5 upstream) (2.1.2 etch)
  • Revision ID: james.westby@ubuntu.com-20070117145235-btmig6qfmqfen0om
Tags: 2.5.0-0ubuntu1
New upstream version, compatible with python2.5.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright 2005 Divmod, Inc.  See LICENSE file for details
 
2
 
 
3
from twisted.python import filepath
 
4
from twisted.protocols import amp
 
5
from twisted.test import iosim
 
6
from twisted.trial import unittest
 
7
from twisted.internet import protocol, defer, error
 
8
 
 
9
from twisted.internet.error import PeerVerifyError
 
10
 
 
11
class TestProto(protocol.Protocol):
 
12
    def __init__(self, onConnLost, dataToSend):
 
13
        self.onConnLost = onConnLost
 
14
        self.dataToSend = dataToSend
 
15
 
 
16
    def connectionMade(self):
 
17
        self.data = []
 
18
        self.transport.write(self.dataToSend)
 
19
 
 
20
    def dataReceived(self, bytes):
 
21
        self.data.append(bytes)
 
22
        # self.transport.loseConnection()
 
23
 
 
24
    def connectionLost(self, reason):
 
25
        self.onConnLost.callback(self.data)
 
26
 
 
27
class SimpleSymmetricProtocol(amp.AMP):
 
28
 
 
29
    def sendHello(self, text):
 
30
        return self.callRemoteString(
 
31
            "hello",
 
32
            hello=text)
 
33
 
 
34
    def amp_HELLO(self, box):
 
35
        return amp.Box(hello=box['hello'])
 
36
 
 
37
    def amp_HOWDOYOUDO(self, box):
 
38
        return amp.QuitBox(howdoyoudo='world')
 
39
 
 
40
class UnfriendlyGreeting(Exception):
 
41
    """Greeting was insufficiently kind.
 
42
    """
 
43
 
 
44
class DeathThreat(Exception):
 
45
    """Greeting was insufficiently kind.
 
46
    """
 
47
 
 
48
class UnknownProtocol(Exception):
 
49
    """Asked to switch to the wrong protocol.
 
50
    """
 
51
 
 
52
 
 
53
class TransportPeer(amp.Argument):
 
54
    # this serves as some informal documentation for how to get variables from
 
55
    # the protocol or your environment and pass them to methods as arguments.
 
56
    def retrieve(self, d, name, proto):
 
57
        return ''
 
58
 
 
59
    def fromStringProto(self, notAString, proto):
 
60
        return proto.transport.getPeer()
 
61
 
 
62
    def toBox(self, name, strings, objects, proto):
 
63
        return
 
64
 
 
65
class Hello(amp.Command):
 
66
 
 
67
    commandName = 'hello'
 
68
 
 
69
    arguments = [('hello', amp.String()),
 
70
                 ('optional', amp.Boolean(optional=True)),
 
71
                 ('print', amp.Unicode(optional=True)),
 
72
                 ('from', TransportPeer(optional=True)),
 
73
                 ('mixedCase', amp.String(optional=True)),
 
74
                 ('dash-arg', amp.String(optional=True)),
 
75
                 ('underscore_arg', amp.String(optional=True))]
 
76
 
 
77
    response = [('hello', amp.String()),
 
78
                ('print', amp.Unicode(optional=True))]
 
79
 
 
80
    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
 
81
 
 
82
    fatalErrors = {DeathThreat: 'DEAD'}
 
83
 
 
84
class NoAnswerHello(Hello):
 
85
    commandName = Hello.commandName
 
86
    requiresAnswer = False
 
87
 
 
88
class FutureHello(amp.Command):
 
89
    commandName = 'hello'
 
90
 
 
91
    arguments = [('hello', amp.String()),
 
92
                 ('optional', amp.Boolean(optional=True)),
 
93
                 ('print', amp.Unicode(optional=True)),
 
94
                 ('from', TransportPeer(optional=True)),
 
95
                 ('bonus', amp.String(optional=True)), # addt'l arguments
 
96
                                                       # should generally be
 
97
                                                       # added at the end, and
 
98
                                                       # be optional...
 
99
                 ]
 
100
 
 
101
    response = [('hello', amp.String()),
 
102
                ('print', amp.Unicode(optional=True))]
 
103
 
 
104
    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
 
105
 
 
106
class WTF(amp.Command):
 
107
    """
 
108
    An example of an invalid command.
 
109
    """
 
110
 
 
111
 
 
112
class BrokenReturn(amp.Command):
 
113
    """ An example of a perfectly good command, but the handler is going to return
 
114
    None...
 
115
    """
 
116
 
 
117
    commandName = 'broken_return'
 
118
 
 
119
class Goodbye(amp.Command):
 
120
    # commandName left blank on purpose: this tests implicit command names.
 
121
    response = [('goodbye', amp.String())]
 
122
    responseType = amp.QuitBox
 
123
 
 
124
class Howdoyoudo(amp.Command):
 
125
    commandName = 'howdoyoudo'
 
126
    # responseType = amp.QuitBox
 
127
 
 
128
class WaitForever(amp.Command):
 
129
    commandName = 'wait_forever'
 
130
 
 
131
class GetList(amp.Command):
 
132
    commandName = 'getlist'
 
133
    arguments = [('length', amp.Integer())]
 
134
    response = [('body', amp.AmpList([('x', amp.Integer())]))]
 
135
 
 
136
class SecuredPing(amp.Command):
 
137
    # XXX TODO: actually make this refuse to send over an insecure connection
 
138
    response = [('pinged', amp.Boolean())]
 
139
 
 
140
class TestSwitchProto(amp.ProtocolSwitchCommand):
 
141
    commandName = 'Switch-Proto'
 
142
 
 
143
    arguments = [
 
144
        ('name', amp.String()),
 
145
        ]
 
146
    errors = {UnknownProtocol: 'UNKNOWN'}
 
147
 
 
148
class SingleUseFactory(protocol.ClientFactory):
 
149
    def __init__(self, proto):
 
150
        self.proto = proto
 
151
        self.proto.factory = self
 
152
 
 
153
    def buildProtocol(self, addr):
 
154
        p, self.proto = self.proto, None
 
155
        return p
 
156
 
 
157
    reasonFailed = None
 
158
 
 
159
    def clientConnectionFailed(self, connector, reason):
 
160
        self.reasonFailed = reason
 
161
        return
 
162
 
 
163
THING_I_DONT_UNDERSTAND = 'gwebol nargo'
 
164
class ThingIDontUnderstandError(Exception):
 
165
    pass
 
166
 
 
167
class FactoryNotifier(amp.AMP):
 
168
    factory = None
 
169
    def connectionMade(self):
 
170
        if self.factory is not None:
 
171
            self.factory.theProto = self
 
172
            if hasattr(self.factory, 'onMade'):
 
173
                self.factory.onMade.callback(None)
 
174
 
 
175
    def emitpong(self):
 
176
        from twisted.internet.interfaces import ISSLTransport
 
177
        if not ISSLTransport.providedBy(self.transport):
 
178
            raise DeathThreat("only send secure pings over secure channels")
 
179
        return {'pinged': True}
 
180
    SecuredPing.responder(emitpong)
 
181
 
 
182
 
 
183
class SimpleSymmetricCommandProtocol(FactoryNotifier):
 
184
    maybeLater = None
 
185
    def __init__(self, onConnLost=None):
 
186
        amp.AMP.__init__(self)
 
187
        self.onConnLost = onConnLost
 
188
 
 
189
    def sendHello(self, text):
 
190
        return self.callRemote(Hello, hello=text)
 
191
 
 
192
    def sendUnicodeHello(self, text, translation):
 
193
        return self.callRemote(Hello, hello=text, Print=translation)
 
194
 
 
195
    greeted = False
 
196
 
 
197
    def cmdHello(self, hello, From, optional=None, Print=None,
 
198
                 mixedCase=None, dash_arg=None, underscore_arg=None):
 
199
        assert From == self.transport.getPeer()
 
200
        if hello == THING_I_DONT_UNDERSTAND:
 
201
            raise ThingIDontUnderstandError()
 
202
        if hello.startswith('fuck'):
 
203
            raise UnfriendlyGreeting("Don't be a dick.")
 
204
        if hello == 'die':
 
205
            raise DeathThreat("aieeeeeeeee")
 
206
        result = dict(hello=hello)
 
207
        if Print is not None:
 
208
            result.update(dict(Print=Print))
 
209
        self.greeted = True
 
210
        return result
 
211
    Hello.responder(cmdHello)
 
212
 
 
213
    def cmdGetlist(self, length):
 
214
        return {'body': [dict(x=1)] * length}
 
215
    GetList.responder(cmdGetlist)
 
216
 
 
217
    def waitforit(self):
 
218
        self.waiting = defer.Deferred()
 
219
        return self.waiting
 
220
    WaitForever.responder(waitforit)
 
221
 
 
222
    def howdo(self):
 
223
        return dict(howdoyoudo='world')
 
224
    Howdoyoudo.responder(howdo)
 
225
 
 
226
    def saybye(self):
 
227
        return dict(goodbye="everyone")
 
228
    Goodbye.responder(saybye)
 
229
 
 
230
    def switchToTestProtocol(self, fail=False):
 
231
        if fail:
 
232
            name = 'no-proto'
 
233
        else:
 
234
            name = 'test-proto'
 
235
        p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
 
236
        return self.callRemote(
 
237
            TestSwitchProto,
 
238
            SingleUseFactory(p), name=name).addCallback(lambda ign: p)
 
239
 
 
240
    def switchit(self, name):
 
241
        if name == 'test-proto':
 
242
            return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
 
243
        raise UnknownProtocol(name)
 
244
    TestSwitchProto.responder(switchit)
 
245
 
 
246
    def donothing(self):
 
247
        return None
 
248
    BrokenReturn.responder(donothing)
 
249
 
 
250
 
 
251
class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
 
252
    def switchit(self, name):
 
253
        if name == 'test-proto':
 
254
            self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
 
255
            self.maybeLater = defer.Deferred()
 
256
            return self.maybeLater
 
257
        raise UnknownProtocol(name)
 
258
    TestSwitchProto.responder(switchit)
 
259
 
 
260
class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
 
261
    def badResponder(self, hello, From, optional=None, Print=None,
 
262
                     mixedCase=None, dash_arg=None, underscore_arg=None):
 
263
        """
 
264
        This responder does nothing and forgets to return a dictionary.
 
265
        """
 
266
    NoAnswerHello.responder(badResponder)
 
267
 
 
268
class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
 
269
    def goodNoAnswerResponder(self, hello, From, optional=None, Print=None,
 
270
                              mixedCase=None, dash_arg=None, underscore_arg=None):
 
271
        return dict(hello=hello+"-noanswer")
 
272
    NoAnswerHello.responder(goodNoAnswerResponder)
 
273
 
 
274
def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol,
 
275
                             ClientClass=SimpleSymmetricProtocol,
 
276
                             *a, **kw):
 
277
    """Returns a 3-tuple: (client, server, pump)
 
278
    """
 
279
    return iosim.connectedServerAndClient(
 
280
        ServerClass, ClientClass,
 
281
        *a, **kw)
 
282
 
 
283
class TotallyDumbProtocol(protocol.Protocol):
 
284
    buf = ''
 
285
    def dataReceived(self, data):
 
286
        self.buf += data
 
287
 
 
288
class LiteralAmp(amp.AMP):
 
289
    def __init__(self):
 
290
        self.boxes = []
 
291
 
 
292
    def ampBoxReceived(self, box):
 
293
        self.boxes.append(box)
 
294
        return
 
295
 
 
296
class ParsingTest(unittest.TestCase):
 
297
 
 
298
    def test_booleanValues(self):
 
299
        """
 
300
        Verify that the Boolean parser parses 'True' and 'False', but nothing
 
301
        else.
 
302
        """
 
303
        b = amp.Boolean()
 
304
        self.assertEquals(b.fromString("True"), True)
 
305
        self.assertEquals(b.fromString("False"), False)
 
306
        self.assertRaises(TypeError, b.fromString, "ninja")
 
307
        self.assertRaises(TypeError, b.fromString, "true")
 
308
        self.assertRaises(TypeError, b.fromString, "TRUE")
 
309
        self.assertEquals(b.toString(True), 'True')
 
310
        self.assertEquals(b.toString(False), 'False')
 
311
 
 
312
    def test_pathValueRoundTrip(self):
 
313
        """
 
314
        Verify the 'Path' argument can parse and emit a file path.
 
315
        """
 
316
        fp = filepath.FilePath(self.mktemp())
 
317
        p = amp.Path()
 
318
        s = p.toString(fp)
 
319
        v = p.fromString(s)
 
320
        self.assertNotIdentical(fp, v) # sanity check
 
321
        self.assertEquals(fp, v)
 
322
 
 
323
 
 
324
    def test_sillyEmptyThing(self):
 
325
        """
 
326
        Test that empty boxes raise an error; they aren't supposed to be sent
 
327
        on purpose.
 
328
        """
 
329
        a = amp.AMP()
 
330
        return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
 
331
 
 
332
 
 
333
    def test_ParsingRoundTrip(self):
 
334
        """
 
335
        Verify that various kinds of data make it through the encode/parse
 
336
        round-trip unharmed.
 
337
        """
 
338
        c, s, p = connectedServerAndClient(ClientClass=LiteralAmp,
 
339
                                           ServerClass=LiteralAmp)
 
340
 
 
341
        SIMPLE = ('simple', 'test')
 
342
        CE = ('ceq', ': ')
 
343
        CR = ('crtest', 'test\r')
 
344
        LF = ('lftest', 'hello\n')
 
345
        NEWLINE = ('newline', 'test\r\none\r\ntwo')
 
346
        NEWLINE2 = ('newline2', 'test\r\none\r\n two')
 
347
        BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline')
 
348
        BODYTEST = ('body', 'blah\r\n\r\ntesttest')
 
349
 
 
350
        testData = [
 
351
            [SIMPLE],
 
352
            [SIMPLE, BODYTEST],
 
353
            [SIMPLE, CE],
 
354
            [SIMPLE, CR],
 
355
            [SIMPLE, CE, CR, LF],
 
356
            [CE, CR, LF],
 
357
            [SIMPLE, NEWLINE, CE, NEWLINE2],
 
358
            [BODYTEST, SIMPLE, NEWLINE]
 
359
            ]
 
360
 
 
361
        for test in testData:
 
362
            jb = amp.Box()
 
363
            jb.update(dict(test))
 
364
            jb._sendTo(c)
 
365
            p.flush()
 
366
            self.assertEquals(s.boxes[-1], jb)
 
367
 
 
368
SWITCH_CLIENT_DATA = 'Success!'
 
369
SWITCH_SERVER_DATA = 'No, really.  Success.'
 
370
 
 
371
class AMPTest(unittest.TestCase):
 
372
 
 
373
    def test_helloWorld(self):
 
374
        """
 
375
        Verify that a simple command can be sent and its response received with
 
376
        the simple low-level string-based API.
 
377
        """
 
378
        c, s, p = connectedServerAndClient()
 
379
        L = []
 
380
        HELLO = 'world'
 
381
        c.sendHello(HELLO).addCallback(L.append)
 
382
        p.flush()
 
383
        self.assertEquals(L[0]['hello'], HELLO)
 
384
 
 
385
 
 
386
    def test_wireFormatRoundTrip(self):
 
387
        """
 
388
        Verify that mixed-case, underscored and dashed arguments are mapped to
 
389
        their python names properly.
 
390
        """
 
391
        c, s, p = connectedServerAndClient()
 
392
        L = []
 
393
        HELLO = 'world'
 
394
        c.sendHello(HELLO).addCallback(L.append)
 
395
        p.flush()
 
396
        self.assertEquals(L[0]['hello'], HELLO)
 
397
 
 
398
 
 
399
    def test_helloWorldUnicode(self):
 
400
        """
 
401
        Verify that unicode arguments can be encoded and decoded.
 
402
        """
 
403
        c, s, p = connectedServerAndClient(
 
404
            ServerClass=SimpleSymmetricCommandProtocol,
 
405
            ClientClass=SimpleSymmetricCommandProtocol)
 
406
        L = []
 
407
        HELLO = 'world'
 
408
        HELLO_UNICODE = 'wor\u1234ld'
 
409
        c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
 
410
        p.flush()
 
411
        self.assertEquals(L[0]['hello'], HELLO)
 
412
        self.assertEquals(L[0]['Print'], HELLO_UNICODE)
 
413
 
 
414
 
 
415
    def test_unknownCommandLow(self):
 
416
        """
 
417
        Verify that unknown commands using low-level APIs will be rejected with an
 
418
        error, but will NOT terminate the connection.
 
419
        """
 
420
        c, s, p = connectedServerAndClient()
 
421
        L = []
 
422
        def clearAndAdd(e):
 
423
            """
 
424
            You can't propagate the error...
 
425
            """
 
426
            e.trap(amp.UnhandledCommand)
 
427
            return "OK"
 
428
        c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append)
 
429
        p.flush()
 
430
        self.assertEquals(L.pop(), "OK")
 
431
        HELLO = 'world'
 
432
        c.sendHello(HELLO).addCallback(L.append)
 
433
        p.flush()
 
434
        self.assertEquals(L[0]['hello'], HELLO)
 
435
 
 
436
 
 
437
    def test_unknownCommandHigh(self):
 
438
        """
 
439
        Verify that unknown commands using high-level APIs will be rejected with an
 
440
        error, but will NOT terminate the connection.
 
441
        """
 
442
        c, s, p = connectedServerAndClient()
 
443
        L = []
 
444
        def clearAndAdd(e):
 
445
            """
 
446
            You can't propagate the error...
 
447
            """
 
448
            e.trap(amp.UnhandledCommand)
 
449
            return "OK"
 
450
        c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
 
451
        p.flush()
 
452
        self.assertEquals(L.pop(), "OK")
 
453
        HELLO = 'world'
 
454
        c.sendHello(HELLO).addCallback(L.append)
 
455
        p.flush()
 
456
        self.assertEquals(L[0]['hello'], HELLO)
 
457
 
 
458
 
 
459
    def test_brokenReturnValue(self):
 
460
        """
 
461
        It can be very confusing if you write some code which responds to a
 
462
        command, but gets the return value wrong.  Most commonly you end up
 
463
        returning None instead of a dictionary.
 
464
 
 
465
        Verify that if that happens, the framework logs a useful error.
 
466
        """
 
467
        L = []
 
468
        SimpleSymmetricCommandProtocol().dispatchCommand(
 
469
            amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append)
 
470
        blr = L[0].trap(amp.BadLocalReturn)
 
471
        self.failUnlessIn('None', repr(L[0].value))
 
472
 
 
473
 
 
474
 
 
475
    def test_unknownArgument(self):
 
476
        """
 
477
        Verify that unknown arguments are ignored, and not passed to a Python
 
478
        function which can't accept them.
 
479
        """
 
480
        c, s, p = connectedServerAndClient(
 
481
            ServerClass=SimpleSymmetricCommandProtocol,
 
482
            ClientClass=SimpleSymmetricCommandProtocol)
 
483
        L = []
 
484
        HELLO = 'world'
 
485
        # c.sendHello(HELLO).addCallback(L.append)
 
486
        c.callRemote(FutureHello,
 
487
                     hello=HELLO,
 
488
                     bonus="I'm not in the book!").addCallback(
 
489
            L.append)
 
490
        p.flush()
 
491
        self.assertEquals(L[0]['hello'], HELLO)
 
492
 
 
493
 
 
494
    def test_simpleReprs(self):
 
495
        """
 
496
        Verify that the various Box objects repr properly, for debugging.
 
497
        """
 
498
        self.assertEquals(type(repr(amp._TLSBox())), str)
 
499
        self.assertEquals(type(repr(amp._SwitchBox('a'))), str)
 
500
        self.assertEquals(type(repr(amp.QuitBox())), str)
 
501
        self.assertEquals(type(repr(amp.AmpBox())), str)
 
502
        self.failUnless("AmpBox" in repr(amp.AmpBox()))
 
503
 
 
504
    def test_keyTooLong(self):
 
505
        """
 
506
        Verify that a key that is too long will immediately raise a synchronous
 
507
        exception.
 
508
        """
 
509
        c, s, p = connectedServerAndClient()
 
510
        L = []
 
511
        x = "H" * (0xff+1)
 
512
        tl = self.assertRaises(amp.TooLong,
 
513
                               c.callRemoteString, "Hello",
 
514
                               **{x: "hi"})
 
515
        self.failUnless(tl.isKey)
 
516
        self.failUnless(tl.isLocal)
 
517
        self.failUnlessIdentical(tl.keyName, None)
 
518
        self.failUnlessIdentical(tl.value, x)
 
519
        self.failUnless(str(len(x)) in repr(tl))
 
520
        self.failUnless("key" in repr(tl))
 
521
 
 
522
 
 
523
    def test_valueTooLong(self):
 
524
        """
 
525
        Verify that attempting to send value longer than 64k will immediately
 
526
        raise an exception.
 
527
        """
 
528
        c, s, p = connectedServerAndClient()
 
529
        L = []
 
530
        x = "H" * (0xffff+1)
 
531
        tl = self.assertRaises(amp.TooLong, c.sendHello, x)
 
532
        p.flush()
 
533
        self.failIf(tl.isKey)
 
534
        self.failUnless(tl.isLocal)
 
535
        self.failUnlessIdentical(tl.keyName, 'hello')
 
536
        self.failUnlessIdentical(tl.value, x)
 
537
        self.failUnless(str(len(x)) in repr(tl))
 
538
        self.failUnless("value" in repr(tl))
 
539
        self.failUnless('hello' in repr(tl))
 
540
 
 
541
 
 
542
    def test_helloWorldCommand(self):
 
543
        """
 
544
        Verify that a simple command can be sent and its response received with
 
545
        the high-level value parsing API.
 
546
        """
 
547
        c, s, p = connectedServerAndClient(
 
548
            ServerClass=SimpleSymmetricCommandProtocol,
 
549
            ClientClass=SimpleSymmetricCommandProtocol)
 
550
        L = []
 
551
        HELLO = 'world'
 
552
        c.sendHello(HELLO).addCallback(L.append)
 
553
        p.flush()
 
554
        self.assertEquals(L[0]['hello'], HELLO)
 
555
 
 
556
 
 
557
    def test_helloErrorHandling(self):
 
558
        """
 
559
        Verify that if a known error type is raised and handled, it will be
 
560
        properly relayed to the other end of the connection and translated into
 
561
        an exception, and no error will be logged.
 
562
        """
 
563
        L=[]
 
564
        c, s, p = connectedServerAndClient(
 
565
            ServerClass=SimpleSymmetricCommandProtocol,
 
566
            ClientClass=SimpleSymmetricCommandProtocol)
 
567
        HELLO = 'fuck you'
 
568
        c.sendHello(HELLO).addErrback(L.append)
 
569
        p.flush()
 
570
        L[0].trap(UnfriendlyGreeting)
 
571
        self.assertEquals(str(L[0].value), "Don't be a dick.")
 
572
 
 
573
 
 
574
    def test_helloFatalErrorHandling(self):
 
575
        """
 
576
        Verify that if a known, fatal error type is raised and handled, it will
 
577
        be properly relayed to the other end of the connection and translated
 
578
        into an exception, no error will be logged, and the connection will be
 
579
        terminated.
 
580
        """
 
581
        L=[]
 
582
        c, s, p = connectedServerAndClient(
 
583
            ServerClass=SimpleSymmetricCommandProtocol,
 
584
            ClientClass=SimpleSymmetricCommandProtocol)
 
585
        HELLO = 'die'
 
586
        c.sendHello(HELLO).addErrback(L.append)
 
587
        p.flush()
 
588
        L.pop().trap(DeathThreat)
 
589
        c.sendHello(HELLO).addErrback(L.append)
 
590
        p.flush()
 
591
        L.pop().trap(error.ConnectionDone)
 
592
 
 
593
 
 
594
 
 
595
    def test_helloNoErrorHandling(self):
 
596
        """
 
597
        Verify that if an unknown error type is raised, it will be relayed to
 
598
        the other end of the connection and translated into an exception, it
 
599
        will be logged, and then the connection will be dropped.
 
600
        """
 
601
        L=[]
 
602
        c, s, p = connectedServerAndClient(
 
603
            ServerClass=SimpleSymmetricCommandProtocol,
 
604
            ClientClass=SimpleSymmetricCommandProtocol)
 
605
        HELLO = THING_I_DONT_UNDERSTAND
 
606
        c.sendHello(HELLO).addErrback(L.append)
 
607
        p.flush()
 
608
        ure = L.pop()
 
609
        ure.trap(amp.UnknownRemoteError)
 
610
        c.sendHello(HELLO).addErrback(L.append)
 
611
        cl = L.pop()
 
612
        cl.trap(error.ConnectionDone)
 
613
        # The exception should have been logged.
 
614
        self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError))
 
615
 
 
616
 
 
617
 
 
618
    def test_lateAnswer(self):
 
619
        """
 
620
        Verify that a command that does not get answered until after the
 
621
        connection terminates will not cause any errors.
 
622
        """
 
623
        c, s, p = connectedServerAndClient(
 
624
            ServerClass=SimpleSymmetricCommandProtocol,
 
625
            ClientClass=SimpleSymmetricCommandProtocol)
 
626
        L = []
 
627
        HELLO = 'world'
 
628
        c.callRemote(WaitForever).addErrback(L.append)
 
629
        p.flush()
 
630
        self.assertEquals(L, [])
 
631
        s.transport.loseConnection()
 
632
        p.flush()
 
633
        L.pop().trap(error.ConnectionDone)
 
634
        # Just make sure that it doesn't error...
 
635
        s.waiting.callback({})
 
636
        return s.waiting
 
637
 
 
638
 
 
639
    def test_requiresNoAnswer(self):
 
640
        """
 
641
        Verify that a command that requires no answer is run.
 
642
        """
 
643
        L=[]
 
644
        c, s, p = connectedServerAndClient(
 
645
            ServerClass=SimpleSymmetricCommandProtocol,
 
646
            ClientClass=SimpleSymmetricCommandProtocol)
 
647
        HELLO = 'world'
 
648
        c.callRemote(NoAnswerHello, hello=HELLO)
 
649
        p.flush()
 
650
        self.failUnless(s.greeted)
 
651
 
 
652
 
 
653
    def test_requiresNoAnswerFail(self):
 
654
        """
 
655
        Verify that commands sent after a failed no-answer request do not complete.
 
656
        """
 
657
        L=[]
 
658
        c, s, p = connectedServerAndClient(
 
659
            ServerClass=SimpleSymmetricCommandProtocol,
 
660
            ClientClass=SimpleSymmetricCommandProtocol)
 
661
        HELLO = 'fuck you'
 
662
        c.callRemote(NoAnswerHello, hello=HELLO)
 
663
        p.flush()
 
664
        # This should be logged locally.
 
665
        self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError))
 
666
        HELLO = 'world'
 
667
        c.callRemote(Hello, hello=HELLO).addErrback(L.append)
 
668
        p.flush()
 
669
        L.pop().trap(error.ConnectionDone)
 
670
        self.failIf(s.greeted)
 
671
 
 
672
 
 
673
    def test_noAnswerResponderBadAnswer(self):
 
674
        """
 
675
        Verify that responders of requiresAnswer=False commands have to return
 
676
        a dictionary anyway.
 
677
 
 
678
        (requiresAnswer is a hint from the _client_ - the server may be called
 
679
        upon to answer commands in any case, if the client wants to know when
 
680
        they complete.)
 
681
        """
 
682
        c, s, p = connectedServerAndClient(
 
683
            ServerClass=BadNoAnswerCommandProtocol,
 
684
            ClientClass=SimpleSymmetricCommandProtocol)
 
685
        c.callRemote(NoAnswerHello, hello="hello")
 
686
        p.flush()
 
687
        le = self.flushLoggedErrors(amp.BadLocalReturn)
 
688
        self.assertEquals(len(le), 1)
 
689
 
 
690
 
 
691
    def test_noAnswerResponderAskedForAnswer(self):
 
692
        """
 
693
        Verify that responders with requiresAnswer=False will actually respond
 
694
        if the client sets requiresAnswer=True.  In other words, verify that
 
695
        requiresAnswer is a hint honored only by the client.
 
696
        """
 
697
        c, s, p = connectedServerAndClient(
 
698
            ServerClass=NoAnswerCommandProtocol,
 
699
            ClientClass=SimpleSymmetricCommandProtocol)
 
700
        L = []
 
701
        c.callRemote(Hello, hello="Hello!").addCallback(L.append)
 
702
        p.flush()
 
703
        self.assertEquals(len(L), 1)
 
704
        self.assertEquals(L, [dict(hello="Hello!-noanswer",
 
705
                                   Print=None)])  # Optional response argument
 
706
 
 
707
 
 
708
    def test_ampListCommand(self):
 
709
        """
 
710
        Test encoding of an argument that uses the AmpList encoding.
 
711
        """
 
712
        c, s, p = connectedServerAndClient(
 
713
            ServerClass=SimpleSymmetricCommandProtocol,
 
714
            ClientClass=SimpleSymmetricCommandProtocol)
 
715
        L = []
 
716
        c.callRemote(GetList, length=10).addCallback(L.append)
 
717
        p.flush()
 
718
        values = L.pop().get('body')
 
719
        self.assertEquals(values, [{'x': 1}] * 10)
 
720
 
 
721
 
 
722
    def test_failEarlyOnArgSending(self):
 
723
        """
 
724
        Verify that if we pass an invalid argument list (omitting an argument), an
 
725
        exception will be raised.
 
726
        """
 
727
        okayCommand = Hello(hello="What?")
 
728
        self.assertRaises(amp.InvalidSignature, Hello)
 
729
 
 
730
 
 
731
    def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol,
 
732
                            spuriousTraffic=False):
 
733
        """
 
734
        Verify that it is possible to switch to another protocol mid-connection and
 
735
        send data to it successfully.
 
736
        """
 
737
        self.testSucceeded = False
 
738
 
 
739
        serverDeferred = defer.Deferred()
 
740
        serverProto = switcher(serverDeferred)
 
741
        clientDeferred = defer.Deferred()
 
742
        clientProto = switcher(clientDeferred)
 
743
        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
 
744
                                           ClientClass=lambda: clientProto)
 
745
 
 
746
        if spuriousTraffic:
 
747
            wfdr = []           # remote
 
748
            wfd = c.callRemote(WaitForever).addErrback(wfdr.append)
 
749
        switchDeferred = c.switchToTestProtocol()
 
750
        if spuriousTraffic:
 
751
            self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world')
 
752
 
 
753
        def cbConnsLost(((serverSuccess, serverData),
 
754
                         (clientSuccess, clientData))):
 
755
            self.failUnless(serverSuccess)
 
756
            self.failUnless(clientSuccess)
 
757
            self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA)
 
758
            self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA)
 
759
            self.testSucceeded = True
 
760
 
 
761
        def cbSwitch(proto):
 
762
            return defer.DeferredList(
 
763
                [serverDeferred, clientDeferred]).addCallback(cbConnsLost)
 
764
 
 
765
        switchDeferred.addCallback(cbSwitch)
 
766
        p.flush()
 
767
        if serverProto.maybeLater is not None:
 
768
            serverProto.maybeLater.callback(serverProto.maybeLaterProto)
 
769
            p.flush()
 
770
        if spuriousTraffic:
 
771
            # switch is done here; do this here to make sure that if we're
 
772
            # going to corrupt the connection, we do it before it's closed.
 
773
            s.waiting.callback({})
 
774
            p.flush()
 
775
        c.transport.loseConnection() # close it
 
776
        p.flush()
 
777
        self.failUnless(self.testSucceeded)
 
778
 
 
779
 
 
780
    def test_protocolSwitchDeferred(self):
 
781
        """
 
782
        Verify that protocol-switching even works if the value returned from
 
783
        the command that does the switch is deferred.
 
784
        """
 
785
        return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
 
786
 
 
787
    def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
 
788
        """
 
789
        Verify that if we try to switch protocols and it fails, the connection
 
790
        stays up and we can go back to speaking AMP.
 
791
        """
 
792
        self.testSucceeded = False
 
793
 
 
794
        serverDeferred = defer.Deferred()
 
795
        serverProto = switcher(serverDeferred)
 
796
        clientDeferred = defer.Deferred()
 
797
        clientProto = switcher(clientDeferred)
 
798
        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
 
799
                                           ClientClass=lambda: clientProto)
 
800
        L = []
 
801
        switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append)
 
802
        p.flush()
 
803
        L.pop().trap(UnknownProtocol)
 
804
        self.failIf(self.testSucceeded)
 
805
        # It's a known error, so let's send a "hello" on the same connection;
 
806
        # it should work.
 
807
        c.sendHello('world').addCallback(L.append)
 
808
        p.flush()
 
809
        self.assertEqual(L.pop()['hello'], 'world')
 
810
 
 
811
 
 
812
    def test_trafficAfterSwitch(self):
 
813
        """
 
814
        Verify that attempts to send traffic after a switch will not corrupt
 
815
        the nested protocol.
 
816
        """
 
817
        return self.test_protocolSwitch(spuriousTraffic=True)
 
818
 
 
819
 
 
820
    def test_quitBoxQuits(self):
 
821
        """
 
822
        Verify that commands with a responseType of QuitBox will in fact
 
823
        terminate the connection.
 
824
        """
 
825
        c, s, p = connectedServerAndClient(
 
826
            ServerClass=SimpleSymmetricCommandProtocol,
 
827
            ClientClass=SimpleSymmetricCommandProtocol)
 
828
 
 
829
        L = []
 
830
        HELLO = 'world'
 
831
        GOODBYE = 'everyone'
 
832
        c.sendHello(HELLO).addCallback(L.append)
 
833
        p.flush()
 
834
        self.assertEquals(L.pop()['hello'], HELLO)
 
835
        c.callRemote(Goodbye).addCallback(L.append)
 
836
        p.flush()
 
837
        self.assertEquals(L.pop()['goodbye'], GOODBYE)
 
838
        c.sendHello(HELLO).addErrback(L.append)
 
839
        L.pop().trap(error.ConnectionDone)
 
840
 
 
841
 
 
842
 
 
843
    def test_basicLiteralEmit(self):
 
844
        """
 
845
        Verify that the command dictionaries for a callRemoteN look correct
 
846
        after being serialized and parsed.
 
847
        """
 
848
        c, s, p = connectedServerAndClient()
 
849
        L = []
 
850
        s.ampBoxReceived = L.append
 
851
        c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
 
852
                     dash_arg='x', underscore_arg='y')
 
853
        p.flush()
 
854
        self.assertEquals(len(L), 1)
 
855
        for k, v in [('_command', Hello.commandName),
 
856
                     ('hello', 'hello test'),
 
857
                     ('mixedCase', 'mixed case arg test'),
 
858
                     ('dash-arg', 'x'),
 
859
                     ('underscore_arg', 'y')]:
 
860
            self.assertEquals(L[-1].pop(k), v)
 
861
        L[-1].pop('_ask')
 
862
        self.assertEquals(L[-1], {})
 
863
 
 
864
 
 
865
    def test_basicStructuredEmit(self):
 
866
        """
 
867
        Verify that a call similar to basicLiteralEmit's is handled properly with
 
868
        high-level quoting and passing to Python methods, and that argument
 
869
        names are correctly handled.
 
870
        """
 
871
        L = []
 
872
        class StructuredHello(amp.AMP):
 
873
            def h(self, *a, **k):
 
874
                L.append((a, k))
 
875
                return dict(hello='aaa')
 
876
            Hello.responder(h)
 
877
        c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
 
878
        c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
 
879
                     dash_arg='x', underscore_arg='y').addCallback(L.append)
 
880
        p.flush()
 
881
        self.assertEquals(len(L), 2)
 
882
        self.assertEquals(L[0],
 
883
                          ((), dict(
 
884
                    hello='hello test',
 
885
                    mixedCase='mixed case arg test',
 
886
                    dash_arg='x',
 
887
                    underscore_arg='y',
 
888
 
 
889
                    # XXX - should optional arguments just not be passed?
 
890
                    # passing None seems a little odd, looking at the way it
 
891
                    # turns out here... -glyph
 
892
                    From=('file', 'file'),
 
893
                    Print=None,
 
894
                    optional=None,
 
895
                    )))
 
896
        self.assertEquals(L[1], dict(Print=None, hello='aaa'))
 
897
 
 
898
class PretendRemoteCertificateAuthority:
 
899
    def checkIsPretendRemote(self):
 
900
        return True
 
901
 
 
902
class IOSimCert:
 
903
    verifyCount = 0
 
904
 
 
905
    def options(self, *ign):
 
906
        return self
 
907
 
 
908
    def iosimVerify(self, otherCert):
 
909
        """
 
910
        This isn't a real certificate, and wouldn't work on a real socket, but
 
911
        iosim specifies a different API so that we don't have to do any crypto
 
912
        math to demonstrate that the right functions get called in the right
 
913
        places.
 
914
        """
 
915
        assert otherCert is self
 
916
        self.verifyCount += 1
 
917
        return True
 
918
 
 
919
class OKCert(IOSimCert):
 
920
    def options(self, x):
 
921
        assert x.checkIsPretendRemote()
 
922
        return self
 
923
 
 
924
class GrumpyCert(IOSimCert):
 
925
    def iosimVerify(self, otherCert):
 
926
        self.verifyCount += 1
 
927
        return False
 
928
 
 
929
class DroppyCert(IOSimCert):
 
930
    def __init__(self, toDrop):
 
931
        self.toDrop = toDrop
 
932
 
 
933
    def iosimVerify(self, otherCert):
 
934
        self.verifyCount += 1
 
935
        self.toDrop.loseConnection()
 
936
        return True
 
937
 
 
938
class SecurableProto(FactoryNotifier):
 
939
 
 
940
    factory = None
 
941
 
 
942
    def verifyFactory(self):
 
943
        return [PretendRemoteCertificateAuthority()]
 
944
 
 
945
    def getTLSVars(self):
 
946
        cert = self.certFactory()
 
947
        verify = self.verifyFactory()
 
948
        return dict(
 
949
            tls_localCertificate=cert,
 
950
            tls_verifyAuthorities=verify)
 
951
    amp.StartTLS.responder(getTLSVars)
 
952
 
 
953
 
 
954
 
 
955
class TLSTest(unittest.TestCase):
 
956
    def test_startingTLS(self):
 
957
        """
 
958
        Verify that starting TLS and succeeding at handshaking sends all the
 
959
        notifications to all the right places.
 
960
        """
 
961
        cli, svr, p = connectedServerAndClient(
 
962
            ServerClass=SecurableProto,
 
963
            ClientClass=SecurableProto)
 
964
 
 
965
        okc = OKCert()
 
966
        svr.certFactory = lambda : okc
 
967
 
 
968
        cli.callRemote(
 
969
            amp.StartTLS, tls_localCertificate=okc,
 
970
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
 
971
 
 
972
        # let's buffer something to be delivered securely
 
973
        L = []
 
974
        d = cli.callRemote(SecuredPing).addCallback(L.append)
 
975
        p.flush()
 
976
        # once for client once for server
 
977
        self.assertEquals(okc.verifyCount, 2)
 
978
        L = []
 
979
        d = cli.callRemote(SecuredPing).addCallback(L.append)
 
980
        p.flush()
 
981
        self.assertEqual(L[0], {'pinged': True})
 
982
 
 
983
    def test_startTooManyTimes(self):
 
984
        """
 
985
        Verify that the protocol will complain if we attempt to renegotiate TLS,
 
986
        which we don't support.
 
987
        """
 
988
        cli, svr, p = connectedServerAndClient(
 
989
            ServerClass=SecurableProto,
 
990
            ClientClass=SecurableProto)
 
991
 
 
992
        okc = OKCert()
 
993
        svr.certFactory = lambda : okc
 
994
 
 
995
        # print c, c.transport
 
996
        cli.callRemote(amp.StartTLS,
 
997
                       tls_localCertificate=okc,
 
998
                       tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
 
999
        p.flush()
 
1000
        cli.noPeerCertificate = True # this is totally fake
 
1001
        self.assertRaises(
 
1002
            amp.OnlyOneTLS,
 
1003
            cli.callRemote,
 
1004
            amp.StartTLS,
 
1005
            tls_localCertificate=okc,
 
1006
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
 
1007
 
 
1008
    def test_negotiationFailed(self):
 
1009
        """
 
1010
        Verify that starting TLS and failing on both sides at handshaking sends
 
1011
        notifications to all the right places and terminates the connection.
 
1012
        """
 
1013
 
 
1014
        badCert = GrumpyCert()
 
1015
 
 
1016
        cli, svr, p = connectedServerAndClient(
 
1017
            ServerClass=SecurableProto,
 
1018
            ClientClass=SecurableProto)
 
1019
        svr.certFactory = lambda : badCert
 
1020
 
 
1021
        cli.callRemote(amp.StartTLS,
 
1022
                       tls_localCertificate=badCert)
 
1023
 
 
1024
        p.flush()
 
1025
        # once for client once for server - but both fail
 
1026
        self.assertEquals(badCert.verifyCount, 2)
 
1027
        d = cli.callRemote(SecuredPing)
 
1028
        p.flush()
 
1029
        self.assertFailure(d, iosim.OpenSSLVerifyError)
 
1030
 
 
1031
    def test_negotiationFailedByClosing(self):
 
1032
        """
 
1033
        Verify that starting TLS and failing by way of a lost connection
 
1034
        notices that it is probably an SSL problem.
 
1035
        """
 
1036
 
 
1037
        cli, svr, p = connectedServerAndClient(
 
1038
            ServerClass=SecurableProto,
 
1039
            ClientClass=SecurableProto)
 
1040
        droppyCert = DroppyCert(svr.transport)
 
1041
        svr.certFactory = lambda : droppyCert
 
1042
 
 
1043
        secure = cli.callRemote(amp.StartTLS,
 
1044
                                tls_localCertificate=droppyCert)
 
1045
 
 
1046
        p.flush()
 
1047
 
 
1048
        self.assertEquals(droppyCert.verifyCount, 2)
 
1049
 
 
1050
        d = cli.callRemote(SecuredPing)
 
1051
        p.flush()
 
1052
 
 
1053
        # it might be a good idea to move this exception somewhere more
 
1054
        # reasonable.
 
1055
        self.assertFailure(d, PeerVerifyError)
 
1056
 
 
1057
 
 
1058
 
 
1059
class InheritedError(Exception):
 
1060
    """
 
1061
    This error is used to check inheritance.
 
1062
    """
 
1063
 
 
1064
 
 
1065
 
 
1066
class OtherInheritedError(Exception):
 
1067
    """
 
1068
    This is a distinct error for checking inheritance.
 
1069
    """
 
1070
 
 
1071
 
 
1072
 
 
1073
class BaseCommand(amp.Command):
 
1074
    """
 
1075
    This provides a command that will be subclassed.
 
1076
    """
 
1077
    errors = {InheritedError: 'INHERITED_ERROR'}
 
1078
 
 
1079
 
 
1080
 
 
1081
class InheritedCommand(BaseCommand):
 
1082
    """
 
1083
    This is a command which subclasses another command but does not override
 
1084
    anything.
 
1085
    """
 
1086
 
 
1087
 
 
1088
 
 
1089
class AddErrorsCommand(BaseCommand):
 
1090
    """
 
1091
    This is a command which subclasses another command but adds errors to the
 
1092
    list.
 
1093
    """
 
1094
    arguments = [('other', amp.Boolean())]
 
1095
    errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'}
 
1096
 
 
1097
 
 
1098
 
 
1099
class NormalCommandProtocol(amp.AMP):
 
1100
    """
 
1101
    This is a protocol which responds to L{BaseCommand}, and is used to test
 
1102
    that inheritance does not interfere with the normal handling of errors.
 
1103
    """
 
1104
    def resp(self):
 
1105
        raise InheritedError()
 
1106
    BaseCommand.responder(resp)
 
1107
 
 
1108
 
 
1109
 
 
1110
class InheritedCommandProtocol(amp.AMP):
 
1111
    """
 
1112
    This is a protocol which responds to L{InheritedCommand}, and is used to
 
1113
    test that inherited commands inherit their bases' errors if they do not
 
1114
    respond to any of their own.
 
1115
    """
 
1116
    def resp(self):
 
1117
        raise InheritedError()
 
1118
    InheritedCommand.responder(resp)
 
1119
 
 
1120
 
 
1121
 
 
1122
class AddedCommandProtocol(amp.AMP):
 
1123
    """
 
1124
    This is a protocol which responds to L{AddErrorsCommand}, and is used to
 
1125
    test that inherited commands can add their own new types of errors, but
 
1126
    still respond in the same way to their parents types of errors.
 
1127
    """
 
1128
    def resp(self, other):
 
1129
        if other:
 
1130
            raise OtherInheritedError()
 
1131
        else:
 
1132
            raise InheritedError()
 
1133
    AddErrorsCommand.responder(resp)
 
1134
 
 
1135
 
 
1136
 
 
1137
class CommandInheritanceTests(unittest.TestCase):
 
1138
    """
 
1139
    These tests verify that commands inherit error conditions properly.
 
1140
    """
 
1141
 
 
1142
    def errorCheck(self, err, proto, cmd, **kw):
 
1143
        """
 
1144
        Check that the appropriate kind of error is raised when a given command
 
1145
        is sent to a given protocol.
 
1146
        """
 
1147
        c, s, p = connectedServerAndClient(ServerClass=proto,
 
1148
                                           ClientClass=proto)
 
1149
        d = c.callRemote(cmd, **kw)
 
1150
        d2 = self.failUnlessFailure(d, err)
 
1151
        p.flush()
 
1152
        return d2
 
1153
 
 
1154
 
 
1155
    def test_basicErrorPropagation(self):
 
1156
        """
 
1157
        Verify that errors specified in a superclass are respected normally
 
1158
        even if it has subclasses.
 
1159
        """
 
1160
        return self.errorCheck(
 
1161
            InheritedError, NormalCommandProtocol, BaseCommand)
 
1162
 
 
1163
 
 
1164
    def test_inheritedErrorPropagation(self):
 
1165
        """
 
1166
        Verify that errors specified in a superclass command are propagated to
 
1167
        its subclasses.
 
1168
        """
 
1169
        return self.errorCheck(
 
1170
            InheritedError, InheritedCommandProtocol, InheritedCommand)
 
1171
 
 
1172
 
 
1173
    def test_inheritedErrorAddition(self):
 
1174
        """
 
1175
        Verify that new errors specified in a subclass of an existing command
 
1176
        are honored even if the superclass defines some errors.
 
1177
        """
 
1178
        return self.errorCheck(
 
1179
            OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True)
 
1180
 
 
1181
 
 
1182
    def test_additionWithOriginalError(self):
 
1183
        """
 
1184
        Verify that errors specified in a command's superclass are respected
 
1185
        even if that command defines new errors itself.
 
1186
        """
 
1187
        return self.errorCheck(
 
1188
            InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False)
 
1189
 
 
1190
 
 
1191
 
 
1192
def _loseAndPass(err, proto):
 
1193
    # be specific, pass on the error to the client.
 
1194
    err.trap(error.ConnectionLost, error.ConnectionDone)
 
1195
    del proto.connectionLost
 
1196
    proto.connectionLost(err)
 
1197
 
 
1198
class LiveFireBase:
 
1199
    """
 
1200
    Utility for connected reactor-using tests.
 
1201
    """
 
1202
 
 
1203
    def setUp(self):
 
1204
        from twisted.internet import reactor
 
1205
        self.serverFactory = protocol.ServerFactory()
 
1206
        self.serverFactory.protocol = self.serverProto
 
1207
        self.clientFactory = protocol.ClientFactory()
 
1208
        self.clientFactory.protocol = self.clientProto
 
1209
        self.clientFactory.onMade = defer.Deferred()
 
1210
        self.serverFactory.onMade = defer.Deferred()
 
1211
        self.serverPort = reactor.listenTCP(0, self.serverFactory)
 
1212
        self.clientConn = reactor.connectTCP(
 
1213
            '127.0.0.1', self.serverPort.getHost().port,
 
1214
            self.clientFactory)
 
1215
        def getProtos(rlst):
 
1216
            self.cli = self.clientFactory.theProto
 
1217
            self.svr = self.serverFactory.theProto
 
1218
        dl = defer.DeferredList([self.clientFactory.onMade,
 
1219
                                 self.serverFactory.onMade])
 
1220
        return dl.addCallback(getProtos)
 
1221
 
 
1222
    def tearDown(self):
 
1223
        L = []
 
1224
        for conn in self.cli, self.svr:
 
1225
            if conn.transport is not None:
 
1226
                # depend on amp's function connection-dropping behavior
 
1227
                d = defer.Deferred().addErrback(_loseAndPass, conn)
 
1228
                conn.connectionLost = d.errback
 
1229
                conn.transport.loseConnection()
 
1230
                L.append(d)
 
1231
        if self.serverPort is not None:
 
1232
            L.append(defer.maybeDeferred(self.serverPort.stopListening))
 
1233
        if self.clientConn is not None:
 
1234
            self.clientConn.disconnect()
 
1235
        return defer.DeferredList(L)
 
1236
 
 
1237
def show(x):
 
1238
    import sys
 
1239
    sys.stdout.write(x+'\n')
 
1240
    sys.stdout.flush()
 
1241
 
 
1242
def tempSelfSigned():
 
1243
    from twisted.internet import ssl
 
1244
 
 
1245
    sharedDN = ssl.DN(CN='shared')
 
1246
    key = ssl.KeyPair.generate()
 
1247
    cr = key.certificateRequest(sharedDN)
 
1248
    sscrd = key.signCertificateRequest(
 
1249
        sharedDN, cr, lambda dn: True, 1234567)
 
1250
    cert = key.newCertificate(sscrd)
 
1251
    return cert
 
1252
 
 
1253
tempcert = tempSelfSigned()
 
1254
 
 
1255
class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase):
 
1256
    clientProto = SecurableProto
 
1257
    serverProto = SecurableProto
 
1258
    def test_liveFireCustomTLS(self):
 
1259
        """
 
1260
        Using real, live TLS, actually negotiate a connection.
 
1261
 
 
1262
        This also looks at the 'peerCertificate' attribute's correctness, since
 
1263
        that's actually loaded using OpenSSL calls, but the main purpose is to
 
1264
        make sure that we didn't miss anything obvious in iosim about TLS
 
1265
        negotiations.
 
1266
        """
 
1267
 
 
1268
        cert = tempcert
 
1269
 
 
1270
        self.svr.verifyFactory = lambda : [cert]
 
1271
        self.svr.certFactory = lambda : cert
 
1272
        # only needed on the server, we specify the client below.
 
1273
 
 
1274
        def secured(rslt):
 
1275
            x = cert.digest()
 
1276
            def pinged(rslt2):
 
1277
                # Interesting.  OpenSSL won't even _tell_ us about the peer
 
1278
                # cert until we negotiate.  we should be able to do this in
 
1279
                # 'secured' instead, but it looks like we can't.  I think this
 
1280
                # is a bug somewhere far deeper than here.
 
1281
                self.failUnlessEqual(x, self.cli.hostCertificate.digest())
 
1282
                self.failUnlessEqual(x, self.cli.peerCertificate.digest())
 
1283
                self.failUnlessEqual(x, self.svr.hostCertificate.digest())
 
1284
                self.failUnlessEqual(x, self.svr.peerCertificate.digest())
 
1285
            return self.cli.callRemote(SecuredPing).addCallback(pinged)
 
1286
        return self.cli.callRemote(amp.StartTLS,
 
1287
                                   tls_localCertificate=cert,
 
1288
                                   tls_verifyAuthorities=[cert]).addCallback(secured)
 
1289
 
 
1290
class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
 
1291
    def tlisfy(self):
 
1292
        return dict(tls_localCertificate=tempcert)
 
1293
 
 
1294
class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase):
 
1295
 
 
1296
    clientProto = SimpleSymmetricCommandProtocol
 
1297
    serverProto = SimpleSymmetricCommandProtocol
 
1298
 
 
1299
    def test_liveFireDefaultTLS(self):
 
1300
        """
 
1301
        Verify that out of the box, we can start TLS to at least encrypt the
 
1302
        connection, even if we don't have any certificates to use.
 
1303
        """
 
1304
        def secured(result):
 
1305
            return self.cli.callRemote(SecuredPing)
 
1306
        return self.cli.callRemote(amp.StartTLS).addCallback(secured)
 
1307
 
 
1308
class WithServerTLSVerification(LiveFireBase, unittest.TestCase):
 
1309
    clientProto = SimpleSymmetricCommandProtocol
 
1310
    serverProto = SlightlySmartTLS
 
1311
 
 
1312
    def test_anonymousVerifyingClient(self):
 
1313
        """
 
1314
        Verify that anonymous clients can verify server certificates.
 
1315
        """
 
1316
        def secured(result):
 
1317
            return self.cli.callRemote(SecuredPing)
 
1318
        return self.cli.callRemote(amp.StartTLS, tls_verifyAuthorities=[tempcert])