~landscape/zope3/newer-from-ztk

« back to all changes in this revision

Viewing changes to src/twisted/protocols/basic.py

  • Committer: Thomas Hervé
  • Date: 2009-07-08 13:52:04 UTC
  • Revision ID: thomas@canonical.com-20090708135204-df5eesrthifpylf8
Remove twisted copy

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# -*- test-case-name: twisted.test.test_protocols -*-
2
 
 
3
 
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
4
 
# See LICENSE for details.
5
 
 
6
 
 
7
 
"""Basic protocols, such as line-oriented, netstring, and 32-bit-int prefixed strings.
8
 
 
9
 
API Stability: semi-stable.
10
 
 
11
 
Maintainer: U{Itamar Shtull-Trauring<mailto:twisted@itamarst.org>}
12
 
"""
13
 
 
14
 
# System imports
15
 
import re
16
 
import struct
17
 
 
18
 
# Twisted imports
19
 
from twisted.internet import protocol, defer, interfaces, error
20
 
from twisted.python import log
21
 
from zope.interface import implements
22
 
 
23
 
LENGTH, DATA, COMMA = range(3)
24
 
NUMBER = re.compile('(\d*)(:?)')
25
 
DEBUG = 0
26
 
 
27
 
class NetstringParseError(ValueError):
28
 
    """The incoming data is not in valid Netstring format."""
29
 
    pass
30
 
 
31
 
 
32
 
class NetstringReceiver(protocol.Protocol):
33
 
    """This uses djb's Netstrings protocol to break up the input into strings.
34
 
 
35
 
    Each string makes a callback to stringReceived, with a single
36
 
    argument of that string.
37
 
 
38
 
    Security features:
39
 
        1. Messages are limited in size, useful if you don't want someone
40
 
           sending you a 500MB netstring (change MAX_LENGTH to the maximum
41
 
           length you wish to accept).
42
 
        2. The connection is lost if an illegal message is received.
43
 
    """
44
 
 
45
 
    MAX_LENGTH = 99999
46
 
    brokenPeer = 0
47
 
    _readerState = LENGTH
48
 
    _readerLength = 0
49
 
 
50
 
    def stringReceived(self, line):
51
 
        """
52
 
        Override this.
53
 
        """
54
 
        raise NotImplementedError
55
 
 
56
 
    def doData(self):
57
 
        buffer,self.__data = self.__data[:int(self._readerLength)],self.__data[int(self._readerLength):]
58
 
        self._readerLength = self._readerLength - len(buffer)
59
 
        self.__buffer = self.__buffer + buffer
60
 
        if self._readerLength != 0:
61
 
            return
62
 
        self.stringReceived(self.__buffer)
63
 
        self._readerState = COMMA
64
 
 
65
 
    def doComma(self):
66
 
        self._readerState = LENGTH
67
 
        if self.__data[0] != ',':
68
 
            if DEBUG:
69
 
                raise NetstringParseError(repr(self.__data))
70
 
            else:
71
 
                raise NetstringParseError
72
 
        self.__data = self.__data[1:]
73
 
 
74
 
 
75
 
    def doLength(self):
76
 
        m = NUMBER.match(self.__data)
77
 
        if not m.end():
78
 
            if DEBUG:
79
 
                raise NetstringParseError(repr(self.__data))
80
 
            else:
81
 
                raise NetstringParseError
82
 
        self.__data = self.__data[m.end():]
83
 
        if m.group(1):
84
 
            try:
85
 
                self._readerLength = self._readerLength * (10**len(m.group(1))) + long(m.group(1))
86
 
            except OverflowError:
87
 
                raise NetstringParseError, "netstring too long"
88
 
            if self._readerLength > self.MAX_LENGTH:
89
 
                raise NetstringParseError, "netstring too long"
90
 
        if m.group(2):
91
 
            self.__buffer = ''
92
 
            self._readerState = DATA
93
 
 
94
 
    def dataReceived(self, data):
95
 
        self.__data = data
96
 
        try:
97
 
            while self.__data:
98
 
                if self._readerState == DATA:
99
 
                    self.doData()
100
 
                elif self._readerState == COMMA:
101
 
                    self.doComma()
102
 
                elif self._readerState == LENGTH:
103
 
                    self.doLength()
104
 
                else:
105
 
                    raise RuntimeError, "mode is not DATA, COMMA or LENGTH"
106
 
        except NetstringParseError:
107
 
            self.transport.loseConnection()
108
 
            self.brokenPeer = 1
109
 
 
110
 
    def sendString(self, data):
111
 
        self.transport.write('%d:%s,' % (len(data), data))
