~cbehrens/nova/lp844160-build-works-with-zones

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/protocols/policies.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.test.test_policies -*-
 
2
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
"""
 
6
Resource limiting policies.
 
7
 
 
8
@seealso: See also L{twisted.protocols.htb} for rate limiting.
 
9
"""
 
10
 
 
11
# system imports
 
12
import sys, operator
 
13
 
 
14
from zope.interface import directlyProvides, providedBy
 
15
 
 
16
# twisted imports
 
17
from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
 
18
from twisted.internet import error
 
19
from twisted.python import log
 
20
 
 
21
 
 
22
class ProtocolWrapper(Protocol):
 
23
    """
 
24
    Wraps protocol instances and acts as their transport as well.
 
25
 
 
26
    @ivar wrappedProtocol: An L{IProtocol} provider to which L{IProtocol}
 
27
        method calls onto this L{ProtocolWrapper} will be proxied.
 
28
 
 
29
    @ivar factory: The L{WrappingFactory} which created this
 
30
        L{ProtocolWrapper}.
 
31
    """
 
32
 
 
33
    disconnecting = 0
 
34
 
 
35
    def __init__(self, factory, wrappedProtocol):
 
36
        self.wrappedProtocol = wrappedProtocol
 
37
        self.factory = factory
 
38
 
 
39
    def makeConnection(self, transport):
 
40
        """
 
41
        When a connection is made, register this wrapper with its factory,
 
42
        save the real transport, and connect the wrapped protocol to this
 
43
        L{ProtocolWrapper} to intercept any transport calls it makes.
 
44
        """
 
45
        directlyProvides(self, providedBy(transport))
 
46
        Protocol.makeConnection(self, transport)
 
47
        self.factory.registerProtocol(self)
 
48
        self.wrappedProtocol.makeConnection(self)
 
49
 
 
50
    # Transport relaying
 
51
 
 
52
    def write(self, data):
 
53
        self.transport.write(data)
 
54
 
 
55
    def writeSequence(self, data):
 
56
        self.transport.writeSequence(data)
 
57
 
 
58
    def loseConnection(self):
 
59
        self.disconnecting = 1
 
60
        self.transport.loseConnection()
 
61
 
 
62
    def getPeer(self):
 
63
        return self.transport.getPeer()
 
64
 
 
65
    def getHost(self):
 
66
        return self.transport.getHost()
 
67
 
 
68
    def registerProducer(self, producer, streaming):
 
69
        self.transport.registerProducer(producer, streaming)
 
70
 
 
71
    def unregisterProducer(self):
 
72
        self.transport.unregisterProducer()
 
73
 
 
74
    def stopConsuming(self):
 
75
        self.transport.stopConsuming()
 
76
 
 
77
    def __getattr__(self, name):
 
78
        return getattr(self.transport, name)
 
79
 
 
80
    # Protocol relaying
 
81
 
 
82
    def dataReceived(self, data):
 
83
        self.wrappedProtocol.dataReceived(data)
 
84
 
 
85
    def connectionLost(self, reason):
 
86
        self.factory.unregisterProtocol(self)
 
87
        self.wrappedProtocol.connectionLost(reason)
 
88
 
 
89
 
 
90
class WrappingFactory(ClientFactory):
 
91
    """Wraps a factory and its protocols, and keeps track of them."""
 
92
 
 
93
    protocol = ProtocolWrapper
 
94
 
 
95
    def __init__(self, wrappedFactory):
 
96
        self.wrappedFactory = wrappedFactory
 
97
        self.protocols = {}
 
98
 
 
99
    def doStart(self):
 
100
        self.wrappedFactory.doStart()
 
101
        ClientFactory.doStart(self)
 
102
 
 
103
    def doStop(self):
 
104
        self.wrappedFactory.doStop()
 
105
        ClientFactory.doStop(self)
 
106
 
 
107
    def startedConnecting(self, connector):
 
108
        self.wrappedFactory.startedConnecting(connector)
 
