~divmod-dev/divmod.org/1304710-storeless-adapter

« back to all changes in this revision

Viewing changes to Vertex/vertex/test/test_q2q.py

  • Committer: cyli
  • Date: 2013-06-27 06:02:46 UTC
  • mto: This revision was merged to the branch mainline in revision 2702.
  • Revision ID: cyli-20130627060246-ciict8hwvjuy9d81
Move Vertex out of the Divmod.org repository

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright 2005-2008 Divmod, Inc.  See LICENSE file for details
2
 
# -*- vertex.test.test_q2q.UDPConnection -*-
3
 
 
4
 
"""
5
 
Tests for L{vertex.q2q}.
6
 
"""
7
 
 
8
 
from cStringIO import StringIO
9
 
 
10
 
from twisted.trial import unittest
11
 
from twisted.application import service
12
 
from twisted.internet import reactor, protocol, defer
13
 
from twisted.internet.task import deferLater
14
 
from twisted.internet.ssl import DistinguishedName, PrivateCertificate, KeyPair
15
 
from twisted.protocols import basic
16
 
from twisted.python import log
17
 
from twisted.python import failure
18
 
from twisted.internet.error import ConnectionDone
19
 
# from twisted.internet.main import CONNECTION_DONE
20
 
 
21
 
from zope.interface import implements
22
 
from twisted.internet.interfaces import IResolverSimple
23
 
 
24
 
from twisted.protocols.amp import (
25
 
    UnhandledCommand, UnknownRemoteError, QuitBox, Command, AMP)
26
 
 
27
 
from vertex import q2q
28
 
 
29
 
 
30
 
def noResources(*a):
31
 
    return []
32
 
 
33
 
class FakeConnectTCP:
34
 
    implements(IResolverSimple)
35
 
 
36
 
    def __init__(self, connectTCP):
37
 
        self._connectTCP = connectTCP
38
 
        self.hostPortToHostPort = {}
39
 
        self.hostToLocalHost = {}
40
 
        self.counter = 1
41
 
 
42
 
    def addHostPort(self, hostname, fakePortNumber, realPortNumber):
43
 
        if hostname in self.hostToLocalHost:
44
 
            localIP = self.hostToLocalHost[hostname]
45
 
        else:
46
 
            localIP = '127.0.0.%d' % (self.counter,)
47
 
            self.counter += 1
48
 
            self.hostToLocalHost[hostname] = localIP
49
 
 
50
 
        self.hostPortToHostPort[(localIP, fakePortNumber)] = (localIP, realPortNumber)
51
 
        self.hostPortToHostPort[(hostname, fakePortNumber)] = (hostname, realPortNumber)
52
 
 
53
 
    def connectTCP(self, host, port, *args, **kw):
54
 
        localhost, localport = self.hostPortToHostPort.get((host,port), (host, port))
55
 
        return self._connectTCP(localhost, localport, *args, **kw)
56
 
 
57
 
    def getHostSync(self,name):
58
 
        result = self.hostToLocalHost[name]
59
 
        return result
60
 
 
61
 
    def getHostByName(self, name, timeout):
62
 
        return defer.maybeDeferred(self.getHostSync, name)
63
 
 
64
 
def runOneDeferred(d):
65
 
    L = []
66
 
    d.addBoth(L.append)
67
 
    reactor.callLater(0, d.addCallback, lambda ign: reactor.crash())
68
 
    reactor.run()
69
 
    if L:
70
 
        if isinstance(L[0], failure.Failure):
71
 
            L[0].trap()
72
 
        return L[0]
73
 
    raise unittest.FailTest("Keyboard Interrupt")
74
 
 
75
 
class Utility(unittest.TestCase):
76
 
    def testServiceInitialization(self):
77
 
        svc = q2q.Q2QService(noResources)
78
 
        svc.certificateStorage.addPrivateCertificate("test.domain")
79
 
 
80
 
        cert = svc.certificateStorage.getPrivateCertificate("test.domain")
81
 
        self.failUnless(cert.getPublicKey().matches(cert.privateKey))
