~ubuntu-branches/debian/lenny/bzr/lenny

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart.py

  • Committer: Bazaar Package Importer
  • Author(s): Thomas Viehmann
  • Date: 2008-08-22 20:06:37 UTC
  • mfrom: (3.1.63 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080822200637-kxobfsnjlzojhqra
Tags: 1.5-1.1
* Non-maintainer upload.
* Apply patch from upstream VCS to fix FTBFS in tools/rst2html.py
  with older docutils. Thanks to Olivier Tétard for digging it
  up.
  Closes: #494246.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
 
 
17
 
"""Smart-server protocol, client and server.
18
 
 
19
 
Requests are sent as a command and list of arguments, followed by optional
20
 
bulk body data.  Responses are similarly a response and list of arguments,
21
 
followed by bulk body data. ::
22
 
 
23
 
  SEP := '\001'
24
 
    Fields are separated by Ctrl-A.
25
 
  BULK_DATA := CHUNK+ TRAILER
26
 
    Chunks can be repeated as many times as necessary.
27
 
  CHUNK := CHUNK_LEN CHUNK_BODY
28
 
  CHUNK_LEN := DIGIT+ NEWLINE
29
 
    Gives the number of bytes in the following chunk.
30
 
  CHUNK_BODY := BYTE[chunk_len]
31
 
  TRAILER := SUCCESS_TRAILER | ERROR_TRAILER
32
 
  SUCCESS_TRAILER := 'done' NEWLINE
33
 
  ERROR_TRAILER := 
34
 
 
35
 
Paths are passed across the network.  The client needs to see a namespace that
36
 
includes any repository that might need to be referenced, and the client needs
37
 
to know about a root directory beyond which it cannot ascend.
38
 
 
39
 
Servers run over ssh will typically want to be able to access any path the user 
40
 
can access.  Public servers on the other hand (which might be over http, ssh
41
 
or tcp) will typically want to restrict access to only a particular directory 
42
 
and its children, so will want to do a software virtual root at that level.
43
 
In other words they'll want to rewrite incoming paths to be under that level
44
 
(and prevent escaping using ../ tricks.)
45
 
 
46
 
URLs that include ~ should probably be passed across to the server verbatim
47
 
and the server can expand them.  This will proably not be meaningful when 
48
 
limited to a directory?
49
 
"""
50
 
 
51
 
 
52
 
# TODO: _translate_error should be on the client, not the transport because
53
 
#     error coding is wire protocol specific.
54
 
 
55
 
# TODO: A plain integer from query_version is too simple; should give some
56
 
# capabilities too?
57
 
 
58
 
# TODO: Server should probably catch exceptions within itself and send them
59
 
# back across the network.  (But shouldn't catch KeyboardInterrupt etc)
60
 
# Also needs to somehow report protocol errors like bad requests.  Need to
61
 
# consider how we'll handle error reporting, e.g. if we get halfway through a
62
 
# bulk transfer and then something goes wrong.
63
 
 
64
 
# TODO: Standard marker at start of request/response lines?
65
 
 
66
 
# TODO: Make each request and response self-validatable, e.g. with checksums.
67
 
#
68
 
# TODO: get/put objects could be changed to gradually read back the data as it
69
 
# comes across the network
70
 
#
71
 
# TODO: What should the server do if it hits an error and has to terminate?
72
 
#
73
 
# TODO: is it useful to allow multiple chunks in the bulk data?
74
 
#
75
 
# TODO: If we get an exception during transmission of bulk data we can't just
76
 
# emit the exception because it won't be seen.
77
 
#   John proposes:  I think it would be worthwhile to have a header on each
78
 
#   chunk, that indicates it is another chunk. Then you can send an 'error'
79
 
#   chunk as long as you finish the previous chunk.
80
 
#
81
 
# TODO: Clone method on Transport; should work up towards parent directory;
82
 
# unclear how this should be stored or communicated to the server... maybe
83
 
# just pass it on all relevant requests?
84
 
#
85
 
# TODO: Better name than clone() for changing between directories.  How about
86
 
# open_dir or change_dir or chdir?
87
 
#
88
 
# TODO: Is it really good to have the notion of current directory within the
89
 
# connection?  Perhaps all Transports should factor out a common connection
90
 
# from the thing that has the directory context?
91
 
#
92
 
# TODO: Pull more things common to sftp and ssh to a higher level.
93
 
#
94
 
# TODO: The server that manages a connection should be quite small and retain
95
 
# minimum state because each of the requests are supposed to be stateless.
96
 
# Then we can write another implementation that maps to http.
97
 
#
98
 
# TODO: What to do when a client connection is garbage collected?  Maybe just
99
 
# abruptly drop the connection?
100
 
#
101
 
# TODO: Server in some cases will need to restrict access to files outside of
102
 
# a particular root directory.  LocalTransport doesn't do anything to stop you
103
 
# ascending above the base directory, so we need to prevent paths
104
 
# containing '..' in either the server or transport layers.  (Also need to
105
 
# consider what happens if someone creates a symlink pointing outside the 
106
 
# directory tree...)
107
 
#
108
 
# TODO: Server should rebase absolute paths coming across the network to put
109
 
# them under the virtual root, if one is in use.  LocalTransport currently
110
 
# doesn't do that; if you give it an absolute path it just uses it.
111
 
112
 
# XXX: Arguments can't contain newlines or ascii; possibly we should e.g.
113
 
# urlescape them instead.  Indeed possibly this should just literally be
114
 
# http-over-ssh.
115
 
#
116
 
# FIXME: This transport, with several others, has imperfect handling of paths
117
 
# within urls.  It'd probably be better for ".." from a root to raise an error
118
 
# rather than return the same directory as we do at present.
119
 
#
120
 
# TODO: Rather than working at the Transport layer we want a Branch,
121
 
# Repository or BzrDir objects that talk to a server.
122
 
#
123
 
# TODO: Probably want some way for server commands to gradually produce body
124
 
# data rather than passing it as a string; they could perhaps pass an
125
 
# iterator-like callback that will gradually yield data; it probably needs a
126
 
# close() method that will always be closed to do any necessary cleanup.
127
 
#
128
 
# TODO: Split the actual smart server from the ssh encoding of it.
129
 
#
130
 
# TODO: Perhaps support file-level readwrite operations over the transport
131
 
# too.
132
 
#
133
 
# TODO: SmartBzrDir class, proxying all Branch etc methods across to another
134
 
# branch doing file-level operations.
135
 
#
136
 
# TODO: jam 20060915 _decode_tuple is acting directly on input over
137
 
#       the socket, and it assumes everything is UTF8 sections separated
138
 
#       by \001. Which means a request like '\002' Will abort the connection
139
 
#       because of a UnicodeDecodeError. It does look like invalid data will
140
 
#       kill the SmartStreamServer, but only with an abort + exception, and 
141
 
#       the overall server shouldn't die.
142
 
 
143
 
from cStringIO import StringIO
144
 
import errno
145
 
import os
146
 
import socket
147
 
import sys
148
 
import tempfile
149
 
import threading
150
 
import urllib
151
 
import urlparse
152
 
 
153
 
from bzrlib import (
154
 
    bzrdir,
155
 
    errors,
156
 
    revision,
157
 
    transport,
158
 
    trace,
159
 
    urlutils,
160
 
    )
161
 
from bzrlib.bundle.serializer import write_bundle
162
 
from bzrlib.trace import mutter
163
 
from bzrlib.transport import local
164
 
 
165
 
# must do this otherwise urllib can't parse the urls properly :(
166
 
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh']:
167
 
    transport.register_urlparse_netloc_protocol(scheme)
168
 
del scheme
169
 
 
170
 
 
171
 
def _recv_tuple(from_file):
172
 
    req_line = from_file.readline()
173
 
    return _decode_tuple(req_line)
174
 
 
175
 
 
176
 
def _decode_tuple(req_line):
177
 
    if req_line == None or req_line == '':
178
 
        return None
179
 
    if req_line[-1] != '\n':
180
 
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
181
 
    return tuple((a.decode('utf-8') for a in req_line[:-1].split('\x01')))
182
 
 
183
 
 
184
 
def _encode_tuple(args):
185
 
    """Encode the tuple args to a bytestream."""
186
 
    return '\x01'.join((a.encode('utf-8') for a in args)) + '\n'
187
 
 
188
 
 
189
 
class SmartProtocolBase(object):
190
 
    """Methods common to client and server"""
191
 
 
192
 
    def _send_bulk_data(self, body):
193
 
        """Send chunked body data"""
194
 
        assert isinstance(body, str)
195
 
        bytes = ''.join(('%d\n' % len(body), body, 'done\n'))
196
 
        self._write_and_flush(bytes)
197
 
 
198
 
    # TODO: this only actually accomodates a single block; possibly should support
199
 
    # multiple chunks?
200
 
    def _recv_bulk(self):
201
 
        chunk_len = self._in.readline()
202
 
        try:
203
 
            chunk_len = int(chunk_len)
204
 
        except ValueError:
205
 
            raise errors.SmartProtocolError("bad chunk length line %r" % chunk_len)
206
 
        bulk = self._in.read(chunk_len)
207
 
        if len(bulk) != chunk_len:
208
 
            raise errors.SmartProtocolError("short read fetching bulk data chunk")
209
 
        self._recv_trailer()
210
 
        return bulk
211
 
 
212
 
    def _recv_tuple(self):
213
 
        return _recv_tuple(self._in)
214
 
 
215
 
    def _recv_trailer(self):
216
 
        resp = self._recv_tuple()
217
 
        if resp == ('done', ):
218
 
            return
219
 
        else:
220
 
            self._translate_error(resp)
221
 
 
222
 
    def _serialise_offsets(self, offsets):
223
 
        """Serialise a readv offset list."""
224
 
        txt = []
225
 
        for start, length in offsets:
226
 
            txt.append('%d,%d' % (start, length))
227
 
        return '\n'.join(txt)
228
 
 
229
 
    def _write_and_flush(self, bytes):
230
 
        """Write bytes to self._out and flush it."""
231
 
        # XXX: this will be inefficient.  Just ask Robert.
232
 
        self._out.write(bytes)
233
 
        self._out.flush()
234
 
 
235
 
 
236
 
class SmartStreamServer(SmartProtocolBase):
237
 
    """Handles smart commands coming over a stream.
238
 
 
239
 
    The stream may be a pipe connected to sshd, or a tcp socket, or an
240
 
    in-process fifo for testing.
241
 
 
242
 
    One instance is created for each connected client; it can serve multiple
243
 
    requests in the lifetime of the connection.
244
 
 
245
 
    The server passes requests through to an underlying backing transport, 
246
 
    which will typically be a LocalTransport looking at the server's filesystem.
247
 
    """
248
 
 
249
 
    def __init__(self, in_file, out_file, backing_transport):
250
 
        """Construct new server.
251
 
 
252
 
        :param in_file: Python file from which requests can be read.
253
 
        :param out_file: Python file to write responses.
254
 
        :param backing_transport: Transport for the directory served.
255
 
        """
256
 
        self._in = in_file
257
 
        self._out = out_file
258
 
        self.smart_server = SmartServer(backing_transport)
259
 
        # server can call back to us to get bulk data - this is not really
260
 
        # ideal, they should get it per request instead
261
 
        self.smart_server._recv_body = self._recv_bulk
262
 
 
263
 
    def _recv_tuple(self):
264
 
        """Read a request from the client and return as a tuple.
265
 
        
266
 
        Returns None at end of file (if the client closed the connection.)
267
 
        """
268
 
        return _recv_tuple(self._in)
269
 
 
270
 
    def _send_tuple(self, args):
271
 
        """Send response header"""
272
 
        return self._write_and_flush(_encode_tuple(args))
273
 
 
274
 
    def _send_error_and_disconnect(self, exception):
275
 
        self._send_tuple(('error', str(exception)))
276
 
        ## self._out.close()
277
 
        ## self._in.close()
278
 
 
279
 
    def _serve_one_request(self):
280
 
        """Read one request from input, process, send back a response.
281
 
        
282
 
        :return: False if the server should terminate, otherwise None.
283
 
        """
284
 
        req_args = self._recv_tuple()
285
 
        if req_args == None:
286
 
            # client closed connection
287
 
            return False  # shutdown server
288
 
        try:
289
 
            response = self.smart_server.dispatch_command(req_args[0], req_args[1:])
290
 
            self._send_tuple(response.args)
291
 
            if response.body is not None:
292
 
                self._send_bulk_data(response.body)
293
 
        except KeyboardInterrupt:
294
 
            raise
295
 
        except Exception, e:
296
 
            # everything else: pass to client, flush, and quit
297
 
            self._send_error_and_disconnect(e)
298
 
            return False
299
 
 
300
 
    def serve(self):
301
 
        """Serve requests until the client disconnects."""
302
 
        # Keep a reference to stderr because the sys module's globals get set to
303
 
        # None during interpreter shutdown.
304
 
        from sys import stderr
305
 
        try:
306
 
            while self._serve_one_request() != False:
307
 
                pass
308
 
        except Exception, e:
309
 
            stderr.write("%s terminating on exception %s\n" % (self, e))
310
 
            raise
311
 
 
312
 
 
313
 
class SmartServerResponse(object):
314
 
    """Response generated by SmartServer."""
315
 
 
316
 
    def __init__(self, args, body=None):
317
 
        self.args = args
318
 
        self.body = body
319
 
 
320
 
# XXX: TODO: Create a SmartServerRequest which will take the responsibility
321
 
# for delivering the data for a request. This could be done with as the
322
 
# StreamServer, though that would create conflation between request and response
323
 
# which may be undesirable.
324
 
 
325
 
 
326
 
class SmartServer(object):
327
 
    """Protocol logic for smart server.
328
 
    
329
 
    This doesn't handle serialization at all, it just processes requests and
330
 
    creates responses.
331
 
    """
332
 
 
333
 
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServer not contain
334
 
    # encoding or decoding logic to allow the wire protocol to vary from the
335
 
    # object protocol: we will want to tweak the wire protocol separate from
336
 
    # the object model, and ideally we will be able to do that without having
337
 
    # a SmartServer subclass for each wire protocol, rather just a Protocol
338
 
    # subclass.
339
 
 
340
 
    # TODO: Better way of representing the body for commands that take it,
341
 
    # and allow it to be streamed into the server.
342
 
    
343
 
    def __init__(self, backing_transport):
344
 
        self._backing_transport = backing_transport
345
 
        
346
 
    def do_hello(self):
347
 
        """Answer a version request with my version."""
348
 
        return SmartServerResponse(('ok', '1'))
349
 
 
350
 
    def do_has(self, relpath):
351
 
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
352
 
        return SmartServerResponse((r,))
353
 
 
354
 
    def do_get(self, relpath):
355
 
        backing_bytes = self._backing_transport.get_bytes(relpath)
356
 
        return SmartServerResponse(('ok',), backing_bytes)
357
 
 
358
 
    def _deserialise_optional_mode(self, mode):
359
 
        # XXX: FIXME this should be on the protocol object.
360
 
        if mode == '':
361
 
            return None
362
 
        else:
363
 
            return int(mode)
364
 
 
365
 
    def do_append(self, relpath, mode):
366
 
        old_length = self._backing_transport.append_bytes(
367
 
            relpath, self._recv_body(), self._deserialise_optional_mode(mode))
368
 
        return SmartServerResponse(('appended', '%d' % old_length))
369
 
 
370
 
    def do_delete(self, relpath):
371
 
        self._backing_transport.delete(relpath)
372
 
 
373
 
    def do_iter_files_recursive(self, abspath):
374
 
        # XXX: the path handling needs some thought.
375
 
        #relpath = self._backing_transport.relpath(abspath)
376
 
        transport = self._backing_transport.clone(abspath)
377
 
        filenames = transport.iter_files_recursive()
378
 
        return SmartServerResponse(('names',) + tuple(filenames))
379
 
 
380
 
    def do_list_dir(self, relpath):
381
 
        filenames = self._backing_transport.list_dir(relpath)
382
 
        return SmartServerResponse(('names',) + tuple(filenames))
383
 
 
384
 
    def do_mkdir(self, relpath, mode):
385
 
        self._backing_transport.mkdir(relpath,
386
 
                                      self._deserialise_optional_mode(mode))
387
 
 
388
 
    def do_move(self, rel_from, rel_to):
389
 
        self._backing_transport.move(rel_from, rel_to)
390
 
 
391
 
    def do_put(self, relpath, mode):
392
 
        self._backing_transport.put_bytes(relpath,
393
 
                self._recv_body(),
394
 
                self._deserialise_optional_mode(mode))
395
 
 
396
 
    def _deserialise_offsets(self, text):
397
 
        # XXX: FIXME this should be on the protocol object.
398
 
        offsets = []
399
 
        for line in text.split('\n'):
400
 
            if not line:
401
 
                continue
402
 
            start, length = line.split(',')
403
 
            offsets.append((int(start), int(length)))
404
 
        return offsets
405
 
 
406
 
    def do_put_non_atomic(self, relpath, mode, create_parent, dir_mode):
407
 
        create_parent_dir = (create_parent == 'T')
408
 
        self._backing_transport.put_bytes_non_atomic(relpath,
409
 
                self._recv_body(),
410
 
                mode=self._deserialise_optional_mode(mode),
411
 
                create_parent_dir=create_parent_dir,
412
 
                dir_mode=self._deserialise_optional_mode(dir_mode))
413
 
 
414
 
    def do_readv(self, relpath):
415
 
        offsets = self._deserialise_offsets(self._recv_body())
416
 
        backing_bytes = ''.join(bytes for offset, bytes in
417
 
                             self._backing_transport.readv(relpath, offsets))
418
 
        return SmartServerResponse(('readv',), backing_bytes)
419
 
        
420
 
    def do_rename(self, rel_from, rel_to):
421
 
        self._backing_transport.rename(rel_from, rel_to)
422
 
 
423
 
    def do_rmdir(self, relpath):
424
 
        self._backing_transport.rmdir(relpath)
425
 
 
426
 
    def do_stat(self, relpath):
427
 
        stat = self._backing_transport.stat(relpath)
428
 
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
429
 
        
430
 
    def do_get_bundle(self, path, revision_id):
431
 
        # open transport relative to our base
432
 
        t = self._backing_transport.clone(path)
433
 
        control, extra_path = bzrdir.BzrDir.open_containing_from_transport(t)
434
 
        repo = control.open_repository()
435
 
        tmpf = tempfile.TemporaryFile()
436
 
        base_revision = revision.NULL_REVISION
437
 
        write_bundle(repo, revision_id, base_revision, tmpf)
438
 
        tmpf.seek(0)
439
 
        return SmartServerResponse((), tmpf.read())
440
 
 
441
 
    def dispatch_command(self, cmd, args):
442
 
        func = getattr(self, 'do_' + cmd, None)
443
 
        if func is None:
444
 
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
445
 
        try:
446
 
            result = func(*args)
447
 
            if result is None: 
448
 
                result = SmartServerResponse(('ok',))
449
 
            return result
450
 
        except errors.NoSuchFile, e:
451
 
            return SmartServerResponse(('NoSuchFile', e.path))
452
 
        except errors.FileExists, e:
453
 
            return SmartServerResponse(('FileExists', e.path))
454
 
        except errors.DirectoryNotEmpty, e:
455
 
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
456
 
        except errors.ShortReadvError, e:
457
 
            return SmartServerResponse(('ShortReadvError',
458
 
                e.path, str(e.offset), str(e.length), str(e.actual)))
459
 
        except UnicodeError, e:
460
 
            # If it is a DecodeError, than most likely we are starting
461
 
            # with a plain string
462
 
            str_or_unicode = e.object
463
 
            if isinstance(str_or_unicode, unicode):
464
 
                val = u'u:' + str_or_unicode
465
 
            else:
466
 
                val = u's:' + str_or_unicode.encode('base64')
467
 
            # This handles UnicodeEncodeError or UnicodeDecodeError
468
 
            return SmartServerResponse((e.__class__.__name__,
469
 
                    e.encoding, val, str(e.start), str(e.end), e.reason))
470
 
        except errors.TransportNotPossible, e:
471
 
            if e.msg == "readonly transport":
472
 
                return SmartServerResponse(('ReadOnlyError', ))
473
 
            else:
474
 
                raise
475
 
 
476
 
 
477
 
class SmartTCPServer(object):
478
 
    """Listens on a TCP socket and accepts connections from smart clients"""
479
 
 
480
 
    def __init__(self, backing_transport=None, host='127.0.0.1', port=0):
481
 
        """Construct a new server.
482
 
 
483
 
        To actually start it running, call either start_background_thread or
484
 
        serve.
485
 
 
486
 
        :param host: Name of the interface to listen on.
487
 
        :param port: TCP port to listen on, or 0 to allocate a transient port.
488
 
        """
489
 
        if backing_transport is None:
490
 
            backing_transport = memory.MemoryTransport()
491
 
        self._server_socket = socket.socket()
492
 
        self._server_socket.bind((host, port))
493
 
        self.port = self._server_socket.getsockname()[1]
494
 
        self._server_socket.listen(1)
495
 
        self._server_socket.settimeout(1)
496
 
        self.backing_transport = backing_transport
497
 
 
498
 
    def serve(self):
499
 
        # let connections timeout so that we get a chance to terminate
500
 
        # Keep a reference to the exceptions we want to catch because the socket
501
 
        # module's globals get set to None during interpreter shutdown.
502
 
        from socket import timeout as socket_timeout
503
 
        from socket import error as socket_error
504
 
        self._should_terminate = False
505
 
        while not self._should_terminate:
506
 
            try:
507
 
                self.accept_and_serve()
508
 
            except socket_timeout:
509
 
                # just check if we're asked to stop
510
 
                pass
511
 
            except socket_error, e:
512
 
                trace.warning("client disconnected: %s", e)
513
 
                pass
514
 
 
515
 
    def get_url(self):
516
 
        """Return the url of the server"""
517
 
        return "bzr://%s:%d/" % self._server_socket.getsockname()
518
 
 
519
 
    def accept_and_serve(self):
520
 
        conn, client_addr = self._server_socket.accept()
521
 
        # For WIN32, where the timeout value from the listening socket
522
 
        # propogates to the newly accepted socket.
523
 
        conn.setblocking(True)
524
 
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
525
 
        from_client = conn.makefile('r')
526
 
        to_client = conn.makefile('w')
527
 
        handler = SmartStreamServer(from_client, to_client,
528
 
                self.backing_transport)
529
 
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
530
 
        connection_thread.setDaemon(True)
531
 
        connection_thread.start()
532
 
 
533
 
    def start_background_thread(self):
534
 
        self._server_thread = threading.Thread(None,
535
 
                self.serve,
536
 
                name='server-' + self.get_url())
537
 
        self._server_thread.setDaemon(True)
538
 
        self._server_thread.start()
539
 
 
540
 
    def stop_background_thread(self):
541
 
        self._should_terminate = True
542
 
        # self._server_socket.close()
543
 
        # we used to join the thread, but it's not really necessary; it will
544
 
        # terminate in time
545
 
        ## self._server_thread.join()
546
 
 
547
 
 
548
 
class SmartTCPServer_for_testing(SmartTCPServer):
549
 
    """Server suitable for use by transport tests.
550
 
    
551
 
    This server is backed by the process's cwd.
552
 
    """
553
 
 
554
 
    def __init__(self):
555
 
        self._homedir = urlutils.local_path_to_url(os.getcwd())[7:]
556
 
        # The server is set up by default like for ssh access: the client
557
 
        # passes filesystem-absolute paths; therefore the server must look
558
 
        # them up relative to the root directory.  it might be better to act
559
 
        # a public server and have the server rewrite paths into the test
560
 
        # directory.
561
 
        SmartTCPServer.__init__(self,
562
 
            transport.get_transport(urlutils.local_path_to_url('/')))
563
 
        
564
 
    def setUp(self):
565
 
        """Set up server for testing"""
566
 
        self.start_background_thread()
567
 
 
568
 
    def tearDown(self):
569
 
        self.stop_background_thread()
570
 
 
571
 
    def get_url(self):
572
 
        """Return the url of the server"""
573
 
        host, port = self._server_socket.getsockname()
574
 
        return "bzr://%s:%d%s" % (host, port, urlutils.escape(self._homedir))
575
 
 
576
 
    def get_bogus_url(self):
577
 
        """Return a URL which will fail to connect"""
578
 
        return 'bzr://127.0.0.1:1/'
579
 
 
580
 
 
581
 
class SmartStat(object):
582
 
 
583
 
    def __init__(self, size, mode):
584
 
        self.st_size = size
585
 
        self.st_mode = mode
586
 
 
587
 
 
588
 
class SmartTransport(transport.Transport):
589
 
    """Connection to a smart server.
590
 
 
591
 
    The connection holds references to pipes that can be used to send requests
592
 
    to the server.
593
 
 
594
 
    The connection has a notion of the current directory to which it's
595
 
    connected; this is incorporated in filenames passed to the server.
596
 
    
597
 
    This supports some higher-level RPC operations and can also be treated 
598
 
    like a Transport to do file-like operations.
599
 
 
600
 
    The connection can be made over a tcp socket, or (in future) an ssh pipe
601
 
    or a series of http requests.  There are concrete subclasses for each
602
 
    type: SmartTCPTransport, etc.
603
 
    """
604
 
 
605
 
    # IMPORTANT FOR IMPLEMENTORS: SmartTransport MUST NOT be given encoding
606
 
    # responsibilities: Put those on SmartClient or similar. This is vital for
607
 
    # the ability to support multiple versions of the smart protocol over time:
608
 
    # SmartTransport is an adapter from the Transport object model to the 
609
 
    # SmartClient model, not an encoder.
610
 
 
611
 
    def __init__(self, url, clone_from=None, client=None):
612
 
        """Constructor.
613
 
 
614
 
        :param client: ignored when clone_from is not None.
615
 
        """
616
 
        ### Technically super() here is faulty because Transport's __init__
617
 
        ### fails to take 2 parameters, and if super were to choose a silly
618
 
        ### initialisation order things would blow up. 
619
 
        if not url.endswith('/'):
620
 
            url += '/'
621
 
        super(SmartTransport, self).__init__(url)
622
 
        self._scheme, self._username, self._password, self._host, self._port, self._path = \
623
 
                transport.split_url(url)
624
 
        if clone_from is None:
625
 
            if client is None:
626
 
                self._client = SmartStreamClient(self._connect_to_server)
627
 
            else:
628
 
                self._client = client
629
 
        else:
630
 
            # credentials may be stripped from the base in some circumstances
631
 
            # as yet to be clearly defined or documented, so copy them.
632
 
            self._username = clone_from._username
633
 
            # reuse same connection
634
 
            self._client = clone_from._client
635
 
 
636
 
    def abspath(self, relpath):
637
 
        """Return the full url to the given relative path.
638
 
        
639
 
        @param relpath: the relative path or path components
640
 
        @type relpath: str or list
641
 
        """
642
 
        return self._unparse_url(self._remote_path(relpath))
643
 
    
644
 
    def clone(self, relative_url):
645
 
        """Make a new SmartTransport related to me, sharing the same connection.
646
 
 
647
 
        This essentially opens a handle on a different remote directory.
648
 
        """
649
 
        if relative_url is None:
650
 
            return self.__class__(self.base, self)
651
 
        else:
652
 
            return self.__class__(self.abspath(relative_url), self)
653
 
 
654
 
    def is_readonly(self):
655
 
        """Smart server transport can do read/write file operations."""
656
 
        return False
657
 
                                                   
658
 
    def get_smart_client(self):
659
 
        return self._client
660
 
                                                   
661
 
    def _unparse_url(self, path):
662
 
        """Return URL for a path.
663
 
 
664
 
        :see: SFTPUrlHandling._unparse_url
665
 
        """
666
 
        # TODO: Eventually it should be possible to unify this with
667
 
        # SFTPUrlHandling._unparse_url?
668
 
        if path == '':
669
 
            path = '/'
670
 
        path = urllib.quote(path)
671
 
        netloc = urllib.quote(self._host)
672
 
        if self._username is not None:
673
 
            netloc = '%s@%s' % (urllib.quote(self._username), netloc)
674
 
        if self._port is not None:
675
 
            netloc = '%s:%d' % (netloc, self._port)
676
 
        return urlparse.urlunparse((self._scheme, netloc, path, '', '', ''))
677
 
 
678
 
    def _remote_path(self, relpath):
679
 
        """Returns the Unicode version of the absolute path for relpath."""
680
 
        return self._combine_paths(self._path, relpath)
681
 
 
682
 
    def has(self, relpath):
683
 
        """Indicate whether a remote file of the given name exists or not.
684
 
 
685
 
        :see: Transport.has()
686
 
        """
687
 
        resp = self._client._call('has', self._remote_path(relpath))
688
 
        if resp == ('yes', ):
689
 
            return True
690
 
        elif resp == ('no', ):
691
 
            return False
692
 
        else:
693
 
            self._translate_error(resp)
694
 
 
695
 
    def get(self, relpath):
696
 
        """Return file-like object reading the contents of a remote file.
697
 
        
698
 
        :see: Transport.get_bytes()/get_file()
699
 
        """
700
 
        remote = self._remote_path(relpath)
701
 
        resp = self._client._call('get', remote)
702
 
        if resp != ('ok', ):
703
 
            self._translate_error(resp, relpath)
704
 
        return StringIO(self._client._recv_bulk())
705
 
 
706
 
    def _serialise_optional_mode(self, mode):
707
 
        if mode is None:
708
 
            return ''
709
 
        else:
710
 
            return '%d' % mode
711
 
 
712
 
    def mkdir(self, relpath, mode=None):
713
 
        resp = self._client._call('mkdir', 
714
 
                                  self._remote_path(relpath), 
715
 
                                  self._serialise_optional_mode(mode))
716
 
        self._translate_error(resp)
717
 
 
718
 
    def put_bytes(self, relpath, upload_contents, mode=None):
719
 
        # FIXME: upload_file is probably not safe for non-ascii characters -
720
 
        # should probably just pass all parameters as length-delimited
721
 
        # strings?
722
 
        resp = self._client._call_with_upload(
723
 
            'put',
724
 
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
725
 
            upload_contents)
726
 
        self._translate_error(resp)
727
 
 
728
 
    def put_bytes_non_atomic(self, relpath, bytes, mode=None,
729
 
                             create_parent_dir=False,
730
 
                             dir_mode=None):
731
 
        """See Transport.put_bytes_non_atomic."""
732
 
        # FIXME: no encoding in the transport!
733
 
        create_parent_str = 'F'
734
 
        if create_parent_dir:
735
 
            create_parent_str = 'T'
736
 
 
737
 
        resp = self._client._call_with_upload(
738
 
            'put_non_atomic',
739
 
            (self._remote_path(relpath), self._serialise_optional_mode(mode),
740
 
             create_parent_str, self._serialise_optional_mode(dir_mode)),
741
 
            bytes)
742
 
        self._translate_error(resp)
743
 
 
744
 
    def put_file(self, relpath, upload_file, mode=None):
745
 
        # its not ideal to seek back, but currently put_non_atomic_file depends
746
 
        # on transports not reading before failing - which is a faulty
747
 
        # assumption I think - RBC 20060915
748
 
        pos = upload_file.tell()
749
 
        try:
750
 
            return self.put_bytes(relpath, upload_file.read(), mode)
751
 
        except:
752
 
            upload_file.seek(pos)
753
 
            raise
754
 
 
755
 
    def put_file_non_atomic(self, relpath, f, mode=None,
756
 
                            create_parent_dir=False,
757
 
                            dir_mode=None):
758
 
        return self.put_bytes_non_atomic(relpath, f.read(), mode=mode,
759
 
                                         create_parent_dir=create_parent_dir,
760
 
                                         dir_mode=dir_mode)
761
 
 
762
 
    def append_file(self, relpath, from_file, mode=None):
763
 
        return self.append_bytes(relpath, from_file.read(), mode)
764
 
        
765
 
    def append_bytes(self, relpath, bytes, mode=None):
766
 
        resp = self._client._call_with_upload(
767
 
            'append',
768
 
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
769
 
            bytes)
770
 
        if resp[0] == 'appended':
771
 
            return int(resp[1])
772
 
        self._translate_error(resp)
773
 
 
774
 
    def delete(self, relpath):
775
 
        resp = self._client._call('delete', self._remote_path(relpath))
776
 
        self._translate_error(resp)
777
 
 
778
 
    def readv(self, relpath, offsets):
779
 
        if not offsets:
780
 
            return
781
 
 
782
 
        offsets = list(offsets)
783
 
 
784
 
        sorted_offsets = sorted(offsets)
785
 
        # turn the list of offsets into a stack
786
 
        offset_stack = iter(offsets)
787
 
        cur_offset_and_size = offset_stack.next()
788
 
        coalesced = list(self._coalesce_offsets(sorted_offsets,
789
 
                               limit=self._max_readv_combine,
790
 
                               fudge_factor=self._bytes_to_read_before_seek))
791
 
 
792
 
 
793
 
        resp = self._client._call_with_upload(
794
 
            'readv',
795
 
            (self._remote_path(relpath),),
796
 
            self._client._serialise_offsets((c.start, c.length) for c in coalesced))
797
 
 
798
 
        if resp[0] != 'readv':
799
 
            # This should raise an exception
800
 
            self._translate_error(resp)
801
 
            return
802
 
 
803
 
        data = self._client._recv_bulk()
804
 
        # Cache the results, but only until they have been fulfilled
805
 
        data_map = {}
806
 
        for c_offset in coalesced:
807
 
            if len(data) < c_offset.length:
808
 
                raise errors.ShortReadvError(relpath, c_offset.start,
809
 
                            c_offset.length, actual=len(data))
810
 
            for suboffset, subsize in c_offset.ranges:
811
 
                key = (c_offset.start+suboffset, subsize)
812
 
                data_map[key] = data[suboffset:suboffset+subsize]
813
 
            data = data[c_offset.length:]
814
 
 
815
 
            # Now that we've read some data, see if we can yield anything back
816
 
            while cur_offset_and_size in data_map:
817
 
                this_data = data_map.pop(cur_offset_and_size)
818
 
                yield cur_offset_and_size[0], this_data
819
 
                cur_offset_and_size = offset_stack.next()
820
 
 
821
 
    def rename(self, rel_from, rel_to):
822
 
        self._call('rename', 
823
 
                   self._remote_path(rel_from),
824
 
                   self._remote_path(rel_to))
825
 
 
826
 
    def move(self, rel_from, rel_to):
827
 
        self._call('move', 
828
 
                   self._remote_path(rel_from),
829
 
                   self._remote_path(rel_to))
830
 
 
831
 
    def rmdir(self, relpath):
832
 
        resp = self._call('rmdir', self._remote_path(relpath))
833
 
 
834
 
    def _call(self, method, *args):
835
 
        resp = self._client._call(method, *args)
836
 
        self._translate_error(resp)
837
 
 
838
 
    def _translate_error(self, resp, orig_path=None):
839
 
        """Raise an exception from a response"""
840
 
        if resp is None:
841
 
            what = None
842
 
        else:
843
 
            what = resp[0]
844
 
        if what == 'ok':
845
 
            return
846
 
        elif what == 'NoSuchFile':
847
 
            if orig_path is not None:
848
 
                error_path = orig_path
849
 
            else:
850
 
                error_path = resp[1]
851
 
            raise errors.NoSuchFile(error_path)
852
 
        elif what == 'error':
853
 
            raise errors.SmartProtocolError(unicode(resp[1]))
854
 
        elif what == 'FileExists':
855
 
            raise errors.FileExists(resp[1])
856
 
        elif what == 'DirectoryNotEmpty':
857
 
            raise errors.DirectoryNotEmpty(resp[1])
858
 
        elif what == 'ShortReadvError':
859
 
            raise errors.ShortReadvError(resp[1], int(resp[2]),
860
 
                                         int(resp[3]), int(resp[4]))
861
 
        elif what in ('UnicodeEncodeError', 'UnicodeDecodeError'):
862
 
            encoding = str(resp[1]) # encoding must always be a string
863
 
            val = resp[2]
864
 
            start = int(resp[3])
865
 
            end = int(resp[4])
866
 
            reason = str(resp[5]) # reason must always be a string
867
 
            if val.startswith('u:'):
868
 
                val = val[2:]
869
 
            elif val.startswith('s:'):
870
 
                val = val[2:].decode('base64')
871
 
            if what == 'UnicodeDecodeError':
872
 
                raise UnicodeDecodeError(encoding, val, start, end, reason)
873
 
            elif what == 'UnicodeEncodeError':
874
 
                raise UnicodeEncodeError(encoding, val, start, end, reason)
875
 
        elif what == "ReadOnlyError":
876
 
            raise errors.TransportNotPossible('readonly transport')
877
 
        else:
878
 
            raise errors.SmartProtocolError('unexpected smart server error: %r' % (resp,))
879
 
 
880
 
    def _send_tuple(self, args):
881
 
        self._client._send_tuple(args)
882
 
 
883
 
    def _recv_tuple(self):
884
 
        return self._client._recv_tuple()
885
 
 
886
 
    def disconnect(self):
887
 
        self._client.disconnect()
888
 
 
889
 
    def delete_tree(self, relpath):
890
 
        raise errors.TransportNotPossible('readonly transport')
891
 
 
892
 
    def stat(self, relpath):
893
 
        resp = self._client._call('stat', self._remote_path(relpath))
894
 
        if resp[0] == 'stat':
895
 
            return SmartStat(int(resp[1]), int(resp[2], 8))
896
 
        else:
897
 
            self._translate_error(resp)
898
 
 
899
 
    ## def lock_read(self, relpath):
900
 
    ##     """Lock the given file for shared (read) access.
901
 
    ##     :return: A lock object, which should be passed to Transport.unlock()
902
 
    ##     """
903
 
    ##     # The old RemoteBranch ignore lock for reading, so we will
904
 
    ##     # continue that tradition and return a bogus lock object.
905
 
    ##     class BogusLock(object):
906
 
    ##         def __init__(self, path):
907
 
    ##             self.path = path
908
 
    ##         def unlock(self):
909
 
    ##             pass
910
 
    ##     return BogusLock(relpath)
911
 
 
912
 
    def listable(self):
913
 
        return True
914
 
 
915
 
    def list_dir(self, relpath):
916
 
        resp = self._client._call('list_dir',
917
 
                                  self._remote_path(relpath))
918
 
        if resp[0] == 'names':
919
 
            return [name.encode('ascii') for name in resp[1:]]
920
 
        else:
921
 
            self._translate_error(resp)
922
 
 
923
 
    def iter_files_recursive(self):
924
 
        resp = self._client._call('iter_files_recursive',
925
 
                                  self._remote_path(''))
926
 
        if resp[0] == 'names':
927
 
            return resp[1:]
928
 
        else:
929
 
            self._translate_error(resp)
930
 
 
931
 
 
932
 
class SmartStreamClient(SmartProtocolBase):
933
 
    """Connection to smart server over two streams"""
934
 
 
935
 
    def __init__(self, connect_func):
936
 
        self._connect_func = connect_func
937
 
        self._connected = False
938
 
 
939
 
    def __del__(self):
940
 
        self.disconnect()
941
 
 
942
 
    def _ensure_connection(self):
943
 
        if not self._connected:
944
 
            self._in, self._out = self._connect_func()
945
 
            self._connected = True
946
 
 
947
 
    def _send_tuple(self, args):
948
 
        self._ensure_connection()
949
 
        return self._write_and_flush(_encode_tuple(args))
950
 
 
951
 
    def _send_bulk_data(self, body):
952
 
        self._ensure_connection()
953
 
        SmartProtocolBase._send_bulk_data(self, body)
954
 
        
955
 
    def _recv_bulk(self):
956
 
        self._ensure_connection()
957
 
        return SmartProtocolBase._recv_bulk(self)
958
 
 
959
 
    def _recv_tuple(self):
960
 
        self._ensure_connection()
961
 
        return SmartProtocolBase._recv_tuple(self)
962
 
 
963
 
    def _recv_trailer(self):
964
 
        self._ensure_connection()
965
 
        return SmartProtocolBase._recv_trailer(self)
966
 
 
967
 
    def disconnect(self):
968
 
        """Close connection to the server"""
969
 
        if self._connected:
970
 
            self._out.close()
971
 
            self._in.close()
972
 
 
973
 
    def _call(self, *args):
974
 
        self._send_tuple(args)
975
 
        return self._recv_tuple()
976
 
 
977
 
    def _call_with_upload(self, method, args, body):
978
 
        """Call an rpc, supplying bulk upload data.
979
 
 
980
 
        :param method: method name to call
981
 
        :param args: parameter args tuple
982
 
        :param body: upload body as a byte string
983
 
        """
984
 
        self._send_tuple((method,) + args)
985
 
        self._send_bulk_data(body)
986
 
        return self._recv_tuple()
987
 
 
988
 
    def query_version(self):
989
 
        """Return protocol version number of the server."""
990
 
        # XXX: should make sure it's empty
991
 
        self._send_tuple(('hello',))
992
 
        resp = self._recv_tuple()
993
 
        if resp == ('ok', '1'):
994
 
            return 1
995
 
        else:
996
 
            raise errors.SmartProtocolError("bad response %r" % (resp,))
997
 
 
998
 
 
999
 
class SmartTCPTransport(SmartTransport):
1000
 
    """Connection to smart server over plain tcp"""
1001
 
 
1002
 
    def __init__(self, url, clone_from=None):
1003
 
        super(SmartTCPTransport, self).__init__(url, clone_from)
1004
 
        try:
1005
 
            self._port = int(self._port)
1006
 
        except (ValueError, TypeError), e:
1007
 
            raise errors.InvalidURL(path=url, extra="invalid port %s" % self._port)
1008
 
        self._socket = None
1009
 
 
1010
 
    def _connect_to_server(self):
1011
 
        self._socket = socket.socket()
1012
 
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1013
 
        result = self._socket.connect_ex((self._host, int(self._port)))
1014
 
        if result:
1015
 
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1016
 
                    (self._host, self._port, os.strerror(result)))
