1
# -*- test-case-name: twisted.test.test_policies -*-
2
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
3
# See LICENSE for details.
6
Resource limiting policies.
8
@seealso: See also L{twisted.protocols.htb} for rate limiting.
14
from zope.interface import directlyProvides, providedBy
17
from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
18
from twisted.internet import error
19
from twisted.python import log
22
class ProtocolWrapper(Protocol):
24
Wraps protocol instances and acts as their transport as well.
26
@ivar wrappedProtocol: An L{IProtocol} provider to which L{IProtocol}
27
method calls onto this L{ProtocolWrapper} will be proxied.
29
@ivar factory: The L{WrappingFactory} which created this
35
def __init__(self, factory, wrappedProtocol):
36
self.wrappedProtocol = wrappedProtocol
37
self.factory = factory
39
def makeConnection(self, transport):
41
When a connection is made, register this wrapper with its factory,
42
save the real transport, and connect the wrapped protocol to this
43
L{ProtocolWrapper} to intercept any transport calls it makes.
45
directlyProvides(self, providedBy(transport))
46
Protocol.makeConnection(self, transport)
47
self.factory.registerProtocol(self)
48
self.wrappedProtocol.makeConnection(self)
52
def write(self, data):
53
self.transport.write(data)
55
def writeSequence(self, data):
56
self.transport.writeSequence(data)
58
def loseConnection(self):
59
self.disconnecting = 1
60
self.transport.loseConnection()
63
return self.transport.getPeer()
66
return self.transport.getHost()
68
def registerProducer(self, producer, streaming):
69
self.transport.registerProducer(producer, streaming)
71
def unregisterProducer(self):
72
self.transport.unregisterProducer()
74
def stopConsuming(self):
75
self.transport.stopConsuming()
77
def __getattr__(self, name):
78
return getattr(self.transport, name)
82
def dataReceived(self, data):
83
self.wrappedProtocol.dataReceived(data)
85
def connectionLost(self, reason):
86
self.factory.unregisterProtocol(self)
87
self.wrappedProtocol.connectionLost(reason)
90
class WrappingFactory(ClientFactory):
91
"""Wraps a factory and its protocols, and keeps track of them."""
93
protocol = ProtocolWrapper
95
def __init__(self, wrappedFactory):
96
self.wrappedFactory = wrappedFactory
100
self.wrappedFactory.doStart()
101
ClientFactory.doStart(self)
104
self.wrappedFactory.doStop()
105
ClientFactory.doStop(self)
107
def startedConnecting(self, connector):
108
self.wrappedFactory.startedConnecting(connector)
110
def clientConnectionFailed(self, connector, reason):
111
self.wrappedFactory.clientConnectionFailed(connector, reason)
113
def clientConnectionLost(self, connector, reason):
114
self.wrappedFactory.clientConnectionLost(connector, reason)
116
def buildProtocol(self, addr):
117
return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
119
def registerProtocol(self, p):
120
"""Called by protocol to register itself."""
121
self.protocols[p] = 1
123
def unregisterProtocol(self, p):
124
"""Called by protocols when they go away."""
125
del self.protocols[p]
128
class ThrottlingProtocol(ProtocolWrapper):
129
"""Protocol for ThrottlingFactory."""
131
# wrap API for tracking bandwidth
133
def write(self, data):
134
self.factory.registerWritten(len(data))
135
ProtocolWrapper.write(self, data)
137
def writeSequence(self, seq):
138
self.factory.registerWritten(reduce(operator.add, map(len, seq)))
139
ProtocolWrapper.writeSequence(self, seq)
141
def dataReceived(self, data):
142
self.factory.registerRead(len(data))
143
ProtocolWrapper.dataReceived(self, data)
145
def registerProducer(self, producer, streaming):
146
self.producer = producer
147
ProtocolWrapper.registerProducer(self, producer, streaming)
149
def unregisterProducer(self):
151
ProtocolWrapper.unregisterProducer(self)
154
def throttleReads(self):
155
self.transport.pauseProducing()
157
def unthrottleReads(self):
158
self.transport.resumeProducing()
160
def throttleWrites(self):
161
if hasattr(self, "producer"):
162
self.producer.pauseProducing()
164
def unthrottleWrites(self):
165
if hasattr(self, "producer"):
166
self.producer.resumeProducing()
169
class ThrottlingFactory(WrappingFactory):
171
Throttles bandwidth and number of connections.
173
Write bandwidth will only be throttled if there is a producer
177
protocol = ThrottlingProtocol
179
def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
180
readLimit=None, writeLimit=None):
181
WrappingFactory.__init__(self, wrappedFactory)
182
self.connectionCount = 0
183
self.maxConnectionCount = maxConnectionCount
184
self.readLimit = readLimit # max bytes we should read per second
185
self.writeLimit = writeLimit # max bytes we should write per second
186
self.readThisSecond = 0
187
self.writtenThisSecond = 0
188
self.unthrottleReadsID = None
189
self.checkReadBandwidthID = None
190
self.unthrottleWritesID = None
191
self.checkWriteBandwidthID = None
194
def callLater(self, period, func):
196
Wrapper around L{reactor.callLater} for test purpose.
198
from twisted.internet import reactor
199
return reactor.callLater(period, func)
202
def registerWritten(self, length):
204
Called by protocol to tell us more bytes were written.
206
self.writtenThisSecond += length
209
def registerRead(self, length):
211
Called by protocol to tell us more bytes were read.
213
self.readThisSecond += length
216
def checkReadBandwidth(self):
218
Checks if we've passed bandwidth limits.
220
if self.readThisSecond > self.readLimit:
222
throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
223
self.unthrottleReadsID = self.callLater(throttleTime,
224
self.unthrottleReads)
225
self.readThisSecond = 0
226
self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
229
def checkWriteBandwidth(self):
230
if self.writtenThisSecond > self.writeLimit:
231
self.throttleWrites()
232
throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
233
self.unthrottleWritesID = self.callLater(throttleTime,
234
self.unthrottleWrites)
235
# reset for next round
236
self.writtenThisSecond = 0
237
self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
240
def throttleReads(self):
242
Throttle reads on all protocols.
244
log.msg("Throttling reads on %s" % self)
245
for p in self.protocols.keys():
249
def unthrottleReads(self):
251
Stop throttling reads on all protocols.
253
self.unthrottleReadsID = None
254
log.msg("Stopped throttling reads on %s" % self)
255
for p in self.protocols.keys():
259
def throttleWrites(self):
261
Throttle writes on all protocols.
263
log.msg("Throttling writes on %s" % self)
264
for p in self.protocols.keys():
268
def unthrottleWrites(self):
270
Stop throttling writes on all protocols.
272
self.unthrottleWritesID = None
273
log.msg("Stopped throttling writes on %s" % self)
274
for p in self.protocols.keys():
278
def buildProtocol(self, addr):
279
if self.connectionCount == 0:
280
if self.readLimit is not None:
281
self.checkReadBandwidth()
282
if self.writeLimit is not None:
283
self.checkWriteBandwidth()
285
if self.connectionCount < self.maxConnectionCount:
286
self.connectionCount += 1
287
return WrappingFactory.buildProtocol(self, addr)
289
log.msg("Max connection count reached!")
293
def unregisterProtocol(self, p):
294
WrappingFactory.unregisterProtocol(self, p)
295
self.connectionCount -= 1
296
if self.connectionCount == 0:
297
if self.unthrottleReadsID is not None:
298
self.unthrottleReadsID.cancel()
299
if self.checkReadBandwidthID is not None:
300
self.checkReadBandwidthID.cancel()
301
if self.unthrottleWritesID is not None:
302
self.unthrottleWritesID.cancel()
303
if self.checkWriteBandwidthID is not None:
304
self.checkWriteBandwidthID.cancel()
308
class SpewingProtocol(ProtocolWrapper):
309
def dataReceived(self, data):
310
log.msg("Received: %r" % data)
311
ProtocolWrapper.dataReceived(self,data)
313
def write(self, data):
314
log.msg("Sending: %r" % data)
315
ProtocolWrapper.write(self,data)
319
class SpewingFactory(WrappingFactory):
320
protocol = SpewingProtocol
324
class LimitConnectionsByPeer(WrappingFactory):
326
maxConnectionsPerPeer = 5
328
def startFactory(self):
329
self.peerConnections = {}
331
def buildProtocol(self, addr):
333
connectionCount = self.peerConnections.get(peerHost, 0)
334
if connectionCount >= self.maxConnectionsPerPeer:
336
self.peerConnections[peerHost] = connectionCount + 1
337
return WrappingFactory.buildProtocol(self, addr)
339
def unregisterProtocol(self, p):
340
peerHost = p.getPeer()[1]
341
self.peerConnections[peerHost] -= 1
342
if self.peerConnections[peerHost] == 0:
343
del self.peerConnections[peerHost]
346
class LimitTotalConnectionsFactory(ServerFactory):
348
Factory that limits the number of simultaneous connections.
350
@type connectionCount: C{int}
351
@ivar connectionCount: number of current connections.
352
@type connectionLimit: C{int} or C{None}
353
@cvar connectionLimit: maximum number of connections.
354
@type overflowProtocol: L{Protocol} or C{None}
355
@cvar overflowProtocol: Protocol to use for new connections when
356
connectionLimit is exceeded. If C{None} (the default value), excess
357
connections will be closed immediately.
360
connectionLimit = None
361
overflowProtocol = None
363
def buildProtocol(self, addr):
364
if (self.connectionLimit is None or
365
self.connectionCount < self.connectionLimit):
366
# Build the normal protocol
367
wrappedProtocol = self.protocol()
368
elif self.overflowProtocol is None:
369
# Just drop the connection
372
# Too many connections, so build the overflow protocol
373
wrappedProtocol = self.overflowProtocol()
375
wrappedProtocol.factory = self
376
protocol = ProtocolWrapper(self, wrappedProtocol)
377
self.connectionCount += 1
380
def registerProtocol(self, p):
383
def unregisterProtocol(self, p):
384
self.connectionCount -= 1
388
class TimeoutProtocol(ProtocolWrapper):
390
Protocol that automatically disconnects when the connection is idle.
393
def __init__(self, factory, wrappedProtocol, timeoutPeriod):
397
@param factory: An L{IFactory}.
398
@param wrappedProtocol: A L{Protocol} to wrapp.
399
@param timeoutPeriod: Number of seconds to wait for activity before
402
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
403
self.timeoutCall = None
404
self.setTimeout(timeoutPeriod)
407
def setTimeout(self, timeoutPeriod=None):
411
This will cancel any existing timeouts.
413
@param timeoutPeriod: If not C{None}, change the timeout period.
414
Otherwise, use the existing value.
417
if timeoutPeriod is not None:
418
self.timeoutPeriod = timeoutPeriod
419
self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
422
def cancelTimeout(self):
426
If the timeout was already cancelled, this does nothing.
430
self.timeoutCall.cancel()
431
except error.AlreadyCalled:
433
self.timeoutCall = None
436
def resetTimeout(self):
438
Reset the timeout, usually because some activity just happened.
441
self.timeoutCall.reset(self.timeoutPeriod)
444
def write(self, data):
446
ProtocolWrapper.write(self, data)
449
def writeSequence(self, seq):
451
ProtocolWrapper.writeSequence(self, seq)
454
def dataReceived(self, data):
456
ProtocolWrapper.dataReceived(self, data)
459
def connectionLost(self, reason):
461
ProtocolWrapper.connectionLost(self, reason)
464
def timeoutFunc(self):
466
This method is called when the timeout is triggered.
468
By default it calls L{loseConnection}. Override this if you want
469
something else to happen.
471
self.loseConnection()
475
class TimeoutFactory(WrappingFactory):
477
Factory for TimeoutWrapper.
479
protocol = TimeoutProtocol
482
def __init__(self, wrappedFactory, timeoutPeriod=30*60):
483
self.timeoutPeriod = timeoutPeriod
484
WrappingFactory.__init__(self, wrappedFactory)
487
def buildProtocol(self, addr):
488
return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
489
timeoutPeriod=self.timeoutPeriod)
492
def callLater(self, period, func):
494
Wrapper around L{reactor.callLater} for test purpose.
496
from twisted.internet import reactor
497
return reactor.callLater(period, func)
501
class TrafficLoggingProtocol(ProtocolWrapper):
503
def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
506
@param factory: factory which created this protocol.
507
@type factory: C{protocol.Factory}.
508
@param wrappedProtocol: the underlying protocol.
509
@type wrappedProtocol: C{protocol.Protocol}.
510
@param logfile: file opened for writing used to write log messages.
511
@type logfile: C{file}
512
@param lengthLimit: maximum size of the datareceived logged.
513
@type lengthLimit: C{int}
514
@param number: identifier of the connection.
515
@type number: C{int}.
517
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
518
self.logfile = logfile
519
self.lengthLimit = lengthLimit
520
self._number = number
523
def _log(self, line):
524
self.logfile.write(line + '\n')
528
def _mungeData(self, data):
529
if self.lengthLimit and len(data) > self.lengthLimit:
530
data = data[:self.lengthLimit - 12] + '<... elided>'
535
def connectionMade(self):
537
return ProtocolWrapper.connectionMade(self)
540
def dataReceived(self, data):
541
self._log('C %d: %r' % (self._number, self._mungeData(data)))
542
return ProtocolWrapper.dataReceived(self, data)
545
def connectionLost(self, reason):
546
self._log('C %d: %r' % (self._number, reason))
547
return ProtocolWrapper.connectionLost(self, reason)
551
def write(self, data):
552
self._log('S %d: %r' % (self._number, self._mungeData(data)))
553
return ProtocolWrapper.write(self, data)
556
def writeSequence(self, iovec):
557
self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
558
return ProtocolWrapper.writeSequence(self, iovec)
561
def loseConnection(self):
562
self._log('S %d: *' % (self._number,))
563
return ProtocolWrapper.loseConnection(self)
567
class TrafficLoggingFactory(WrappingFactory):
568
protocol = TrafficLoggingProtocol
572
def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
573
self.logfilePrefix = logfilePrefix
574
self.lengthLimit = lengthLimit
575
WrappingFactory.__init__(self, wrappedFactory)
578
def open(self, name):
579
return file(name, 'w')
582
def buildProtocol(self, addr):
584
logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
585
return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
586
logfile, self.lengthLimit, self._counter)
589
def resetCounter(self):
591
Reset the value of the counter used to identify connections.
598
"""Mixin for protocols which wish to timeout connections
600
@cvar timeOut: The number of seconds after which to timeout the connection.
606
def callLater(self, period, func):
607
from twisted.internet import reactor
608
return reactor.callLater(period, func)
611
def resetTimeout(self):
612
"""Reset the timeout count down"""
613
if self.__timeoutCall is not None and self.timeOut is not None:
614
self.__timeoutCall.reset(self.timeOut)
616
def setTimeout(self, period):
617
"""Change the timeout period
619
@type period: C{int} or C{NoneType}
620
@param period: The period, in seconds, to change the timeout to, or
621
C{None} to disable the timeout.
624
self.timeOut = period
626
if self.__timeoutCall is not None:
628
self.__timeoutCall.cancel()
629
self.__timeoutCall = None
631
self.__timeoutCall.reset(period)
632
elif period is not None:
633
self.__timeoutCall = self.callLater(period, self.__timedOut)
637
def __timedOut(self):
638
self.__timeoutCall = None
639
self.timeoutConnection()
641
def timeoutConnection(self):
642
"""Called when the connection times out.
643
Override to define behavior other than dropping the connection.
645
self.transport.loseConnection()