~ubuntu-branches/ubuntu/trusty/python3.4/trusty-proposed

« back to all changes in this revision

Viewing changes to Lib/asyncio/streams.py

  • Committer: Package Import Robot
  • Author(s): Matthias Klose
  • Date: 2013-11-25 09:44:27 UTC
  • Revision ID: package-import@ubuntu.com-20131125094427-lzxj8ap5w01lmo7f
Tags: upstream-3.4~b1
ImportĀ upstreamĀ versionĀ 3.4~b1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
"""Stream-related things."""
 
2
 
 
3
__all__ = ['StreamReader', 'StreamReaderProtocol',
 
4
           'open_connection', 'start_server',
 
5
           ]
 
6
 
 
7
import collections
 
8
 
 
9
from . import events
 
10
from . import futures
 
11
from . import protocols
 
12
from . import tasks
 
13
 
 
14
 
 
15
_DEFAULT_LIMIT = 2**16
 
16
 
 
17
 
 
18
@tasks.coroutine
 
19
def open_connection(host=None, port=None, *,
 
20
                    loop=None, limit=_DEFAULT_LIMIT, **kwds):
 
21
    """A wrapper for create_connection() returning a (reader, writer) pair.
 
22
 
 
23
    The reader returned is a StreamReader instance; the writer is a
 
24
    Transport.
 
25
 
 
26
    The arguments are all the usual arguments to create_connection()
 
27
    except protocol_factory; most common are positional host and port,
 
28
    with various optional keyword arguments following.
 
29
 
 
30
    Additional optional keyword arguments are loop (to set the event loop
 
31
    instance to use) and limit (to set the buffer limit passed to the
 
32
    StreamReader).
 
33
 
 
34
    (If you want to customize the StreamReader and/or
 
35
    StreamReaderProtocol classes, just copy the code -- there's
 
36
    really nothing special here except some convenience.)
 
37
    """
 
38
    if loop is None:
 
39
        loop = events.get_event_loop()
 
40
    reader = StreamReader(limit=limit, loop=loop)
 
41
    protocol = StreamReaderProtocol(reader)
 
42
    transport, _ = yield from loop.create_connection(
 
43
        lambda: protocol, host, port, **kwds)
 
44
    writer = StreamWriter(transport, protocol, reader, loop)
 
45
    return reader, writer
 
46
 
 
47
 
 
48
@tasks.coroutine
 
49
def start_server(client_connected_cb, host=None, port=None, *,
 
50
                 loop=None, limit=_DEFAULT_LIMIT, **kwds):
 
51
    """Start a socket server, call back for each client connected.
 
52
 
 
53
    The first parameter, `client_connected_cb`, takes two parameters:
 
54
    client_reader, client_writer.  client_reader is a StreamReader
 
55
    object, while client_writer is a StreamWriter object.  This
 
56
    parameter can either be a plain callback function or a coroutine;
 
57
    if it is a coroutine, it will be automatically converted into a
 
58
    Task.
 
59
 
 
60
    The rest of the arguments are all the usual arguments to
 
61
    loop.create_server() except protocol_factory; most common are
 
62
    positional host and port, with various optional keyword arguments
 
63
    following.  The return value is the same as loop.create_server().
 
64
 
 
65
    Additional optional keyword arguments are loop (to set the event loop
 
66
    instance to use) and limit (to set the buffer limit passed to the
 
67
    StreamReader).
 
68
 
 
69
    The return value is the same as loop.create_server(), i.e. a
 
70
    Server object which can be used to stop the service.
 
71
    """
 
72
    if loop is None:
 
73
        loop = events.get_event_loop()
 
74
 
 
75
    def factory():
 
76
        reader = StreamReader(limit=limit, loop=loop)
 
77
        protocol = StreamReaderProtocol(reader, client_connected_cb,
 
78
                                        loop=loop)
 
79
        return protocol
 
80
 
 
81
    return (yield from loop.create_server(factory, host, port, **kwds))
 
82
 
 
83
 
 
84
class StreamReaderProtocol(protocols.Protocol):
 
85
    """Trivial helper class to adapt between Protocol and StreamReader.
 
86
 
 
87
    (This is a helper class instead of making StreamReader itself a
 
88
    Protocol subclass, because the StreamReader has other potential
 
89
    uses, and to prevent the user of the StreamReader to accidentally
 
90
    call inappropriate methods of the protocol.)
 
91
    """
 
92
 
 
93
    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
 
94
        self._stream_reader = stream_reader
 
95
        self._stream_writer = None
 
96
        self._drain_waiter = None
 
97
        self._paused = False
 
98
        self._client_connected_cb = client_connected_cb
 
99
        self._loop = loop  # May be None; we may never need it.
 
100
 
 
101
    def connection_made(self, transport):
 
102
        self._stream_reader.set_transport(transport)
 
103
        if self._client_connected_cb is not None:
 
104
            self._stream_writer = StreamWriter(transport, self,
 
105
                                               self._stream_reader,
 
106
                                               self._loop)
 
107
            res = self._client_connected_cb(self._stream_reader,
 
108
                                            self._stream_writer)
 