112
 
 
113
 
 
114
 
class SafeNetstringReceiver(NetstringReceiver):
115
 
    """This class is deprecated, use NetstringReceiver instead.
116
 
    """
117
 
 
118
 
 
119
 
class LineOnlyReceiver(protocol.Protocol):
120
 
    """A protocol that receives only lines.
121
 
 
122
 
    This is purely a speed optimisation over LineReceiver, for the
123
 
    cases that raw mode is known to be unnecessary.
124
 
 
125
 
    @cvar delimiter: The line-ending delimiter to use. By default this is
126
 
                     '\\r\\n'.
127
 
    @cvar MAX_LENGTH: The maximum length of a line to allow (If a
128
 
                      sent line is longer than this, the connection is dropped).
129
 
                      Default is 16384.
130
 
    """
131
 
    _buffer = ''
132
 
    delimiter = '\r\n'
133
 
    MAX_LENGTH = 16384
134
 
 
135
 
    def dataReceived(self, data):
136
 
        """Translates bytes into lines, and calls lineReceived."""
137
 
        lines  = (self._buffer+data).split(self.delimiter)
138
 
        self._buffer = lines.pop(-1)
139
 
        for line in lines:
140
 
            if self.transport.disconnecting:
141
 
                # this is necessary because the transport may be told to lose
142
 
                # the connection by a line within a larger packet, and it is
143
 
                # important to disregard all the lines in that packet following
144
 
                # the one that told it to close.
145
 
                return
146
 
            if len(line) > self.MAX_LENGTH:
147
 
                return self.lineLengthExceeded(line)
148
 
            else:
149
 
                self.lineReceived(line)
150
 
        if len(self._buffer) > self.MAX_LENGTH:
151
 
            return self.lineLengthExceeded(self._buffer)
152
 
 
153
 
    def lineReceived(self, line):
154
 
        """Override this for when each line is received.
155
 
        """
156
 
        raise NotImplementedError
157
 
 
158
 
    def sendLine(self, line):
159
 
        """Sends a line to the other end of the connection.
160
 
        """
161
 
        return self.transport.writeSequence((line,self.delimiter))
162
 
 
163
 
    def lineLengthExceeded(self, line):
164
 
        """Called when the maximum line length has been reached.
165
 
        Override if it needs to be dealt with in some special way.
166
 
        """
167
 
        return error.ConnectionLost('Line length exceeded')
168
 
 
169
 
 
170
 
class _PauseableMixin:
171
 
    paused = False
172
 
 
173
 
    def pauseProducing(self):
174
 
        self.paused = True
175
 
        self.transport.pauseProducing()
176
 
 
177
 
    def resumeProducing(self):
178
 
        self.paused = False
179
 
        self.transport.resumeProducing()
180
 
        self.dataReceived('')
181
 
 
182
 
    def stopProducing(self):
183
 
        self.paused = True
184
 
        self.transport.stopProducing()
185
 
 
186
 
 
187
 
class LineReceiver(protocol.Protocol, _PauseableMixin):
188
 
    """A protocol that receives lines and/or raw data, depending on mode.
189
 
 
190
 
    In line mode, each line that's received becomes a callback to
191
 
    L{lineReceived}.  In raw data mode, each chunk of raw data becomes a
192
 
    callback to L{rawDataReceived}.  The L{setLineMode} and L{setRawMode}
193
 
    methods switch between the two modes.
194
 
 
195
 
    This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
196
 
 
197
 
    @cvar delimiter: The line-ending delimiter to use. By default this is
198
 
                     '\\r\\n'.
199
 
    @cvar MAX_LENGTH: The maximum length of a line to allow (If a
200
 
                      sent line is longer than this, the connection is dropped).
201
 
                      Default is 16384.
202
 
    """
203
 
    line_mode = 1
204
 
    __buffer = ''
205
 
    delimiter = '\r\n'
206
 
    MAX_LENGTH = 16384
207
 
 
208
 
    def clearLineBuffer(self):
209
 
        """Clear buffered data."""
210
 
        self.__buffer = ""
211
 
 
212
 
    def dataReceived(self, data):
213
 
        """Protocol.dataReceived.
214
 
        Translates bytes into lines, and calls lineReceived (or
215
 
        rawDataReceived, depending on mode.)
216
 
        """
217
 
        self.__buffer = self.__buffer+data
218
 
        while self.line_mode and not self.paused:
219
 
            try:
220
 
                line, self.__buffer = self.__buffer.split(self.delimiter, 1)
221
 
            except ValueError:
222
 
                if len(self.__buffer) > self.MAX_LENGTH:
223
 
                    line, self.__buffer = self.__buffer, ''
224
 
                    return self.lineLengthExceeded(line)
225
 
                break
226
 
            else:
227
 
                linelength = len(line)
228
 
                if linelength > self.MAX_LENGTH:
229
 
                    exceeded = line + self.__buffer
230
 
                    self.__buffer = ''
231
 
                    return self.lineLengthExceeded(exceeded)
232
 
                why = self.lineReceived(line)
233
 
                if why or self.transport and self.transport.disconnecting:
234
 
                    return why
235
 
        else:
236
 
            if not self.paused:
237
 
                data=self.__buffer
238
 
                self.__buffer=''
239
 
                if data:
240
 
                    return self.rawDataReceived(data)
241
 
 
242
 
    def setLineMode(self, extra=''):
243
 
        """Sets the line-mode of this receiver.
244
 
 
245
 
        If you are calling this from a rawDataReceived callback,
246
 
        you can pass in extra unhandled data, and that data will
247
 
        be parsed for lines.  Further data received will be sent
248
 
        to lineReceived rather than rawDataReceived.
249
 
 
250
 
        Do not pass extra data if calling this function from
251
 
        within a lineReceived callback.
252
 
        """
253
 
        self.line_mode = 1
254
 
        if extra:
255
 
            return self.dataReceived(extra)
256
 
 
257
 
    def setRawMode(self):
258
 
        """Sets the raw mode of this receiver.
259
 
        Further data received will be sent to rawDataReceived rather
260
 
        than lineReceived.
261
 
        """
262
 
        self.line_mode = 0
263
 
 
264
 
    def rawDataReceived(self, data):
265
 
        """Override this for when raw data is received.
266
 
        """
267
 
        raise NotImplementedError
268
 
 
269
 
    def lineReceived(self, line):
270
 
        """Override this for when each line is received.
271
 
        """
272
 
        raise NotImplementedError
273
 
 
274
 
    def sendLine(self, line):
275
 
        """Sends a line to the other end of the connection.
276
 
        """
277
 
        return self.transport.write(line + self.delimiter)
278
 
 
279
 
    def lineLengthExceeded(self, line):
280
 
        """Called when the maximum line length has been reached.
281
 
        Override if it needs to be dealt with in some special way.
282
 
 
283
 
        The argument 'line' contains the remainder of the buffer, starting
284
 
        with (at least some part) of the line which is too long. This may
285
 
        be more than one line, or may be only the initial portion of the
286
 
        line.
287
 
        """
288
 
        return self.transport.loseConnection()
289
 
 
290
 
 
291
 
class Int32StringReceiver(protocol.Protocol, _PauseableMixin):
292
 
    """A receiver for int32-prefixed strings.
293
 
 
294
 
    An int32 string is a string prefixed by 4 bytes, the 32-bit length of
295
 
    the string encoded in network byte order.
296
 
 
297
 
    This class publishes the same interface as NetstringReceiver.
298
 
    """
299
 
 
300
 
    MAX_LENGTH = 99999
301
 
    recvd = ""
302
 
 
303
 
    def stringReceived(self, msg):
304
 
        """Override this.
305
 
        """
306
 
        raise NotImplementedError
307
 
 
308
 
    def dataReceived(self, recd):
309
 
        """Convert int32 prefixed strings into calls to stringReceived.
310
 
        """
311
 
        self.recvd = self.recvd + recd
312
 
        while len(self.recvd) > 3 and not self.paused:
313
 
            length ,= struct.unpack("!i",self.recvd[:4])
314
 
            if length > self.MAX_LENGTH:
315
 
                self.transport.loseConnection()
316
 
                return
317
 
            if len(self.recvd) < length+4:
318
 
                break
319
 
            packet = self.recvd[4:length+4]
320
 
            self.recvd = self.recvd[length+4:]
321
 
            self.stringReceived(packet)
322
 
 
323
 
    def sendString(self, data):
324
 
        """Send an int32-prefixed string to the other end of the connection.
325
 
        """
326
 
        self.transport.write(struct.pack("!i",len(data))+data)
327
 
 
328
 
 
329
 
class Int16StringReceiver(protocol.Protocol, _PauseableMixin):
330
 
    """A receiver for int16-prefixed strings.
331
 
 
332
 
    An int16 string is a string prefixed by 2 bytes, the 16-bit length of
333
 
    the string encoded in network byte order.
334
 
 
335
 
    This class publishes the same interface as NetstringReceiver.
336
 
    """
337
 
 
338
 
    recvd = ""
339
 
 
340
 
    def stringReceived(self, msg):