109
 
 
110
    def clientConnectionFailed(self, connector, reason):
 
111
        self.wrappedFactory.clientConnectionFailed(connector, reason)
 
112
 
 
113
    def clientConnectionLost(self, connector, reason):
 
114
        self.wrappedFactory.clientConnectionLost(connector, reason)
 
115
 
 
116
    def buildProtocol(self, addr):
 
117
        return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
 
118
 
 
119
    def registerProtocol(self, p):
 
120
        """Called by protocol to register itself."""
 
121
        self.protocols[p] = 1
 
122
 
 
123
    def unregisterProtocol(self, p):
 
124
        """Called by protocols when they go away."""
 
125
        del self.protocols[p]
 
126
 
 
127
 
 
128
class ThrottlingProtocol(ProtocolWrapper):
 
129
    """Protocol for ThrottlingFactory."""
 
130
 
 
131
    # wrap API for tracking bandwidth
 
132
 
 
133
    def write(self, data):
 
134
        self.factory.registerWritten(len(data))
 
135
        ProtocolWrapper.write(self, data)
 
136
 
 
137
    def writeSequence(self, seq):
 
138
        self.factory.registerWritten(reduce(operator.add, map(len, seq)))
 
139
        ProtocolWrapper.writeSequence(self, seq)
 
140
 
 
141
    def dataReceived(self, data):
 
142
        self.factory.registerRead(len(data))
 
143
        ProtocolWrapper.dataReceived(self, data)
 
144
 
 
145
    def registerProducer(self, producer, streaming):
 
146
        self.producer = producer
 
147
        ProtocolWrapper.registerProducer(self, producer, streaming)
 
148
 
 
149
    def unregisterProducer(self):
 
150
        del self.producer
 
151
        ProtocolWrapper.unregisterProducer(self)
 
152
 
 
153
 
 
154
    def throttleReads(self):
 
155
        self.transport.pauseProducing()
 
156
 
 
157
    def unthrottleReads(self):
 
158
        self.transport.resumeProducing()
 
159
 
 
160
    def throttleWrites(self):
 
161
        if hasattr(self, "producer"):
 
162
            self.producer.pauseProducing()
 
163
 
 
164
    def unthrottleWrites(self):
 
165
        if hasattr(self, "producer"):
 
166
            self.producer.resumeProducing()
 
167
 
 
168
 
 
169
class ThrottlingFactory(WrappingFactory):
 
170
    """
 
171
    Throttles bandwidth and number of connections.
 
172
 
 
173
    Write bandwidth will only be throttled if there is a producer
 
174
    registered.
 
175
    """
 
176
 
 
177
    protocol = ThrottlingProtocol
 
178
 
 
179
    def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
 
180
                 readLimit=None, writeLimit=None):
 
181
        WrappingFactory.__init__(self, wrappedFactory)
 
182
        self.connectionCount = 0
 
183
        self.maxConnectionCount = maxConnectionCount
 
184
        self.readLimit = readLimit # max bytes we should read per second
 
185
        self.writeLimit = writeLimit # max bytes we should write per second
 
186
        self.readThisSecond = 0
 
187
        self.writtenThisSecond = 0
 
188
        self.unthrottleReadsID = None
 
189
        self.checkReadBandwidthID = None
 
190
        self.unthrottleWritesID = None
 
191
        self.checkWriteBandwidthID = None
 
192
 
 
193
 
 
194
    def callLater(self, period, func):
 
195
        """
 
196
        Wrapper around L{reactor.callLater} for test purpose.
 
197
        """
 
198
        from twisted.internet import reactor
 
199
        return reactor.callLater(period, func)
 
200
 
 
201
 
 
202
    def registerWritten(self, length):
 
203
        """
 
204
        Called by protocol to tell us more bytes were written.
 
205
        """
 
206
        self.writtenThisSecond += length
 
207
 
 
208
 
 
209
    def registerRead(self, length):
 
210
        """
 
211
        Called by protocol to tell us more bytes were read.
 
212
        """
 
213
        self.readThisSecond += length
 