82
 
 
83
 
class OneTrickPony(AMP):
84
 
    def amp_TRICK(self, box):
85
 
        return QuitBox(tricked='True')
86
 
 
87
 
class OneTrickPonyServerFactory(protocol.ServerFactory):
88
 
    protocol = OneTrickPony
89
 
 
90
 
class OneTrickPonyClient(AMP):
91
 
    def connectionMade(self):
92
 
        self.callRemoteString('trick').chainDeferred(self.factory.ponged)
93
 
 
94
 
class OneTrickPonyClientFactory(protocol.ClientFactory):
95
 
    protocol = OneTrickPonyClient
96
 
 
97
 
    def __init__(self, ponged):
98
 
        self.ponged = ponged
99
 
 
100
 
    def buildProtocol(self, addr):
101
 
        result = protocol.ClientFactory.buildProtocol(self, addr)
102
 
        self.proto = result
103
 
        return result
104
 
 
105
 
    def clientConnectionFailed(self, connector, reason):
106
 
        self.ponged.errback(reason)
107
 
 
108
 
 
109
 
class DataEater(protocol.Protocol):
110
 
    def __init__(self):
111
 
        self.waiters = []
112
 
        self.data = []
113
 
        self.count = 0
114
 
 
115
 
    def dataReceived(self, data):
116
 
        if not data:
117
 
            raise RuntimeError("Empty string delivered to DataEater")
118
 
        self.data.append(data)
119
 
        self.count += len(data)
120
 
        for count, waiter in self.waiters[:]:
121
 
            if self.count >= count:
122
 
                waiter.callback(self.count)
123
 
 
124
 
    def removeD(self, result, size, d):
125
 
        # XXX done as a callback because 1.3 util.wait actually calls a
126
 
        # callback on the deferred
127
 
        self.waiters.remove((size, d))
128
 
        return result
129
 
 
130
 
    def waitForCount(self, size):
131
 
        D = defer.Deferred()
132
 
        self.waiters.append((size, D))
133
 
        self.waiters.sort()
134
 
        self.waiters.reverse()
135
 
        return D.addBoth(self.removeD, size, D)
136
 
 
137
 
    def buildProtocol(self, addr):
138
 
        return self
139
 
 
140
 
class DataFeeder(protocol.Protocol):
141
 
    def __init__(self, fobj):
142
 
        self.fobj = fobj
143
 
 
144
 
    def clientConnectionFailed(self, connector, reason):
145
 
        log.msg("DataFeeder client connection failed:")
146
 
        log.err(reason)
147
 
 
148
 
    def clientConnectionLost(self, connector, reason):
149
 
        pass
150
 
 
151
 
    def connectionMade(self):
152
 
        basic.FileSender().beginFileTransfer(self.fobj, self.transport)
153
 
 
154
 
    def buildProtocol(self, addr):
155
 
        return self
156
 
 
157
 
class StreamingDataFeeder(protocol.Protocol):
158
 
    DELAY = 0.01
159
 
    CHUNK = 1024
160
 
    paused = False
161
 
    pauseCount = 0
162
 
    resumeCount = 0
163
 
    stopCount = 0
164
 
    outCount = 0
165
 
    call = None
166
 
 
167
 
    def __init__(self, infile):
168
 
        self.file = infile
169
 
 
170
 
    def clientConnectionFailed(sef, connector, reason):
171
 
        log.msg("StreamingDataFeeder client connection failed:")
172
 
        log.err(reason)
173
 
 
174
 
    def clientConnectionLost(self, connector, reason):
175
 
        pass
176
 
 
177
 
    def connectionMade(self):
178
 
        self.nextChunk = self.file.read(self.CHUNK)
179
 
        self.transport.registerProducer(self, True)
180
 
        self.call = reactor.callLater(self.DELAY, self._keepGoing)
181
 
 
182
 
    def _keepGoing(self):
183
 
        self.call = None
184
 
        if self.paused:
185
 
            return
186
 
        chunk = self.nextChunk