109
            if tasks.iscoroutine(res):
 
110
                tasks.Task(res, loop=self._loop)
 
111
 
 
112
    def connection_lost(self, exc):
 
113
        if exc is None:
 
114
            self._stream_reader.feed_eof()
 
115
        else:
 
116
            self._stream_reader.set_exception(exc)
 
117
        # Also wake up the writing side.
 
118
        if self._paused:
 
119
            waiter = self._drain_waiter
 
120
            if waiter is not None:
 
121
                self._drain_waiter = None
 
122
                if not waiter.done():
 
123
                    if exc is None:
 
124
                        waiter.set_result(None)
 
125
                    else:
 
126
                        waiter.set_exception(exc)
 
127
 
 
128
    def data_received(self, data):
 
129
        self._stream_reader.feed_data(data)
 
130
 
 
131
    def eof_received(self):
 
132
        self._stream_reader.feed_eof()
 
133
 
 
134
    def pause_writing(self):
 
135
        assert not self._paused
 
136
        self._paused = True
 
137
 
 
138
    def resume_writing(self):
 
139
        assert self._paused
 
140
        self._paused = False
 
141
        waiter = self._drain_waiter
 
142
        if waiter is not None:
 
143
            self._drain_waiter = None
 
144
            if not waiter.done():
 
145
                waiter.set_result(None)
 
146
 
 
147
 
 
148
class StreamWriter:
 
149
    """Wraps a Transport.
 
150
 
 
151
    This exposes write(), writelines(), [can_]write_eof(),
 
152
    get_extra_info() and close().  It adds drain() which returns an
 
153
    optional Future on which you can wait for flow control.  It also
 
154
    adds a transport attribute which references the Transport
 
155
    directly.
 
156
    """
 
157
 
 
158
    def __init__(self, transport, protocol, reader, loop):
 
159
        self._transport = transport
 
160
        self._protocol = protocol
 
161
        self._reader = reader
 
162
        self._loop = loop
 
163
 
 
164
    @property
 
165
    def transport(self):
 
166
        return self._transport
 
167
 
 
168
    def write(self, data):
 
169
        self._transport.write(data)
 
170
 
 
171
    def writelines(self, data):
 
172
        self._transport.writelines(data)
 
173
 
 
174
    def write_eof(self):
 
175
        return self._transport.write_eof()
 
176
 
 
177
    def can_write_eof(self):
 
178
        return self._transport.can_write_eof()
 
179
 
 
180
    def close(self):
 
181
        return self._transport.close()
 
182
 
 
183
    def get_extra_info(self, name, default=None):
 
184
        return self._transport.get_extra_info(name, default)
 
185
 
 
186
    def drain(self):
 
187
        """This method has an unusual return value.
 
188
 
 
189
        The intended use is to write
 
190
 
 
191
          w.write(data)
 
192
          yield from w.drain()
 
193
 
 
194
        When there's nothing to wait for, drain() returns (), and the
 
195
        yield-from continues immediately.  When the transport buffer
 
196
        is full (the protocol is paused), drain() creates and returns
 
197
        a Future and the yield-from will block until that Future is
 
198
        completed, which will happen when the buffer is (partially)
 
199
        drained and the protocol is resumed.
 
200
        """
 
201
        if self._reader._exception is not None:
 
202
            raise self._writer._exception
 
203
        if self._transport._conn_lost:  # Uses private variable.
 
204
            raise ConnectionResetError('Connection lost')
 
205
        if not self._protocol._paused:
 
206
            return ()
 
207
        waiter = self._protocol._drain_waiter
 
208
        assert waiter is None or waiter.cancelled()
 
209
        waiter = futures.Future(loop=self._loop)
 
210
        self._protocol._drain_waiter = waiter
 
211
        return waiter
 
212
 
 
213
 
 
214
class StreamReader:
 
215
 
 
216
    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
 
217
        # The line length limit is  a security feature;
 
218
        # it also doubles as half the buffer limit.
 
219
        self._limit = limit
 
220
        if loop is None:
 
221
            loop = events.get_event_loop()
 
222
        self._loop = loop
 
223
        self._buffer = collections.deque()  # Deque of bytes objects.
 
224
        self._byte_count = 0  # Bytes in buffer.
 
225
        self._eof = False  # Whether we're done.
 
226
        self._waiter = None  # A future.
 
227
        self._exception = None
 
228
        self._transport = None
 
229
        self._paused = False
 
230
 
 
231
    def exception(self):
 
232
        return self._exception
 
233
 
 
234
    def set_exception(self, exc):
 
235
        self._exception = exc
 
236
 
 
237
        waiter = self._waiter
 
238
        if waiter is not None:
 
239
            self._waiter = None
 
240
            if not waiter.cancelled():
 
241
                waiter.set_exception(exc)
 
242
 
 
243
    def set_transport(self, transport):
 
244
        assert self._transport is None, 'Transport already set'
 
245
        self._transport = transport
 
246
 
 
247
    def _maybe_resume_transport(self):
 
248
        if self._paused and self._byte_count <= self._limit:
 