214
 
 
215
 
 
216
    def checkReadBandwidth(self):
 
217
        """
 
218
        Checks if we've passed bandwidth limits.
 
219
        """
 
220
        if self.readThisSecond > self.readLimit:
 
221
            self.throttleReads()
 
222
            throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
 
223
            self.unthrottleReadsID = self.callLater(throttleTime,
 
224
                                                    self.unthrottleReads)
 
225
        self.readThisSecond = 0
 
226
        self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
 
227
 
 
228
 
 
229
    def checkWriteBandwidth(self):
 
230
        if self.writtenThisSecond > self.writeLimit:
 
231
            self.throttleWrites()
 
232
            throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
 
233
            self.unthrottleWritesID = self.callLater(throttleTime,
 
234
                                                        self.unthrottleWrites)
 
235
        # reset for next round
 
236
        self.writtenThisSecond = 0
 
237
        self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
 
238
 
 
239
 
 
240
    def throttleReads(self):
 
241
        """
 
242
        Throttle reads on all protocols.
 
243
        """
 
244
        log.msg("Throttling reads on %s" % self)
 
245
        for p in self.protocols.keys():
 
246
            p.throttleReads()
 
247
 
 
248
 
 
249
    def unthrottleReads(self):
 
250
        """
 
251
        Stop throttling reads on all protocols.
 
252
        """
 
253
        self.unthrottleReadsID = None
 
254
        log.msg("Stopped throttling reads on %s" % self)
 
255
        for p in self.protocols.keys():
 
256
            p.unthrottleReads()
 
257
 
 
258
 
 
259
    def throttleWrites(self):
 
260
        """
 
261
        Throttle writes on all protocols.
 
262
        """
 
263
        log.msg("Throttling writes on %s" % self)
 
264
        for p in self.protocols.keys():
 
265
            p.throttleWrites()
 
266
 
 
267
 
 
268
    def unthrottleWrites(self):
 
269
        """
 
270
        Stop throttling writes on all protocols.
 
271
        """
 
272
        self.unthrottleWritesID = None
 
273
        log.msg("Stopped throttling writes on %s" % self)
 
274
        for p in self.protocols.keys():
 
275
            p.unthrottleWrites()
 
276
 
 
277
 
 
278
    def buildProtocol(self, addr):
 
279
        if self.connectionCount == 0:
 
280
            if self.readLimit is not None:
 
281
                self.checkReadBandwidth()
 
282
            if self.writeLimit is not None:
 
283
                self.checkWriteBandwidth()
 
284
 
 
285
        if self.connectionCount < self.maxConnectionCount:
 
286
            self.connectionCount += 1
 
287
            return WrappingFactory.buildProtocol(self, addr)
 
288
        else:
 
289
            log.msg("Max connection count reached!")
 
290
            return None
 
291
 
 
292
 
 
293
    def unregisterProtocol(self, p):
 
294
        WrappingFactory.unregisterProtocol(self, p)
 
295
        self.connectionCount -= 1
 
296
        if self.connectionCount == 0:
 
297
            if self.unthrottleReadsID is not None:
 
298
                self.unthrottleReadsID.cancel()
 
299
            if self.checkReadBandwidthID is not None:
 
300
                self.checkReadBandwidthID.cancel()
 
301
            if self.unthrottleWritesID is not None:
 
302
                self.unthrottleWritesID.cancel()
 
303
            if self.checkWriteBandwidthID is not None:
 
304
                self.checkWriteBandwidthID.cancel()
 
305
 
 
306
 
 
307
 
 
308
class SpewingProtocol(ProtocolWrapper):
 
309
    def dataReceived(self, data):
 
310
        log.msg("Received: %r" % data)
 
311
        ProtocolWrapper.dataReceived(self,data)
 
312
 
 
313
    def write(self, data):
 
314
        log.msg("Sending: %r" % data)
 
315
        ProtocolWrapper.write(self,data)
 
316
 
 
317
 
 
318
 
 
319
class SpewingFactory(WrappingFactory):
 