187
 
        self.nextChunk = self.file.read(self.CHUNK)
188
 
        self.outCount += len(chunk)
189
 
        if chunk:
190
 
            self.transport.write(chunk)
191
 
        if self.nextChunk:
192
 
            self.call = reactor.callLater(self.DELAY, self._keepGoing)
193
 
 
194
 
 
195
 
    def pauseProducing(self):
196
 
        self.paused = True
197
 
        self.pauseCount += 1
198
 
        if self.call is not None:
199
 
            self.cancelMe()
200
 
 
201
 
    def resumeProducing(self):
202
 
        self.paused = False
203
 
        if self.call is not None:
204
 
            self.cancelMe()
205
 
        self.call = reactor.callLater(self.DELAY, self._keepGoing)
206
 
        self.resumeCount += 1
207
 
 
208
 
    def cancelMe(self):
209
 
        self.call.cancel()
210
 
        self.call = None
211
 
 
212
 
    def stopProducing(self):
213
 
        self.paused = True
214
 
        self.stopCount += 1
215
 
        if self.call is not None:
216
 
            self.cancelMe()
217
 
 
218
 
    def buildProtocol(self, addr):
219
 
        return self
220
 
 
221
 
 
222
 
class ErroneousClientError(Exception):
223
 
    pass
224
 
 
225
 
class EngenderError(Command):
226
 
    commandName = 'Engender-Error'
227
 
 
228
 
class Break(Command):
229
 
    commandName = 'Break'
230
 
 
231
 
class Flag(Command):
232
 
    commandName = 'Flag'
233
 
 
234
 
class FatalError(Exception):
235
 
    pass
236
 
 
237
 
class Fatal(Command):
238
 
    fatalErrors = {FatalError: "quite bad"}
239
 
 
240
 
class Erroneous(AMP):
241
 
    def _fatal(self):
242
 
        raise FatalError("This is fatal.")
243
 
    Fatal.responder(_fatal)
244
 
 
245
 
    flag = False
246
 
    def _break(self):
247
 
        raise ErroneousClientError("Zoop")
248
 
    Break.responder(_break)
249
 
 
250
 
    def _engenderError(self):
251
 
        def ebBroken(err):
252
 
            err.trap(ConnectionDone)
253
 
            # This connection is dead.  Avoid having an error logged by turning
254
 
            # this into success; the result can't possibly get to the other
255
 
            # side, anyway. -exarkun
256
 
            return {}
257
 
        return self.callRemote(Break).addErrback(ebBroken)
258
 
    EngenderError.responder(_engenderError)
259
 
 
260
 
    def _flag(self):
261
 
        self.flag = True
262
 
        return {}
263
 
    Flag.responder(_flag)
264
 
 
265
 
class ErroneousServerFactory(protocol.ServerFactory):
266
 
    protocol = Erroneous
267
 
 
268
 
class ErroneousClientFactory(protocol.ClientFactory):
269
 
    protocol = Erroneous
270
 
 
271
 
class Greet(Command):
272
 
    commandName = 'Greet'
273
 
 
274
 
class Greeter(AMP, protocol.ServerFactory, protocol.ClientFactory):
275
 
    def __init__(self, isServer, startupD):
276
 
        self.isServer = isServer
277
 
        AMP.__init__(self)
278
 
        self.startupD = startupD
279
 
 
280
 
    def buildProtocol(self, addr):
281
 
        return self
282
 
 
283
 
    def connectionMade(self):
284
 
        self.callRemote(Greet).chainDeferred(self.startupD)
285
 
 
286
 
    def _greet(self):
287
 
        self.greeted = True
288
 
        return dict()
289
 
    Greet.responder(_greet)
290
 
 
291
 
class Q2QConnectionTestCase(unittest.TestCase):
292
 
    streamer = None
293
 
 
294
 
    fromResource = 'clientResource'
295
 
    toResource = 'serverResource'
296
 
 
297
 
    fromDomain = 'origin.domain.example.com'
