1
# Copyright 2005-2008 Divmod, Inc. See LICENSE file for details
2
# -*- vertex.test.test_q2q.UDPConnection -*-
5
Tests for L{vertex.q2q}.
8
from cStringIO import StringIO
10
from twisted.trial import unittest
11
from twisted.application import service
12
from twisted.internet import reactor, protocol, defer
13
from twisted.internet.task import deferLater
14
from twisted.internet.ssl import DistinguishedName, PrivateCertificate, KeyPair
15
from twisted.protocols import basic
16
from twisted.python import log
17
from twisted.python import failure
18
from twisted.internet.error import ConnectionDone
19
# from twisted.internet.main import CONNECTION_DONE
21
from zope.interface import implements
22
from twisted.internet.interfaces import IResolverSimple
24
from twisted.protocols.amp import (
25
UnhandledCommand, UnknownRemoteError, QuitBox, Command, AMP)
27
from vertex import q2q
34
implements(IResolverSimple)
36
def __init__(self, connectTCP):
37
self._connectTCP = connectTCP
38
self.hostPortToHostPort = {}
39
self.hostToLocalHost = {}
42
def addHostPort(self, hostname, fakePortNumber, realPortNumber):
43
if hostname in self.hostToLocalHost:
44
localIP = self.hostToLocalHost[hostname]
46
localIP = '127.0.0.%d' % (self.counter,)
48
self.hostToLocalHost[hostname] = localIP
50
self.hostPortToHostPort[(localIP, fakePortNumber)] = (localIP, realPortNumber)
51
self.hostPortToHostPort[(hostname, fakePortNumber)] = (hostname, realPortNumber)
53
def connectTCP(self, host, port, *args, **kw):
54
localhost, localport = self.hostPortToHostPort.get((host,port), (host, port))
55
return self._connectTCP(localhost, localport, *args, **kw)
57
def getHostSync(self,name):
58
result = self.hostToLocalHost[name]
61
def getHostByName(self, name, timeout):
62
return defer.maybeDeferred(self.getHostSync, name)
64
def runOneDeferred(d):
67
reactor.callLater(0, d.addCallback, lambda ign: reactor.crash())
70
if isinstance(L[0], failure.Failure):
73
raise unittest.FailTest("Keyboard Interrupt")
75
class Utility(unittest.TestCase):
76
def testServiceInitialization(self):
77
svc = q2q.Q2QService(noResources)
78
svc.certificateStorage.addPrivateCertificate("test.domain")
80
cert = svc.certificateStorage.getPrivateCertificate("test.domain")
81
self.failUnless(cert.getPublicKey().matches(cert.privateKey))
83
class OneTrickPony(AMP):
84
def amp_TRICK(self, box):
85
return QuitBox(tricked='True')
87
class OneTrickPonyServerFactory(protocol.ServerFactory):
88
protocol = OneTrickPony
90
class OneTrickPonyClient(AMP):
91
def connectionMade(self):
92
self.callRemoteString('trick').chainDeferred(self.factory.ponged)
94
class OneTrickPonyClientFactory(protocol.ClientFactory):
95
protocol = OneTrickPonyClient
97
def __init__(self, ponged):
100
def buildProtocol(self, addr):
101
result = protocol.ClientFactory.buildProtocol(self, addr)
105
def clientConnectionFailed(self, connector, reason):
106
self.ponged.errback(reason)
109
class DataEater(protocol.Protocol):
115
def dataReceived(self, data):
117
raise RuntimeError("Empty string delivered to DataEater")
118
self.data.append(data)
119
self.count += len(data)
120
for count, waiter in self.waiters[:]:
121
if self.count >= count:
122
waiter.callback(self.count)
124
def removeD(self, result, size, d):
125
# XXX done as a callback because 1.3 util.wait actually calls a
126
# callback on the deferred
127
self.waiters.remove((size, d))
130
def waitForCount(self, size):
132
self.waiters.append((size, D))
134
self.waiters.reverse()
135
return D.addBoth(self.removeD, size, D)
137
def buildProtocol(self, addr):
140
class DataFeeder(protocol.Protocol):
141
def __init__(self, fobj):
144
def clientConnectionFailed(self, connector, reason):
145
log.msg("DataFeeder client connection failed:")
148
def clientConnectionLost(self, connector, reason):
151
def connectionMade(self):
152
basic.FileSender().beginFileTransfer(self.fobj, self.transport)
154
def buildProtocol(self, addr):
157
class StreamingDataFeeder(protocol.Protocol):
167
def __init__(self, infile):
170
def clientConnectionFailed(sef, connector, reason):
171
log.msg("StreamingDataFeeder client connection failed:")
174
def clientConnectionLost(self, connector, reason):
177
def connectionMade(self):
178
self.nextChunk = self.file.read(self.CHUNK)
179
self.transport.registerProducer(self, True)
180
self.call = reactor.callLater(self.DELAY, self._keepGoing)
182
def _keepGoing(self):
186
chunk = self.nextChunk
187
self.nextChunk = self.file.read(self.CHUNK)
188
self.outCount += len(chunk)
190
self.transport.write(chunk)
192
self.call = reactor.callLater(self.DELAY, self._keepGoing)
195
def pauseProducing(self):
198
if self.call is not None:
201
def resumeProducing(self):
203
if self.call is not None:
205
self.call = reactor.callLater(self.DELAY, self._keepGoing)
206
self.resumeCount += 1
212
def stopProducing(self):
215
if self.call is not None:
218
def buildProtocol(self, addr):
222
class ErroneousClientError(Exception):
225
class EngenderError(Command):
226
commandName = 'Engender-Error'
228
class Break(Command):
229
commandName = 'Break'
234
class FatalError(Exception):
237
class Fatal(Command):
238
fatalErrors = {FatalError: "quite bad"}
240
class Erroneous(AMP):
242
raise FatalError("This is fatal.")
243
Fatal.responder(_fatal)
247
raise ErroneousClientError("Zoop")
248
Break.responder(_break)
250
def _engenderError(self):
252
err.trap(ConnectionDone)
253
# This connection is dead. Avoid having an error logged by turning
254
# this into success; the result can't possibly get to the other
255
# side, anyway. -exarkun
257
return self.callRemote(Break).addErrback(ebBroken)
258
EngenderError.responder(_engenderError)
263
Flag.responder(_flag)
265
class ErroneousServerFactory(protocol.ServerFactory):
268
class ErroneousClientFactory(protocol.ClientFactory):
271
class Greet(Command):
272
commandName = 'Greet'
274
class Greeter(AMP, protocol.ServerFactory, protocol.ClientFactory):
275
def __init__(self, isServer, startupD):
276
self.isServer = isServer
278
self.startupD = startupD
280
def buildProtocol(self, addr):
283
def connectionMade(self):
284
self.callRemote(Greet).chainDeferred(self.startupD)
289
Greet.responder(_greet)
291
class Q2QConnectionTestCase(unittest.TestCase):
294
fromResource = 'clientResource'
295
toResource = 'serverResource'
297
fromDomain = 'origin.domain.example.com'
299
spoofedDomain = 'spoofed.domain.example.com'
300
toDomain = 'destination.domain.example.org'
303
userReverseDNS = 'i.watch.too.much.tv'
304
inboundTCPPortnum = 0
306
virtualEnabled = False
308
def _makeQ2QService(self, certificateEntity, publicIP, pff=None):
309
svc = q2q.Q2QService(pff, q2qPortnum=0,
310
inboundTCPPortnum=self.inboundTCPPortnum,
312
svc.udpEnabled = self.udpEnabled
313
svc.virtualEnabled = self.virtualEnabled
314
if '@' not in certificateEntity:
315
svc.certificateStorage.addPrivateCertificate(certificateEntity)
316
svc.debugName = certificateEntity
320
def _addQ2QProtocol(self, name, factory):
321
resourceKey = (self.fromAddress,
322
self.toAddress, name)
323
self.resourceMap[resourceKey] = factory
325
def protocolFactoryLookup(self, *key):
326
if key in self.resourceMap:
327
return [(self.resourceMap[key], 'test-description')]
332
self.fromAddress = q2q.Q2QAddress(self.fromDomain, self.fromResource)
333
self.toAddress = q2q.Q2QAddress(self.toDomain, self.toResource)
335
# A mapping of host names to port numbers Our connectTCP will always
336
# connect to 127.0.0.1 and on a port which is a value in this
338
fakeDNS = FakeConnectTCP(reactor.connectTCP)
339
reactor.connectTCP = fakeDNS.connectTCP
341
# ALSO WE MUST DO OTHER SIMILAR THINGS
342
self._oldResolver = reactor.resolver
343
reactor.installResolver(fakeDNS)
345
# Set up a know-nothing service object for the client half of the
347
self.serverService2 = self._makeQ2QService(self.fromDomain, self.fromIP, noResources)
349
# Do likewise for the server half of the conversation. Also, allow
350
# test methods to set up some trivial resources which we can attempt to
351
# access from the client.
352
self.resourceMap = {}
353
self.serverService = self._makeQ2QService(self.toDomain, self.toIP,
354
self.protocolFactoryLookup)
356
self.msvc = service.MultiService()
357
self.serverService2.setServiceParent(self.msvc)
358
self.serverService.setServiceParent(self.msvc)
360
# Let the kernel allocate a random port for each of these service's listeners
361
self.msvc.startService()
364
self.fromDomain, 8788,
365
self.serverService2.q2qPort.getHost().port)
369
self.serverService.q2qPort.getHost().port)
371
self._addQ2QProtocol('pony', OneTrickPonyServerFactory())
373
self.dataEater = DataEater()
374
self._addQ2QProtocol('eat', self.dataEater)
376
self._addQ2QProtocol('error', ErroneousServerFactory())
379
reactor.installResolver(self._oldResolver)
380
del reactor.connectTCP
381
return self.msvc.stopService()
385
class ConnectionTestMixin:
387
def testConnectWithIntroduction(self):
388
ponged = defer.Deferred()
389
self.serverService2.connectQ2Q(self.fromAddress,
392
OneTrickPonyClientFactory(ponged))
393
return ponged.addCallback(lambda answerBox: self.failUnless('tricked' in answerBox))
395
def addClientService(self, toAddress, secret, serverService):
396
return self._addClientService(
397
toAddress.resource, secret, serverService, toAddress.domain)
399
def _addClientService(self, username,
400
privateSecret, serverService,
402
svc = self._makeQ2QService(username + '@' + serverDomain, None)
403
serverService.certificateStorage.addUser(serverDomain,
406
svc.setServiceParent(self.msvc)
407
return svc.authorize(q2q.Q2QAddress(serverDomain, username),
408
privateSecret).addCallback(lambda x: svc)
411
def testListening(self):
412
_1 = self.addClientService(self.toAddress, 'aaaa', self.serverService)
414
self.clientServerService = _1result
415
ponyFactory = OneTrickPonyServerFactory()
416
_2 = self.clientServerService.listenQ2Q(self.toAddress,
417
{'pony2': ponyFactory},
421
_3 = self.addClientService(
422
self.fromAddress, 'bbbb', self.serverService2)
424
self.clientClientService = _3result
426
_4 = defer.Deferred()
427
otpcf = OneTrickPonyClientFactory(_4)
428
self.clientClientService.connectQ2Q(self.fromAddress,
433
T = otpcf.proto.transport
434
self.assertEquals(T.getQ2QPeer(), self.toAddress)
435
self.assertEquals(T.getQ2QHost(), self.fromAddress)
436
self.failUnless('tricked' in answerBox)
438
return _4.addCallback(_4c)
439
return _3.addCallback(_3c)
440
return _2.addCallback(_2c)
441
return _1.addCallback(_1c)
443
def testChooserGetsThreeChoices(self):
446
ponyFactory = OneTrickPonyServerFactory()
447
_1 = self.addClientService(
448
self.toAddress, 'aaaa', self.serverService)
450
self.clientServerService2 = _1result
451
# print 'ultra frack'
453
_2 = self.clientServerService2.listenQ2Q(self.toAddress,
454
{'pony': ponyFactory},
457
_3 = self.clientServerService.listenQ2Q(self.toAddress,
458
{'pony': ponyFactory},
461
expectedList = ['ponies rule', 'ponies are weird', 'test-description']
462
def chooser(servers):
463
self.failUnlessEqual(len(servers), 3)
464
for server in servers:
465
expectedList.remove(server['description'])
466
if server['description'] == 'ponies rule':
468
self.clientServerService.certificateStorage.getPrivateCertificate(str(self.toAddress)),
469
server['certificate'])
472
factory = protocol.ClientFactory()
473
factory.protocol = AMP
474
_4 = self.clientClientService.connectQ2Q(
481
self.failUnlessEqual(expectedList, [])
482
return _4.addCallback(_4c)
483
return _3.addCallback(_3c)
484
return _2.addCallback(_2c)
485
return _1.addCallback(_1c)
486
return self.testListening().addCallback(actualTest)
491
def testTwoGreetings(self):
492
d1 = defer.Deferred()
493
d2 = defer.Deferred()
494
client = Greeter(False, d1)
495
server = Greeter(True, d2)
496
self._addQ2QProtocol('greet', server)
497
self.serverService2.connectQ2Q(self.fromAddress,
502
self.failUnless(client.greeted)
503
self.failUnless(server.greeted)
504
return defer.DeferredList([d1, d2]).addCallback(_)
507
def testSendingFiles(self):
509
self.streamer = StreamingDataFeeder(StringIO('x' * SIZE))
510
self.streamer.CHUNK = 8192
511
a = self.serverService2.connectQ2Q(self.fromAddress,
512
self.toAddress, 'eat',
513
DataFeeder(StringIO('y' * SIZE)))
514
b = self.serverService2.connectQ2Q(self.fromAddress,
515
self.toAddress, 'eat',
519
# self.assertEquals( len(self.serverService.liveConnections), 1)
520
# XXX currently there are 2 connections but there should only be 1: the
521
# connection cache is busted, need a separate test for that
522
for liveConnection in self.serverService.iterconnections():
523
liveConnection.transport.pauseProducing()
524
wfc = self.dataEater.waitForCount(SIZE * 2)
526
def shouldntHappen(x):
530
self.fail("wfc fired with: " + repr(x))
531
wfc.addBoth(shouldntHappen)
534
for liveConnection in self.serverService.iterconnections():
535
liveConnection.transport.resumeProducing()
536
def assertSomeStuff(ign):
537
self.failUnless(self.streamer.pauseCount > 0)
538
self.failUnless(self.streamer.resumeCount > 0)
539
return self.dataEater.waitForCount(SIZE * 2).addCallback(assertSomeStuff)
540
return deferLater(reactor, 3, lambda: None).addCallback(keepGoing)
541
return defer.DeferredList([a, b]).addCallback(dotest)
542
testSendingFiles.skip = "hangs forever"
544
def testBadIssuerOnSelfSignedCert(self):
545
x = self.testConnectWithIntroduction()
546
def actualTest(result):
547
ponged = defer.Deferred()
548
signer = self.serverService2.certificateStorage.getPrivateCertificate(
549
self.fromDomain).privateKey
550
req = signer.requestObject(DistinguishedName(commonName=self.toDomain))
551
sreq = signer.signRequestObject(
552
DistinguishedName(commonName=self.fromDomain), req, 12345)
553
selfSignedLie = PrivateCertificate.fromCertificateAndKeyPair(
555
self.serverService2.connectQ2Q(self.fromAddress,
558
OneTrickPonyClientFactory(ponged),
560
fakeFromDomain=self.toDomain).addErrback(
561
lambda e: e.trap(q2q.VerifyError))
563
return self.assertFailure(ponged, q2q.VerifyError)
564
return x.addCallback(actualTest)
567
def testBadCertRequestSubject(self):
568
kp = KeyPair.generate()
569
subject = DistinguishedName(commonName='HACKERX',
570
localityName='INTERNETANIA')
571
reqobj = kp.requestObject(subject)
573
fakereq = kp.requestObject(subject)
574
ssigned = kp.signRequestObject(subject, fakereq, 1)
575
certpair = PrivateCertificate.fromCertificateAndKeyPair
576
fakecert = certpair(ssigned, kp)
577
apc = self.serverService2.certificateStorage.addPrivateCertificate
580
D = secured.callRemote(
582
certificate_request=reqobj,
583
password='itdoesntmatter')
585
cert = dcert['certificate']
586
privcert = certpair(cert, kp)
587
apc(str(self.fromAddress), privcert)
588
return D.addCallback(_1)
590
d = self.serverService2.getSecureConnection(
591
self.fromAddress, self.fromAddress.domainAddress(), authorize=False,
592
usePrivateCertificate=fakecert,
595
def unexpectedSuccess(result):
596
self.fail("Expected BadCertificateRequest, got %r" % (result,))
597
def expectedFailure(err):
598
err.trap(q2q.BadCertificateRequest)
599
d.addCallbacks(unexpectedSuccess, expectedFailure)
602
def testClientSideUnhandledException(self):
603
d = self.serverService2.connectQ2Q(
604
self.fromAddress, self.toAddress, 'error',
605
ErroneousClientFactory())
606
def connected(proto):
607
return proto.callRemote(EngenderError)
608
d.addCallback(connected)
609
# The unhandled, undeclared error causes the connection to be closed
610
# from the other side.
611
d = self.assertFailure(d, ConnectionDone, UnknownRemoteError)
612
def cbDisconnected(err):
614
len(self.flushLoggedErrors(ErroneousClientError)),
616
d.addCallback(cbDisconnected)
619
def successIsFailure(self, success):
622
def testTwoBadWrites(self):
623
d = self.serverService2.connectQ2Q(
624
self.fromAddress, self.toAddress, 'error',
625
ErroneousClientFactory())
627
def connected(proto):
628
d1 = self.assertFailure(proto.callRemote(Fatal), FatalError)
630
self.assertFailure(proto.callRemote(Flag),
632
d1.addCallback(noMoreCalls)
634
d.addCallback(connected)
640
class VirtualConnection(Q2QConnectionTestCase, ConnectionTestMixin):
641
inboundTCPPortnum = None
643
virtualEnabled = True
645
def testListening(self):
648
def testChooserGetsThreeChoices(self):
651
testListening.skip = 'virtual port forwarding not implemented'
652
testChooserGetsThreeChoices.skip = 'cant do this without testListening'
654
class UDPConnection(Q2QConnectionTestCase, ConnectionTestMixin):
656
inboundTCPPortnum = None
658
virtualEnabled = False
660
class TCPConnection(Q2QConnectionTestCase, ConnectionTestMixin):
661
inboundTCPPortnum = 0
663
virtualEnabled = False
665
# class LiveServerMixin:
666
# serverDomain = 'test.domain.example.com'
668
# def jackDNS(self, *info):
669
# self.fakeDNS = FakeConnectTCP(reactor.connectTCP)
670
# reactor.connectTCP = self.fakeDNS.connectTCP
671
# self._oldResolver = reactor.resolver
672
# reactor.installResolver(self.fakeDNS)
674
# for (hostname, oldport, newport) in info:
675
# self.fakeDNS.addHostPort(hostname, oldport, newport)
677
# def unjackDNS(self):
678
# del reactor.connectTCP
679
# reactor.installResolver(self._oldResolver)
681
# class AuthorizeTestCase(unittest.TestCase, LiveServerMixin):
682
# authUser = 'authtestuser'
683
# authPass = 'p4ssw0rd'
686
# self.testDir, serverService = self.deploy()
687
# self.serverService = service.IServiceCollection(serverService)
688
# self.serverService.startService()
689
# self.clientDir = os.path.join(self.testDir, 'client')
690
# self.clientService = q2qclient.ClientQ2QService(self.clientDir)
692
# q2qPort = self.getQ2QService().q2qPort.getHost().port
693
# self.jackDNS((self.serverDomain, 8788, q2qPort))
695
# def tearDown(self):
697
# util.wait(self.serverService.stopService())
698
# util.wait(self.clientService.stopService())
700
# def testAuthorize(self):
701
# self.createUser(self.authUser, self.authPass)
703
# d = self.clientService.authorize(
704
# q2q.Q2QAddress(self.serverDomain, self.authUser),
707
# result = runOneDeferred(d)
709
# self.failUnless(os.path.exists(os.path.join(self.clientDir, 'public')))
710
# self.failUnless(os.path.exists(os.path.join(self.clientDir, 'public', self.serverDomain + '.pem')))
711
# self.failUnless(os.path.exists(os.path.join(self.clientDir, 'private')))
712
# self.failUnless(os.path.exists(os.path.join(self.clientDir, 'private', self.authUser + '@' + self.serverDomain + '.pem')))