320
    protocol = SpewingProtocol
 
321
 
 
322
 
 
323
 
 
324
class LimitConnectionsByPeer(WrappingFactory):
 
325
 
 
326
    maxConnectionsPerPeer = 5
 
327
 
 
328
    def startFactory(self):
 
329
        self.peerConnections = {}
 
330
 
 
331
    def buildProtocol(self, addr):
 
332
        peerHost = addr[0]
 
333
        connectionCount = self.peerConnections.get(peerHost, 0)
 
334
        if connectionCount >= self.maxConnectionsPerPeer:
 
335
            return None
 
336
        self.peerConnections[peerHost] = connectionCount + 1
 
337
        return WrappingFactory.buildProtocol(self, addr)
 
338
 
 
339
    def unregisterProtocol(self, p):
 
340
        peerHost = p.getPeer()[1]
 
341
        self.peerConnections[peerHost] -= 1
 
342
        if self.peerConnections[peerHost] == 0:
 
343
            del self.peerConnections[peerHost]
 
344
 
 
345
 
 
346
class LimitTotalConnectionsFactory(ServerFactory):
 
347
    """
 
348
    Factory that limits the number of simultaneous connections.
 
349
 
 
350
    @type connectionCount: C{int}
 
351
    @ivar connectionCount: number of current connections.
 
352
    @type connectionLimit: C{int} or C{None}
 
353
    @cvar connectionLimit: maximum number of connections.
 
354
    @type overflowProtocol: L{Protocol} or C{None}
 
355
    @cvar overflowProtocol: Protocol to use for new connections when
 
356
        connectionLimit is exceeded.  If C{None} (the default value), excess
 
357
        connections will be closed immediately.
 
358
    """
 
359
    connectionCount = 0
 
360
    connectionLimit = None
 
361
    overflowProtocol = None
 
362
 
 
363
    def buildProtocol(self, addr):
 
364
        if (self.connectionLimit is None or
 
365
            self.connectionCount < self.connectionLimit):
 
366
                # Build the normal protocol
 
367
                wrappedProtocol = self.protocol()
 
368
        elif self.overflowProtocol is None:
 
369
            # Just drop the connection
 
370
            return None
 
371
        else:
 
372
            # Too many connections, so build the overflow protocol
 
373
            wrappedProtocol = self.overflowProtocol()
 
374
 
 
375
        wrappedProtocol.factory = self
 
376
        protocol = ProtocolWrapper(self, wrappedProtocol)
 
377
        self.connectionCount += 1
 
378
        return protocol
 
379
 
 
380
    def registerProtocol(self, p):
 
381
        pass
 
382
 
 
383
    def unregisterProtocol(self, p):
 
384
        self.connectionCount -= 1
 
385
 
 
386
 
 
387
 
 
388
class TimeoutProtocol(ProtocolWrapper):
 
389
    """
 
390
    Protocol that automatically disconnects when the connection is idle.
 
391
    """
 
392
 
 
393
    def __init__(self, factory, wrappedProtocol, timeoutPeriod):
 
394
        """
 
395
        Constructor.
 
396
 
 
397
        @param factory: An L{IFactory}.
 
398
        @param wrappedProtocol: A L{Protocol} to wrapp.
 
399
        @param timeoutPeriod: Number of seconds to wait for activity before
 
400
            timing out.
 
401
        """
 
402
        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
 
403
        self.timeoutCall = None
 
404
        self.setTimeout(timeoutPeriod)
 
405
 
 
406
 
 
407
    def setTimeout(self, timeoutPeriod=None):
 
408
        """
 
409
        Set a timeout.
 
410
 
 
411
        This will cancel any existing timeouts.
 
412
 
 
413
        @param timeoutPeriod: If not C{None}, change the timeout period.
 
414
            Otherwise, use the existing value.
 
415
        """
 
416
        self.cancelTimeout()
 
417
        if timeoutPeriod is not None:
 
418
            self.timeoutPeriod = timeoutPeriod
 