298
 
    fromIP = '127.0.0.1'
299
 
    spoofedDomain = 'spoofed.domain.example.com'
300
 
    toDomain = 'destination.domain.example.org'
301
 
    toIP = '127.0.0.2'
302
 
 
303
 
    userReverseDNS = 'i.watch.too.much.tv'
304
 
    inboundTCPPortnum = 0
305
 
    udpEnabled = False
306
 
    virtualEnabled = False
307
 
 
308
 
    def _makeQ2QService(self, certificateEntity, publicIP, pff=None):
309
 
        svc = q2q.Q2QService(pff, q2qPortnum=0,
310
 
                             inboundTCPPortnum=self.inboundTCPPortnum,
311
 
                             publicIP=publicIP)
312
 
        svc.udpEnabled = self.udpEnabled
313
 
        svc.virtualEnabled = self.virtualEnabled
314
 
        if '@' not in certificateEntity:
315
 
            svc.certificateStorage.addPrivateCertificate(certificateEntity)
316
 
        svc.debugName = certificateEntity
317
 
        return svc
318
 
 
319
 
 
320
 
    def _addQ2QProtocol(self, name, factory):
321
 
        resourceKey = (self.fromAddress,
322
 
                       self.toAddress, name)
323
 
        self.resourceMap[resourceKey] = factory
324
 
 
325
 
    def protocolFactoryLookup(self, *key):
326
 
        if key in self.resourceMap:
327
 
            return [(self.resourceMap[key], 'test-description')]
328
 
        return []
329
 
 
330
 
 
331
 
    def setUp(self):
332
 
        self.fromAddress = q2q.Q2QAddress(self.fromDomain, self.fromResource)
333
 
        self.toAddress = q2q.Q2QAddress(self.toDomain, self.toResource)
334
 
 
335
 
        # A mapping of host names to port numbers Our connectTCP will always
336
 
        # connect to 127.0.0.1 and on a port which is a value in this
337
 
        # dictionary.
338
 
        fakeDNS = FakeConnectTCP(reactor.connectTCP)
339
 
        reactor.connectTCP = fakeDNS.connectTCP
340
 
 
341
 
        # ALSO WE MUST DO OTHER SIMILAR THINGS
342
 
        self._oldResolver = reactor.resolver
343
 
        reactor.installResolver(fakeDNS)
344
 
 
345
 
        # Set up a know-nothing service object for the client half of the
346
 
        # conversation.
347
 
        self.serverService2 = self._makeQ2QService(self.fromDomain, self.fromIP, noResources)
348
 
 
349
 
        # Do likewise for the server half of the conversation.  Also, allow
350
 
        # test methods to set up some trivial resources which we can attempt to
351
 
        # access from the client.
352
 
        self.resourceMap = {}
353
 
        self.serverService = self._makeQ2QService(self.toDomain, self.toIP,
354
 
                                                  self.protocolFactoryLookup)
355
 
 
356
 
        self.msvc = service.MultiService()
357
 
        self.serverService2.setServiceParent(self.msvc)
358
 
        self.serverService.setServiceParent(self.msvc)
359
 
 
360
 
        # Let the kernel allocate a random port for each of these service's listeners
361
 
        self.msvc.startService()
362
 
 
363
 
        fakeDNS.addHostPort(
364
 
            self.fromDomain, 8788,
365
 
            self.serverService2.q2qPort.getHost().port)
366
 
 
367
 
        fakeDNS.addHostPort(
368
 
            self.toDomain, 8788,
369
 
            self.serverService.q2qPort.getHost().port)
370
 
 
371
 
        self._addQ2QProtocol('pony', OneTrickPonyServerFactory())
372
 
 
373
 
        self.dataEater = DataEater()
374
 
        self._addQ2QProtocol('eat', self.dataEater)
375
 
 
376
 
        self._addQ2QProtocol('error', ErroneousServerFactory())
377
 
 
378
 
    def tearDown(self):
379
 
        reactor.installResolver(self._oldResolver)
380
 
        del reactor.connectTCP