249
            self._paused = False
 
250
            self._transport.resume_reading()
 
251
 
 
252
    def feed_eof(self):
 
253
        self._eof = True
 
254
        waiter = self._waiter
 
255
        if waiter is not None:
 
256
            self._waiter = None
 
257
            if not waiter.cancelled():
 
258
                waiter.set_result(True)
 
259
 
 
260
    def feed_data(self, data):
 
261
        if not data:
 
262
            return
 
263
 
 
264
        self._buffer.append(data)
 
265
        self._byte_count += len(data)
 
266
 
 
267
        waiter = self._waiter
 
268
        if waiter is not None:
 
269
            self._waiter = None
 
270
            if not waiter.cancelled():
 
271
                waiter.set_result(False)
 
272
 
 
273
        if (self._transport is not None and
 
274
            not self._paused and
 
275
            self._byte_count > 2*self._limit):
 
276
            try:
 
277
                self._transport.pause_reading()
 
278
            except NotImplementedError:
 
279
                # The transport can't be paused.
 
280
                # We'll just have to buffer all data.
 
281
                # Forget the transport so we don't keep trying.
 
282
                self._transport = None
 
283
            else:
 
284
                self._paused = True
 
285
 
 
286
    @tasks.coroutine
 
287
    def readline(self):
 
288
        if self._exception is not None:
 
289
            raise self._exception
 
290
 
 
291
        parts = []
 
292
        parts_size = 0
 
293
        not_enough = True
 
294
 
 
295
        while not_enough:
 
296
            while self._buffer and not_enough:
 
297
                data = self._buffer.popleft()
 
298
                ichar = data.find(b'\n')
 
299
                if ichar < 0:
 
300
                    parts.append(data)
 
301
                    parts_size += len(data)
 
302
                else:
 
303
                    ichar += 1
 
304
                    head, tail = data[:ichar], data[ichar:]
 
305
                    if tail:
 
306
                        self._buffer.appendleft(tail)
 
307
                    not_enough = False
 
308
                    parts.append(head)
 
309
                    parts_size += len(head)
 
310
 
 
311
                if parts_size > self._limit:
 
312
                    self._byte_count -= parts_size
 
313
                    self._maybe_resume_transport()
 
314
                    raise ValueError('Line is too long')
 
315
 
 
316
            if self._eof:
 
317
                break
 
318
 
 
319
            if not_enough:
 
320
                assert self._waiter is None
 
321
                self._waiter = futures.Future(loop=self._loop)
 
322
                try:
 
323
                    yield from self._waiter
 
324
                finally:
 
325
                    self._waiter = None
 
326
 
 
327
        line = b''.join(parts)
 
328
        self._byte_count -= parts_size
 
329
        self._maybe_resume_transport()
 
330
 
 
331
        return line
 
332
 
 
333
    @tasks.coroutine
 
334
    def read(self, n=-1):
 
335
        if self._exception is not None:
 
336
            raise self._exception
 
337
 
 
338
        if not n:
 
339
            return b''
 
340
 
 
341
        if n < 0:
 
342
            while not self._eof:
 
343
                assert not self._waiter
 
344
                self._waiter = futures.Future(loop=self._loop)
 
345
                try:
 
346
                    yield from self._waiter
 
347
                finally:
 
348
                    self._waiter = None
 
349
        else:
 
350
            if not self._byte_count and not self._eof:
 
351
                assert not self._waiter
 
352
                self._waiter = futures.Future(loop=self._loop)
 
353
                try:
 
354
                    yield from self._waiter
 
355
                finally:
 
356
                    self._waiter = None
 
357
 
 
358
        if n < 0 or self._byte_count <= n:
 
359
            data = b''.join(self._buffer)
 
360
            self._buffer.clear()
 
361
            self._byte_count = 0
 
362
            self._maybe_resume_transport()
 
363
            return data
 
364
 
 
365
        parts = []
 
366
        parts_bytes = 0
 
367
        while self._buffer and parts_bytes < n:
 
368
            data = self._buffer.popleft()
 
369
            data_bytes = len(data)
 
370
            if n < parts_bytes + data_bytes:
 
371
                data_bytes = n - parts_bytes
 
372
                data, rest = data[:data_bytes], data[data_bytes:]
 
373
                self._buffer.appendleft(rest)
 
374
 
 
375
            parts.append(data)
 
376
            parts_bytes += data_bytes
 
377
            self._byte_count -= data_bytes
 
378
            self._maybe_resume_transport()
 
379
 
 
380
        return b''.join(parts)
 
381
 
 
382
    @tasks.coroutine
 
383
    def readexactly(self, n):
 
384
        if self._exception is not None:
 
385
            raise self._exception
 
386
 
 
387
        if n <= 0:
 
388
            return b''
 
389
 
 
390
        while self._byte_count < n and not self._eof:
 
391
            assert not self._waiter
 
392
            self._waiter = futures.Future(loop=self._loop)
 
393
            try:
 
394
                yield from self._waiter
 
395
            finally:
 
396
                self._waiter = None
 
397
 
 
398
        return (yield from self.read(n))