341
 
        """Override this.
342
 
        """
343
 
        raise NotImplementedError
344
 
 
345
 
    def dataReceived(self, recd):
346
 
        """Convert int16 prefixed strings into calls to stringReceived.
347
 
        """
348
 
        self.recvd = self.recvd + recd
349
 
        while len(self.recvd) > 1 and not self.paused:
350
 
            length = (ord(self.recvd[0]) * 256) + ord(self.recvd[1])
351
 
            if len(self.recvd) < length+2:
352
 
                break
353
 
            packet = self.recvd[2:length+2]
354
 
            self.recvd = self.recvd[length+2:]
355
 
            self.stringReceived(packet)
356
 
 
357
 
    def sendString(self, data):
358
 
        """Send an int16-prefixed string to the other end of the connection.
359
 
        """
360
 
        assert len(data) < 65536, "message too long"
361
 
        self.transport.write(struct.pack("!h",len(data)) + data)
362
 
 
363
 
 
364
 
class StatefulStringProtocol:
365
 
    """A stateful string protocol.
366
 
 
367
 
    This is a mixin for string protocols (Int32StringReceiver,
368
 
    NetstringReceiver) which translates stringReceived into a callback
369
 
    (prefixed with 'proto_') depending on state."""
370
 
 
371
 
    state = 'init'
372
 
 
373
 
    def stringReceived(self,string):
374
 
        """Choose a protocol phase function and call it.
375
 
 
376
 
        Call back to the appropriate protocol phase; this begins with
377
 
        the function proto_init and moves on to proto_* depending on
378
 
        what each proto_* function returns.  (For example, if
379
 
        self.proto_init returns 'foo', then self.proto_foo will be the
380
 
        next function called when a protocol message is received.
381
 
        """
382
 
        try:
383
 
            pto = 'proto_'+self.state
384
 
            statehandler = getattr(self,pto)
385
 
        except AttributeError:
386
 
            log.msg('callback',self.state,'not found')
387
 
        else:
388
 
            self.state = statehandler(string)
389
 
            if self.state == 'done':
390
 
                self.transport.loseConnection()
391
 
 
392
 
class FileSender:
393
 
    """A producer that sends the contents of a file to a consumer.
394
 
 
395
 
    This is a helper for protocols that, at some point, will take a
396
 
    file-like object, read its contents, and write them out to the network,
397
 
    optionally performing some transformation on the bytes in between.
398
 
 
399
 
    This API is unstable.
400
 
    """
401
 
    implements(interfaces.IProducer)
402
 
 
403
 
    CHUNK_SIZE = 2 ** 14
404
 
 
405
 
    lastSent = ''
406
 
    deferred = None
407
 
 
408
 
    def beginFileTransfer(self, file, consumer, transform = None):
409
 
        """Begin transferring a file
410
 
 
411
 
        @type file: Any file-like object
412
 
        @param file: The file object to read data from
413
 
 
414
 
        @type consumer: Any implementor of IConsumer
415
 
        @param consumer: The object to write data to
416
 
 
417
 
        @param transform: A callable taking one string argument and returning
418
 
        the same.  All bytes read from the file are passed through this before
419
 
        being written to the consumer.
420
 
 
421
 
        @rtype: C{Deferred}
422
 
        @return: A deferred whose callback will be invoked when the file has been
423
 
        completely written to the consumer.  The last byte written to the consumer
424
 
        is passed to the callback.
425
 
        """
426
 
        self.file = file
427
 
        self.consumer = consumer
428
 
        self.transform = transform
429
 
 
430
 
        self.deferred = deferred = defer.Deferred()
431
 
        self.consumer.registerProducer(self, False)
432
 
        return deferred
433
 
 
434
 
    def resumeProducing(self):
435
 
        chunk = ''
436
 
        if self.file:
437
 
            chunk = self.file.read(self.CHUNK_SIZE)
438
 
        if not chunk:
439
 
            self.file = None
440
 
            self.consumer.unregisterProducer()
441
 
            if self.deferred:
442
 
                self.deferred.callback(self.lastSent)
443
 
                self.deferred = None
444
 
            return
445
 
 
446
 
        if self.transform:
447
 
            chunk = self.transform(chunk)
448
 
        self.consumer.write(chunk)
449
 
        self.lastSent = chunk[-1]
450
 
 
451
 
    def pauseProducing(self):
452
 
        pass
453
 
 
454
 
    def stopProducing(self):
455
 
        if self.deferred:
456
 
            self.deferred.errback(Exception("Consumer asked us to stop producing"))
457
 
            self.deferred = None