419
        self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
 
420
 
 
421
 
 
422
    def cancelTimeout(self):
 
423
        """
 
424
        Cancel the timeout.
 
425
 
 
426
        If the timeout was already cancelled, this does nothing.
 
427
        """
 
428
        if self.timeoutCall:
 
429
            try:
 
430
                self.timeoutCall.cancel()
 
431
            except error.AlreadyCalled:
 
432
                pass
 
433
            self.timeoutCall = None
 
434
 
 
435
 
 
436
    def resetTimeout(self):
 
437
        """
 
438
        Reset the timeout, usually because some activity just happened.
 
439
        """
 
440
        if self.timeoutCall:
 
441
            self.timeoutCall.reset(self.timeoutPeriod)
 
442
 
 
443
 
 
444
    def write(self, data):
 
445
        self.resetTimeout()
 
446
        ProtocolWrapper.write(self, data)
 
447
 
 
448
 
 
449
    def writeSequence(self, seq):
 
450
        self.resetTimeout()
 
451
        ProtocolWrapper.writeSequence(self, seq)
 
452
 
 
453
 
 
454
    def dataReceived(self, data):
 
455
        self.resetTimeout()
 
456
        ProtocolWrapper.dataReceived(self, data)
 
457
 
 
458
 
 
459
    def connectionLost(self, reason):
 
460
        self.cancelTimeout()
 
461
        ProtocolWrapper.connectionLost(self, reason)
 
462
 
 
463
 
 
464
    def timeoutFunc(self):
 
465
        """
 
466
        This method is called when the timeout is triggered.
 
467
 
 
468
        By default it calls L{loseConnection}.  Override this if you want
 
469
        something else to happen.
 
470
        """
 
471
        self.loseConnection()
 
472
 
 
473
 
 
474
 
 
475
class TimeoutFactory(WrappingFactory):
 
476
    """
 
477
    Factory for TimeoutWrapper.
 
478
    """
 
479
    protocol = TimeoutProtocol
 
480
 
 
481
 
 
482
    def __init__(self, wrappedFactory, timeoutPeriod=30*60):
 
483
        self.timeoutPeriod = timeoutPeriod
 
484
        WrappingFactory.__init__(self, wrappedFactory)
 
485
 
 
486
 
 
487
    def buildProtocol(self, addr):
 
488
        return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
 
489
                             timeoutPeriod=self.timeoutPeriod)
 
490
 
 
491
 
 
492
    def callLater(self, period, func):
 
493
        """
 
494
        Wrapper around L{reactor.callLater} for test purpose.
 
495
        """
 
496
        from twisted.internet import reactor
 
497
        return reactor.callLater(period, func)
 
498
 
 
499
 
 
500
 
 
501
class TrafficLoggingProtocol(ProtocolWrapper):
 
502
 
 
503
    def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
 
504
                 number=0):
 
505
        """
 
506
        @param factory: factory which created this protocol.
 
507
        @type factory: C{protocol.Factory}.
 
508
        @param wrappedProtocol: the underlying protocol.
 
509
        @type wrappedProtocol: C{protocol.Protocol}.
 
510
        @param logfile: file opened for writing used to write log messages.
 
511
        @type logfile: C{file}
 
512
        @param lengthLimit: maximum size of the datareceived logged.
 
513
        @type lengthLimit: C{int}
 
514
        @param number: identifier of the connection.
 
515
        @type number: C{int}.
 
516
        """
 
517
        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
 
518
        self.logfile = logfile
 
519
        self.lengthLimit = lengthLimit
 
520
        self._number = number
 
521
 
 
522
 
 
523
    def _log(self, line):
 
524
        self.logfile.write(line + '\n')
 
525
        self.logfile.flush()
 
526
 
 
527
 
 
528
    def _mungeData(self, data):
 
529
        if self.lengthLimit and len(data) > self.lengthLimit:
 
530
            data = data[:self.lengthLimit - 12] + '<... elided>'
 
531
        return data
 