1017
 
        # TODO: May be more efficient to just treat them as sockets
1018
 
        # throughout?  But what about pipes to ssh?...
1019
 
        to_server = self._socket.makefile('w')
1020
 
        from_server = self._socket.makefile('r')
1021
 
        return from_server, to_server
1022
 
 
1023
 
    def disconnect(self):
1024
 
        super(SmartTCPTransport, self).disconnect()
1025
 
        # XXX: Is closing the socket as well as closing the files really
1026
 
        # necessary?
1027
 
        if self._socket is not None:
1028
 
            self._socket.close()
1029
 
 
1030
 
try:
1031
 
    from bzrlib.transport import sftp, ssh
1032
 
except errors.ParamikoNotPresent:
1033
 
    # no paramiko, no SSHTransport.
1034
 
    pass
1035
 
else:
1036
 
    class SmartSSHTransport(SmartTransport):
1037
 
        """Connection to smart server over SSH."""
1038
 
 
1039
 
        def __init__(self, url, clone_from=None):
1040
 
            # TODO: all this probably belongs in the parent class.
1041
 
            super(SmartSSHTransport, self).__init__(url, clone_from)
1042
 
            try:
1043
 
                if self._port is not None:
1044
 
                    self._port = int(self._port)
1045
 
            except (ValueError, TypeError), e:
1046
 
                raise errors.InvalidURL(path=url, extra="invalid port %s" % self._port)
1047
 
 
1048
 
        def _connect_to_server(self):
1049
 
            executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
1050
 
            vendor = ssh._get_ssh_vendor()
1051
 
            self._ssh_connection = vendor.connect_ssh(self._username,
1052
 
                    self._password, self._host, self._port,
1053
 
                    command=[executable, 'serve', '--inet', '--directory=/',
1054
 
                             '--allow-writes'])
1055
 
            return self._ssh_connection.get_filelike_channels()
1056
 
 
1057
 
        def disconnect(self):
1058
 
            super(SmartSSHTransport, self).disconnect()
1059
 
            self._ssh_connection.close()
1060
 
 
1061
 
 
1062
 
def get_test_permutations():
1063
 
    """Return (transport, server) permutations for testing"""
1064
 
    return [(SmartTCPTransport, SmartTCPServer_for_testing)]