381
 
        return self.msvc.stopService()
382
 
 
383
 
 
384
 
 
385
 
class ConnectionTestMixin:
386
 
 
387
 
    def testConnectWithIntroduction(self):
388
 
        ponged = defer.Deferred()
389
 
        self.serverService2.connectQ2Q(self.fromAddress,
390
 
                                      self.toAddress,
391
 
                                      'pony',
392
 
                                      OneTrickPonyClientFactory(ponged))
393
 
        return ponged.addCallback(lambda answerBox: self.failUnless('tricked' in answerBox))
394
 
 
395
 
    def addClientService(self, toAddress, secret, serverService):
396
 
        return self._addClientService(
397
 
            toAddress.resource, secret, serverService, toAddress.domain)
398
 
 
399
 
    def _addClientService(self, username,
400
 
                          privateSecret, serverService,
401
 
                          serverDomain):
402
 
        svc = self._makeQ2QService(username + '@' + serverDomain, None)
403
 
        serverService.certificateStorage.addUser(serverDomain,
404
 
                                                 username,
405
 
                                                 privateSecret)
406
 
        svc.setServiceParent(self.msvc)
407
 
        return svc.authorize(q2q.Q2QAddress(serverDomain, username),
408
 
                             privateSecret).addCallback(lambda x: svc)
409
 
 
410
 
 
411
 
    def testListening(self):
412
 
        _1 = self.addClientService(self.toAddress, 'aaaa', self.serverService)
413
 
        def _1c(_1result):
414
 
            self.clientServerService = _1result
415
 
            ponyFactory = OneTrickPonyServerFactory()
416
 
            _2 = self.clientServerService.listenQ2Q(self.toAddress,
417
 
                                                    {'pony2': ponyFactory},
418
 
                                                    'ponies suck')
419
 
 
420
 
            def _2c(ignored):
421
 
                _3 = self.addClientService(
422
 
                        self.fromAddress, 'bbbb', self.serverService2)
423
 
                def _3c(_3result):
424
 
                    self.clientClientService = _3result
425
 
 
426
 
                    _4 = defer.Deferred()
427
 
                    otpcf = OneTrickPonyClientFactory(_4)
428
 
                    self.clientClientService.connectQ2Q(self.fromAddress,
429
 
                                                        self.toAddress,
430
 
                                                        'pony2',
431
 
                                                        otpcf)
432
 
                    def _4c(answerBox):
433
 
                        T = otpcf.proto.transport
434
 
                        self.assertEquals(T.getQ2QPeer(), self.toAddress)
435
 
                        self.assertEquals(T.getQ2QHost(), self.fromAddress)
436
 
                        self.failUnless('tricked' in answerBox)
437
 
 
438
 
                    return _4.addCallback(_4c)
439
 
                return _3.addCallback(_3c)
440
 
            return _2.addCallback(_2c)
441
 
        return _1.addCallback(_1c)
442
 
 
443
 
    def testChooserGetsThreeChoices(self):
444
 
 
445
 
        def actualTest(ign):
446
 
            ponyFactory = OneTrickPonyServerFactory()
447
 
            _1 = self.addClientService(
448
 
                self.toAddress, 'aaaa', self.serverService)
449
 
            def _1c(_1result):
450
 
                self.clientServerService2 = _1result
451
 
                # print 'ultra frack'
452
 
 
453
 
                _2 = self.clientServerService2.listenQ2Q(self.toAddress,
454
 
                                                         {'pony': ponyFactory},
455
 
                                                         'ponies are weird')
456
 
                def _2c(ign):
457
 
                    _3 = self.clientServerService.listenQ2Q(self.toAddress,
458
 
                                                            {'pony': ponyFactory},
459
 
                                                            'ponies rule')
460
 
                    def _3c(ign):
461
 
                        expectedList = ['ponies rule', 'ponies are weird', 'test-description']
462
 
                        def chooser(servers):
463
 
                            self.failUnlessEqual(len(servers), 3)
464
 
                            for server in servers:
465
 
                                expectedList.remove(server['description'])
466
 
                                if server['description'] == 'ponies rule':
467
 
                                    self.assertEquals(
468
 
                                        self.clientServerService.certificateStorage.getPrivateCertificate(str(self.toAddress)),
469
 
                                        server['certificate'])
470
 
                                    yield server
471
 
 
472
 
                        factory = protocol.ClientFactory()
473
 
                        factory.protocol = AMP
474
 
                        _4 = self.clientClientService.connectQ2Q(
475
 
                            self.fromAddress,
476
 
                            self.toAddress,
477
 
                            'pony',
478
 
                            factory,
479
 
                            chooser=chooser)
480
 
                        def _4c(ign):
481
 
                            self.failUnlessEqual(expectedList, [])
482
 
                        return _4.addCallback(_4c)
483
 
                    return _3.addCallback(_3c)
484
 
                return _2.addCallback(_2c)
485
 
            return _1.addCallback(_1c)
486
 
        return self.testListening().addCallback(actualTest)
487
 
 
488
 
        # print 'dang yo'
489
 
 
490
 
 
491
 
    def testTwoGreetings(self):
492
 
        d1 = defer.Deferred()
493
 
        d2 = defer.Deferred()
494
 
        client = Greeter(False, d1)
495
 
        server = Greeter(True, d2)
496
 
        self._addQ2QProtocol('greet', server)
497
 
        self.serverService2.connectQ2Q(self.fromAddress,
498
 
                                      self.toAddress,
499
 
                                      'greet',
500
 
                                      client)
501
 
        def _(x):
502
 
            self.failUnless(client.greeted)
503
 
            self.failUnless(server.greeted)
504
 
        return defer.DeferredList([d1, d2]).addCallback(_)
505
 
 
506
 
 
507
 
    def testSendingFiles(self):
508
 
        SIZE = 1024 * 500
509
 
        self.streamer = StreamingDataFeeder(StringIO('x' * SIZE))
510
 
        self.streamer.CHUNK = 8192
511
 
        a = self.serverService2.connectQ2Q(self.fromAddress,
512
 
                                                 self.toAddress, 'eat',
513
 
                                                 DataFeeder(StringIO('y' * SIZE)))
514
 
        b = self.serverService2.connectQ2Q(self.fromAddress,
515
 
                                           self.toAddress, 'eat',
516
 
                                           self.streamer)
517
 
 
518
 
        def dotest(ign):
519
 
            # self.assertEquals( len(self.serverService.liveConnections), 1)
520
 
            # XXX currently there are 2 connections but there should only be 1: the
521
 
            # connection cache is busted, need a separate test for that
522
 
            for liveConnection in self.serverService.iterconnections():
523
 
                liveConnection.transport.pauseProducing()
524
 
            wfc = self.dataEater.waitForCount(SIZE * 2)
525
 
            resumed = [False]
526
 
            def shouldntHappen(x):
527
 
                if resumed[0]:
528
 
                    return x
529
 
                else:
530
 
                    self.fail("wfc fired with: " + repr(x))
531
 
            wfc.addBoth(shouldntHappen)
532
 
            def keepGoing(ign):
533
 
                resumed[0] = True
534
 
                for liveConnection in self.serverService.iterconnections():
535
 
                    liveConnection.transport.resumeProducing()
536
 
                def assertSomeStuff(ign):
537
 
                    self.failUnless(self.streamer.pauseCount > 0)
538
 
                    self.failUnless(self.streamer.resumeCount > 0)
539
 
                return self.dataEater.waitForCount(SIZE * 2).addCallback(assertSomeStuff)
540
 
            return deferLater(reactor, 3, lambda: None).addCallback(keepGoing)
541
 
        return defer.DeferredList([a, b]).addCallback(dotest)
542
 
    testSendingFiles.skip = "hangs forever"
543
 
 
544
 
    def testBadIssuerOnSelfSignedCert(self):
545
 
        x = self.testConnectWithIntroduction()
546
 
        def actualTest(result):