532
 
 
533
 
 
534
    # IProtocol
 
535
    def connectionMade(self):
 
536
        self._log('*')
 
537
        return ProtocolWrapper.connectionMade(self)
 
538
 
 
539
 
 
540
    def dataReceived(self, data):
 
541
        self._log('C %d: %r' % (self._number, self._mungeData(data)))
 
542
        return ProtocolWrapper.dataReceived(self, data)
 
543
 
 
544
 
 
545
    def connectionLost(self, reason):
 
546
        self._log('C %d: %r' % (self._number, reason))
 
547
        return ProtocolWrapper.connectionLost(self, reason)
 
548
 
 
549
 
 
550
    # ITransport
 
551
    def write(self, data):
 
552
        self._log('S %d: %r' % (self._number, self._mungeData(data)))
 
553
        return ProtocolWrapper.write(self, data)
 
554
 
 
555
 
 
556
    def writeSequence(self, iovec):
 
557
        self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
 
558
        return ProtocolWrapper.writeSequence(self, iovec)
 
559
 
 
560
 
 
561
    def loseConnection(self):
 
562
        self._log('S %d: *' % (self._number,))
 
563
        return ProtocolWrapper.loseConnection(self)
 
564
 
 
565
 
 
566
 
 
567
class TrafficLoggingFactory(WrappingFactory):
 
568
    protocol = TrafficLoggingProtocol
 
569
 
 
570
    _counter = 0
 
571
 
 
572
    def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
 
573
        self.logfilePrefix = logfilePrefix
 
574
        self.lengthLimit = lengthLimit
 
575
        WrappingFactory.__init__(self, wrappedFactory)
 
576
 
 
577
 
 
578
    def open(self, name):
 
579
        return file(name, 'w')
 
580
 
 
581
 
 
582
    def buildProtocol(self, addr):
 
583
        self._counter += 1
 
584
        logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
 
585
        return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
 
586
                             logfile, self.lengthLimit, self._counter)
 
587
 
 
588
 
 
589
    def resetCounter(self):
 
590
        """
 
591
        Reset the value of the counter used to identify connections.
 
592
        """
 
593
        self._counter = 0
 
594
 
 
595
 
 
596
 
 
597
class TimeoutMixin:
 
598
    """Mixin for protocols which wish to timeout connections
 
599
 
 
600
    @cvar timeOut: The number of seconds after which to timeout the connection.
 
601
    """
 
602
    timeOut = None
 
603
 
 
604
    __timeoutCall = None
 
605
 
 
606
    def callLater(self, period, func):
 
607
        from twisted.internet import reactor
 
608
        return reactor.callLater(period, func)
 
609
 
 
610
 
 
611
    def resetTimeout(self):
 
612
        """Reset the timeout count down"""
 
613
        if self.__timeoutCall is not None and self.timeOut is not None:
 
614
            self.__timeoutCall.reset(self.timeOut)
 
615
 
 
616
    def setTimeout(self, period):
 
617
        """Change the timeout period
 
618
 
 
619
        @type period: C{int} or C{NoneType}
 
620
        @param period: The period, in seconds, to change the timeout to, or
 
621
        C{None} to disable the timeout.
 
622
        """
 
623
        prev = self.timeOut
 
624
        self.timeOut = period
 
625
 
 
626
        if self.__timeoutCall is not None:
 
627
            if period is None:
 
628
                self.__timeoutCall.cancel()
 
629
                self.__timeoutCall = None
 
630
            else:
 
631
                self.__timeoutCall.reset(period)
 
632
        elif period is not None:
 
633
            self.__timeoutCall = self.callLater(period, self.__timedOut)
 
634
 
 
635
        return prev
 
636
 
 
637
    def __timedOut(self):
 
638
        self.__timeoutCall = None
 
639
        self.timeoutConnection()
 
640
 
 
641
    def timeoutConnection(self):
 
642
        """Called when the connection times out.
 
643
        Override to define behavior other than dropping the connection.
 
644
        """
 
645
        self.transport.loseConnection()