547
 
            ponged = defer.Deferred()
548
 
            signer = self.serverService2.certificateStorage.getPrivateCertificate(
549
 
                self.fromDomain).privateKey
550
 
            req = signer.requestObject(DistinguishedName(commonName=self.toDomain))
551
 
            sreq = signer.signRequestObject(
552
 
                DistinguishedName(commonName=self.fromDomain), req, 12345)
553
 
            selfSignedLie = PrivateCertificate.fromCertificateAndKeyPair(
554
 
                sreq, signer)
555
 
            self.serverService2.connectQ2Q(self.fromAddress,
556
 
                                          self.toAddress,
557
 
                                          'pony',
558
 
                                          OneTrickPonyClientFactory(ponged),
559
 
                                          selfSignedLie,
560
 
                                          fakeFromDomain=self.toDomain).addErrback(
561
 
                lambda e: e.trap(q2q.VerifyError))
562
 
 
563
 
            return self.assertFailure(ponged, q2q.VerifyError)
564
 
        return x.addCallback(actualTest)
565
 
 
566
 
 
567
 
    def testBadCertRequestSubject(self):
568
 
        kp = KeyPair.generate()
569
 
        subject = DistinguishedName(commonName='HACKERX',
570
 
                                    localityName='INTERNETANIA')
571
 
        reqobj = kp.requestObject(subject)
572
 
 
573
 
        fakereq = kp.requestObject(subject)
574
 
        ssigned = kp.signRequestObject(subject, fakereq, 1)
575
 
        certpair = PrivateCertificate.fromCertificateAndKeyPair
576
 
        fakecert = certpair(ssigned, kp)
577
 
        apc = self.serverService2.certificateStorage.addPrivateCertificate
578
 
 
579
 
        def _2(secured):
580
 
            D = secured.callRemote(
581
 
                q2q.Sign,
582
 
                certificate_request=reqobj,
583
 
                password='itdoesntmatter')
584
 
            def _1(dcert):
585
 
                cert = dcert['certificate']
586
 
                privcert = certpair(cert, kp)
587
 
                apc(str(self.fromAddress), privcert)
588
 
            return D.addCallback(_1)
589
 
 
590
 
        d = self.serverService2.getSecureConnection(
591
 
            self.fromAddress, self.fromAddress.domainAddress(), authorize=False,
592
 
            usePrivateCertificate=fakecert,
593
 
            ).addCallback(_2)
594
 
 
595
 
        def unexpectedSuccess(result):
596
 
            self.fail("Expected BadCertificateRequest, got %r" % (result,))
597
 
        def expectedFailure(err):
598
 
            err.trap(q2q.BadCertificateRequest)
599
 
        d.addCallbacks(unexpectedSuccess, expectedFailure)
600
 
        return d
601
 
 
602
 
    def testClientSideUnhandledException(self):
603
 
        d = self.serverService2.connectQ2Q(
604
 
            self.fromAddress, self.toAddress, 'error',
605
 
            ErroneousClientFactory())
606
 
        def connected(proto):
607
 
            return proto.callRemote(EngenderError)
608
 
        d.addCallback(connected)
609
 
        # The unhandled, undeclared error causes the connection to be closed
610
 
        # from the other side.
611
 
        d = self.assertFailure(d, ConnectionDone, UnknownRemoteError)
612
 
        def cbDisconnected(err):
613
 
            self.assertEqual(
614
 
                len(self.flushLoggedErrors(ErroneousClientError)),
615
 
                1)
616
 
        d.addCallback(cbDisconnected)
617
 
        return d
618
 
 
619
 
    def successIsFailure(self, success):
620
 
        self.fail()
621
 
 
622
 
    def testTwoBadWrites(self):
623
 
        d = self.serverService2.connectQ2Q(
624
 
            self.fromAddress, self.toAddress, 'error',
625
 
            ErroneousClientFactory())
626
 
 
627
 
        def connected(proto):
628
 
            d1 = self.assertFailure(proto.callRemote(Fatal), FatalError)
629
 
            def noMoreCalls(_):
630
 
                 self.assertFailure(proto.callRemote(Flag),
631
 
                                    ConnectionDone)
632
 
            d1.addCallback(noMoreCalls)
633
 
            return d1
634
 
        d.addCallback(connected)
635
 
        return d
636
 
 
637
 
 
638
 
 
639
 
 
640
 
class VirtualConnection(Q2QConnectionTestCase, ConnectionTestMixin):
641
 
    inboundTCPPortnum = None
642
 
    udpEnabled = False
643
 
    virtualEnabled = True
644
 
 
645
 
    def testListening(self):
646
 
        pass
647
 
 
648
 
    def testChooserGetsThreeChoices(self):
649
 
        pass
650
 
 
651
 
    testListening.skip = 'virtual port forwarding not implemented'
652
 
    testChooserGetsThreeChoices.skip = 'cant do this without testListening'
653
 
 
654
 
class UDPConnection(Q2QConnectionTestCase, ConnectionTestMixin):
655
 
    # skip = 'yep'
656
 
    inboundTCPPortnum = None
657
 
    udpEnabled = True
658
 
    virtualEnabled = False
659
 
 
660
 
class TCPConnection(Q2QConnectionTestCase, ConnectionTestMixin):
661
 
    inboundTCPPortnum = 0
662
 
    udpEnabled = False
663
 
    virtualEnabled = False
664
 
 
665
 
# class LiveServerMixin:
666
 
#     serverDomain = 'test.domain.example.com'
667
 
 
668
 
#     def jackDNS(self, *info):
669
 
#         self.fakeDNS = FakeConnectTCP(reactor.connectTCP)
670
 
#         reactor.connectTCP = self.fakeDNS.connectTCP
671
 
#         self._oldResolver = reactor.resolver
672
 
#         reactor.installResolver(self.fakeDNS)
673
 
 
674
 
#         for (hostname, oldport, newport) in info:
675
 
#             self.fakeDNS.addHostPort(hostname, oldport, newport)
676
 
 
677
 
#     def unjackDNS(self):
678
 
#         del reactor.connectTCP
679
 
#         reactor.installResolver(self._oldResolver)
680
 
 
681
 
# class AuthorizeTestCase(unittest.TestCase, LiveServerMixin):
682
 
#     authUser = 'authtestuser'
683
 
#     authPass = 'p4ssw0rd'
684
 
 
685
 
#     def setUp(self):
686
 
#         self.testDir, serverService = self.deploy()
687
 
#         self.serverService = service.IServiceCollection(serverService)
688
 
#         self.serverService.startService()
689
 
#         self.clientDir = os.path.join(self.testDir, 'client')
690
 
#         self.clientService = q2qclient.ClientQ2QService(self.clientDir)
691
 
 
692
 
#         q2qPort = self.getQ2QService().q2qPort.getHost().port
693
 
#         self.jackDNS((self.serverDomain, 8788, q2qPort))
694
 
 
695
 
#     def tearDown(self):
696
 
#         self.unjackDNS()
697
 
#         util.wait(self.serverService.stopService())
698
 
#         util.wait(self.clientService.stopService())
699
 
 
700
 
#     def testAuthorize(self):
701
 
#         self.createUser(self.authUser, self.authPass)
702
 
 
703
 
#         d = self.clientService.authorize(
704
 
#             q2q.Q2QAddress(self.serverDomain, self.authUser),
705
 
#             self.authPass)
706
 
 
707
 
#         result = runOneDeferred(d)
708
 
 
709
 
#         self.failUnless(os.path.exists(os.path.join(self.clientDir, 'public')))
710
 
#         self.failUnless(os.path.exists(os.path.join(self.clientDir, 'public', self.serverDomain + '.pem')))
711
 
#         self.failUnless(os.path.exists(os.path.join(self.clientDir, 'private')))
712
 
#         self.failUnless(os.path.exists(os.path.join(self.clientDir, 'private', self.authUser + '@' + self.serverDomain + '.pem')))