~didrocks/ubuntuone-client/use_result_var

« back to all changes in this revision

Viewing changes to ubuntuone/u1sync/client.py

  • Committer: Bazaar Package Importer
  • Author(s): Rodney Dawes
  • Date: 2011-02-11 16:18:11 UTC
  • mto: This revision was merged to the branch mainline in revision 67.
  • Revision ID: james.westby@ubuntu.com-20110211161811-n18dj9lde7dxqjzr
Tags: upstream-1.5.4
ImportĀ upstreamĀ versionĀ 1.5.4

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# ubuntuone.u1sync.client
2
 
#
3
 
# Client/protocol end of u1sync
4
 
#
5
 
# Author: Lucio Torre <lucio.torre@canonical.com>
6
 
# Author: Tim Cole <tim.cole@canonical.com>
7
 
#
8
 
# Copyright 2009 Canonical Ltd.
9
 
#
10
 
# This program is free software: you can redistribute it and/or modify it
11
 
# under the terms of the GNU General Public License version 3, as published
12
 
# by the Free Software Foundation.
13
 
#
14
 
# This program is distributed in the hope that it will be useful, but
15
 
# WITHOUT ANY WARRANTY; without even the implied warranties of
16
 
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
17
 
# PURPOSE.  See the GNU General Public License for more details.
18
 
#
19
 
# You should have received a copy of the GNU General Public License along
20
 
# with this program.  If not, see <http://www.gnu.org/licenses/>.
21
 
"""Pretty API for protocol client."""
22
 
 
23
 
from __future__ import with_statement
24
 
 
25
 
import os
26
 
import sys
27
 
import shutil
28
 
from Queue import Queue
29
 
from threading import Lock
30
 
import zlib
31
 
from cStringIO import StringIO
32
 
 
33
 
from twisted.internet import reactor, defer
34
 
from twisted.internet.defer import inlineCallbacks, returnValue
35
 
from ubuntuone.logger import LOGFOLDER
36
 
from ubuntuone.storageprotocol.content_hash import crc32
37
 
from ubuntuone.storageprotocol.context import get_ssl_context
38
 
from ubuntuone.u1sync.genericmerge import MergeNode
39
 
from ubuntuone.u1sync.utils import should_sync
40
 
 
41
 
CONSUMER_KEY = "ubuntuone"
42
 
CONSUMER_SECRET = "hammertime"
43
 
 
44
 
from oauth.oauth import OAuthConsumer
45
 
from ubuntuone.storageprotocol.client import (
46
 
    StorageClientFactory, StorageClient)
47
 
from ubuntuone.storageprotocol import request, volumes
48
 
from ubuntuone.storageprotocol.dircontent_pb2 import \
49
 
    DirectoryContent, DIRECTORY
50
 
import uuid
51
 
import logging
52
 
from logging.handlers import RotatingFileHandler
53
 
import time
54
 
 
55
 
from ubuntuone.platform import (
56
 
    remove_file,
57
 
    open_file,
58
 
    rename,
59
 
)
60
 
 
61
 
def share_str(share_uuid):
62
 
    """Converts a share UUID to a form the protocol likes."""
63
 
    return str(share_uuid) if share_uuid is not None else request.ROOT
64
 
 
65
 
LOGFILENAME = os.path.join(LOGFOLDER, 'u1sync.log')
66
 
u1_logger = logging.getLogger("u1sync.timing.log")
67
 
handler = RotatingFileHandler(LOGFILENAME)
68
 
u1_logger.addHandler(handler)
69
 
 
70
 
def log_timing(func):
71
 
    def wrapper(*arg, **kwargs):
72
 
        start = time.time()
73
 
        ent = func(*arg, **kwargs)
74
 
        stop = time.time()
75
 
        u1_logger.debug('for %s %0.5f ms elapsed' % (func.func_name, \
76
 
                    (stop-start)*1000.0))
77
 
        return ent
78
 
    return wrapper
79
 
 
80
 
 
81
 
class ForcedShutdown(Exception):
82
 
    """Client shutdown forced."""
83
 
 
84
 
 
85
 
class Waiter(object):
86
 
    """Wait object for blocking waits."""
87
 
 
88
 
    def __init__(self):
89
 
        """Initializes the wait object."""
90
 
        self.queue = Queue()
91
 
 
92
 
    def wake(self, result):
93
 
        """Wakes the waiter with a result."""
94
 
        self.queue.put((result, None))
95
 
 
96
 
    def wakeAndRaise(self, exc_info):
97
 
        """Wakes the waiter, raising the given exception in it."""
98
 
        self.queue.put((None, exc_info))
99
 
 
100
 
    def wakeWithResult(self, func, *args, **kw):
101
 
        """Wakes the waiter with the result of the given function."""
102
 
        try:
103
 
            result = func(*args, **kw)
104
 
        except Exception:
105
 
            self.wakeAndRaise(sys.exc_info())
106
 
        else:
107
 
            self.wake(result)
108
 
 
109
 
    def wait(self):
110
 
        """Waits for wakeup."""
111
 
        (result, exc_info) = self.queue.get()
112
 
        if exc_info:
113
 
            try:
114
 
                raise exc_info[0], exc_info[1], exc_info[2]
115
 
            finally:
116
 
                exc_info = None
117
 
        else:
118
 
            return result
119
 
 
120
 
 
121
 
class SyncStorageClient(StorageClient):
122
 
    """Simple client that calls a callback on connection."""
123
 
 
124
 
    @log_timing
125
 
    def connectionMade(self):
126
 
        """Setup and call callback."""
127
 
        StorageClient.connectionMade(self)
128
 
        if self.factory.current_protocol not in (None, self):
129
 
            self.factory.current_protocol.transport.loseConnection()
130
 
        self.factory.current_protocol = self
131
 
        self.factory.observer.connected()
132
 
 
133
 
    @log_timing
134
 
    def connectionLost(self, reason=None):
135
 
        """Callback for established connection lost."""
136
 
        StorageClient.connectionLost(self, reason)
137
 
        if self.factory.current_protocol is self:
138
 
            self.factory.current_protocol = None
139
 
            self.factory.observer.disconnected(reason)
140
 
 
141
 
 
142
 
class SyncClientFactory(StorageClientFactory):
143
 
    """A cmd protocol factory."""
144
 
    # no init: pylint: disable-msg=W0232
145
 
 
146
 
    protocol = SyncStorageClient
147
 
 
148
 
    @log_timing
149
 
    def __init__(self, observer):
150
 
        """Create the factory"""
151
 
        self.observer = observer
152
 
        self.current_protocol = None
153
 
 
154
 
    @log_timing
155
 
    def clientConnectionFailed(self, connector, reason):
156
 
        """We failed at connecting."""
157
 
        self.current_protocol = None
158
 
        self.observer.connection_failed(reason)
159
 
 
160
 
 
161
 
class UnsupportedOperationError(Exception):
162
 
    """The operation is unsupported by the protocol version."""
163
 
 
164
 
 
165
 
class ConnectionError(Exception):
166
 
    """A connection error."""
167
 
 
168
 
 
169
 
class AuthenticationError(Exception):
170
 
    """An authentication error."""
171
 
 
172
 
 
173
 
class NoSuchShareError(Exception):
174
 
    """Error when there is no such share available."""
175
 
 
176
 
 
177
 
class CapabilitiesError(Exception):
178
 
    """A capabilities set/query related error."""
179
 
 
180
 
class Client(object):
181
 
    """U1 storage client facade."""
182
 
    required_caps = frozenset(["no-content", "fix462230"])
183
 
 
184
 
    def __init__(self, realm=None, reactor=reactor):
185
 
        """Create the instance.
186
 
 
187
 
        'realm' is no longer used, but is left as param for API compatibility.
188
 
 
189
 
        """
190
 
        self.reactor = reactor
191
 
        self.factory = SyncClientFactory(self)
192
 
 
193
 
        self._status_lock = Lock()
194
 
        self._status = "disconnected"
195
 
        self._status_reason = None
196
 
        self._status_waiting = []
197
 
        self._active_waiters = set()
198
 
 
199
 
        self.consumer_key = CONSUMER_KEY
200
 
        self.consumer_secret = CONSUMER_SECRET
201
 
 
202
 
    def force_shutdown(self):
203
 
        """Forces the client to shut itself down."""
204
 
        with self._status_lock:
205
 
            self._status = "forced_shutdown"
206
 
            self._reason = None
207
 
            for waiter in self._active_waiters:
208
 
                waiter.wakeAndRaise((ForcedShutdown("Forced shutdown"),
209
 
                                     None, None))
210
 
            self._active_waiters.clear()
211
 
 
212
 
    def _get_waiter_locked(self):
213
 
        """Gets a wait object for blocking waits.  Should be called with the
214
 
        status lock held.
215
 
        """
216
 
        waiter = Waiter()
217
 
        if self._status == "forced_shutdown":
218
 
            raise ForcedShutdown("Forced shutdown")
219
 
        self._active_waiters.add(waiter)
220
 
        return waiter
221
 
 
222
 
    def _get_waiter(self):
223
 
        """Get a wait object for blocking waits.  Acquires the status lock."""
224
 
        with self._status_lock:
225
 
            return self._get_waiter_locked()
226
 
 
227
 
    def _wait(self, waiter):
228
 
        """Waits for the waiter."""
229
 
        try:
230
 
            return waiter.wait()
231
 
        finally:
232
 
            with self._status_lock:
233
 
                if waiter in self._active_waiters:
234
 
                    self._active_waiters.remove(waiter)
235
 
 
236
 
    @log_timing
237
 
    def _change_status(self, status, reason=None):
238
 
        """Changes the client status.  Usually called from the reactor
239
 
        thread.
240
 
 
241
 
        """
242
 
        with self._status_lock:
243
 
            if self._status == "forced_shutdown":
244
 
                return
245
 
            self._status = status
246
 
            self._status_reason = reason
247
 
            waiting = self._status_waiting
248
 
            self._status_waiting = []
249
 
        for waiter in waiting:
250
 
            waiter.wake((status, reason))
251
 
 
252
 
    @log_timing
253
 
    def _await_status_not(self, *ignore_statuses):
254
 
        """Blocks until the client status changes, returning the new status.
255
 
        Should never be called from the reactor thread.
256
 
 
257
 
        """
258
 
        with self._status_lock:
259
 
            status = self._status
260
 
            reason = self._status_reason
261
 
            while status in ignore_statuses:
262
 
                waiter = self._get_waiter_locked()
263
 
                self._status_waiting.append(waiter)
264
 
                self._status_lock.release()
265
 
                try:
266
 
                    status, reason = self._wait(waiter)
267
 
                finally:
268
 
                    self._status_lock.acquire()
269
 
            if status == "forced_shutdown":
270
 
                raise ForcedShutdown("Forced shutdown.")
271
 
            return (status, reason)
272
 
 
273
 
    def connection_failed(self, reason):
274
 
        """Notification that connection failed."""
275
 
        self._change_status("disconnected", reason)
276
 
 
277
 
    def connected(self):
278
 
        """Notification that connection succeeded."""
279
 
        self._change_status("connected")
280
 
 
281
 
    def disconnected(self, reason):
282
 
        """Notification that we were disconnected."""
283
 
        self._change_status("disconnected", reason)
284
 
 
285
 
    def defer_from_thread(self, function, *args, **kwargs):
286
 
        """Do twisted defer magic to get results and show exceptions."""
287
 
        waiter = self._get_waiter()
288
 
        @log_timing
289
 
        def runner():
290
 
            """inner."""
291
 
            # we do want to catch all
292
 
            # no init: pylint: disable-msg=W0703
293
 
            try:
294
 
                d = function(*args, **kwargs)
295
 
                if isinstance(d, defer.Deferred):
296
 
                    d.addCallbacks(lambda r: waiter.wake((r, None, None)),
297
 
                                   lambda f: waiter.wake((None, None, f)))
298
 
                else:
299
 
                    waiter.wake((d, None, None))
300
 
            except Exception:
301
 
                waiter.wake((None, sys.exc_info(), None))
302
 
 
303
 
        self.reactor.callFromThread(runner)
304
 
        result, exc_info, failure = self._wait(waiter)
305
 
        if exc_info:
306
 
            try:
307
 
                raise exc_info[0], exc_info[1], exc_info[2]
308
 
            finally:
309
 
                exc_info = None
310
 
        elif failure:
311
 
            failure.raiseException()
312
 
        else:
313
 
            return result
314
 
 
315
 
    @log_timing
316
 
    def connect(self, host, port):
317
 
        """Connect to host/port."""
318
 
        def _connect():
319
 
            """Deferred part."""
320
 
            self.reactor.connectTCP(host, port, self.factory)
321
 
        self._connect_inner(_connect)
322
 
 
323
 
    @log_timing
324
 
    def connect_ssl(self, host, port, no_verify):
325
 
        """Connect to host/port using ssl."""
326
 
        def _connect():
327
 
            """deferred part."""
328
 
            ctx = get_ssl_context(no_verify)
329
 
            self.reactor.connectSSL(host, port, self.factory, ctx)
330
 
        self._connect_inner(_connect)
331
 
 
332
 
    @log_timing
333
 
    def _connect_inner(self, _connect):
334
 
        """Helper function for connecting."""
335
 
        self._change_status("connecting")
336
 
        self.reactor.callFromThread(_connect)
337
 
        status, reason = self._await_status_not("connecting")
338
 
        if status != "connected":
339
 
            raise ConnectionError(reason.value)
340
 
 
341
 
    @log_timing
342
 
    def disconnect(self):
343
 
        """Disconnect."""
344
 
        if self.factory.current_protocol is not None:
345
 
            self.reactor.callFromThread(
346
 
                self.factory.current_protocol.transport.loseConnection)
347
 
        self._await_status_not("connecting", "connected", "authenticated")
348
 
 
349
 
    @log_timing
350
 
    def oauth_from_token(self, token, consumer=None):
351
 
        """Perform OAuth authorisation using an existing token."""
352
 
 
353
 
        if consumer is None:
354
 
            consumer = OAuthConsumer(self.consumer_key, self.consumer_secret)
355
 
 
356
 
        def _auth_successful(value):
357
 
            """Callback for successful auth.  Changes status to
358
 
            authenticated."""
359
 
            self._change_status("authenticated")
360
 
            return value
361
 
 
362
 
        def _auth_failed(value):
363
 
            """Callback for failed auth.  Disconnects."""
364
 
            self.factory.current_protocol.transport.loseConnection()
365
 
            return value
366
 
 
367
 
        def _wrapped_authenticate():
368
 
            """Wrapped authenticate."""
369
 
            d = self.factory.current_protocol.oauth_authenticate(consumer,
370
 
                                                                 token)
371
 
            d.addCallbacks(_auth_successful, _auth_failed)
372
 
            return d
373
 
 
374
 
        try:
375
 
            self.defer_from_thread(_wrapped_authenticate)
376
 
        except request.StorageProtocolError, e:
377
 
            raise AuthenticationError(e)
378
 
        status, reason = self._await_status_not("connected")
379
 
        if status != "authenticated":
380
 
            raise AuthenticationError(reason.value)
381
 
 
382
 
    @log_timing
383
 
    def set_capabilities(self):
384
 
        """Set the capabilities with the server"""
385
 
 
386
 
        client = self.factory.current_protocol
387
 
        @log_timing
388
 
        def set_caps_callback(req):
389
 
            "Caps query succeeded"
390
 
            if not req.accepted:
391
 
                de = defer.fail("The server denied setting %s capabilities" % \
392
 
                                req.caps)
393
 
                return de
394
 
 
395
 
        @log_timing
396
 
        def query_caps_callback(req):
397
 
            "Caps query succeeded"
398
 
            if req.accepted:
399
 
                set_d = client.set_caps(self.required_caps)
400
 
                set_d.addCallback(set_caps_callback)
401
 
                return set_d
402
 
            else:
403
 
                # the server don't have the requested capabilities.
404
 
                # return a failure for now, in the future we might want
405
 
                # to reconnect to another server
406
 
                de = defer.fail("The server don't have the requested"
407
 
                                " capabilities: %s" % str(req.caps))
408
 
                return de
409
 
 
410
 
        @log_timing
411
 
        def _wrapped_set_capabilities():
412
 
            """Wrapped set_capabilities """
413
 
            d = client.query_caps(self.required_caps)
414
 
            d.addCallback(query_caps_callback)
415
 
            return d
416
 
 
417
 
        try:
418
 
            self.defer_from_thread(_wrapped_set_capabilities)
419
 
        except request.StorageProtocolError, e:
420
 
            raise CapabilitiesError(e)
421
 
 
422
 
    @log_timing
423
 
    def get_root_info(self, volume_uuid):
424
 
        """Returns the UUID of the applicable share root."""
425
 
        if volume_uuid is None:
426
 
            _get_root = self.factory.current_protocol.get_root
427
 
            root = self.defer_from_thread(_get_root)
428
 
            return (uuid.UUID(root), True)
429
 
        else:
430
 
            str_volume_uuid = str(volume_uuid)
431
 
            volume = self._match_volume(lambda v: \
432
 
                                       str(v.volume_id) == str_volume_uuid)
433
 
            if isinstance(volume, volumes.ShareVolume):
434
 
                modify = volume.access_level == "Modify"
435
 
            if isinstance(volume, volumes.UDFVolume):
436
 
                modify = True
437
 
            return (uuid.UUID(str(volume.node_id)), modify)
438
 
 
439
 
    @log_timing
440
 
    def resolve_path(self, share_uuid, root_uuid, path):
441
 
        """Resolve path relative to the given root node."""
442
 
 
443
 
        @inlineCallbacks
444
 
        def _resolve_worker():
445
 
            """Path resolution worker."""
446
 
            node_uuid = root_uuid
447
 
            local_path = path.strip('/')
448
 
 
449
 
            while local_path != '':
450
 
                local_path, name = os.path.split(local_path)
451
 
                hashes = yield self._get_node_hashes(share_uuid, [root_uuid])
452
 
                content_hash = hashes.get(root_uuid, None)
453
 
                if content_hash is None:
454
 
                    raise KeyError, "Content hash not available"
455
 
                entries = yield self._get_raw_dir_entries(share_uuid,
456
 
                                                          root_uuid,
457
 
                                                          content_hash)
458
 
                match_name = name.decode('utf-8')
459
 
                match = None
460
 
                for entry in entries:
461
 
                    if match_name == entry.name:
462
 
                        match = entry
463
 
                        break
464
 
 
465
 
                if match is None:
466
 
                    raise KeyError, "Path not found"
467
 
 
468
 
                node_uuid = uuid.UUID(match.node)
469
 
 
470
 
            returnValue(node_uuid)
471
 
 
472
 
        return self.defer_from_thread(_resolve_worker)
473
 
 
474
 
    @log_timing
475
 
    def find_volume(self, volume_spec):
476
 
        """Finds a share matching the given UUID.  Looks at both share UUIDs
477
 
        and root node UUIDs."""
478
 
        volume = self._match_volume(lambda s: \
479
 
                                    str(s.volume_id) == volume_spec or \
480
 
                                    str(s.node_id) == volume_spec)
481
 
        return uuid.UUID(str(volume.volume_id))
482
 
 
483
 
    @log_timing
484
 
    def _match_volume(self, predicate):
485
 
        """Finds a volume matching the given predicate."""
486
 
        _list_shares = self.factory.current_protocol.list_volumes
487
 
        r = self.defer_from_thread(_list_shares)
488
 
        for volume in r.volumes:
489
 
            if predicate(volume):
490
 
                return volume
491
 
        raise NoSuchShareError()
492
 
 
493
 
    @log_timing
494
 
    def build_tree(self, share_uuid, root_uuid):
495
 
        """Builds and returns a tree representing the metadata for the given
496
 
        subtree in the given share.
497
 
 
498
 
        @param share_uuid: the share UUID or None for the user's volume
499
 
        @param root_uuid: the root UUID of the subtree (must be a directory)
500
 
        @return: a MergeNode tree
501
 
 
502
 
        """
503
 
        root = MergeNode(node_type=DIRECTORY, uuid=root_uuid)
504
 
 
505
 
        @log_timing
506
 
        @inlineCallbacks
507
 
        def _get_root_content_hash():
508
 
            """Obtain the content hash for the root node."""
509
 
            result = yield self._get_node_hashes(share_uuid, [root_uuid])
510
 
            returnValue(result.get(root_uuid, None))
511
 
 
512
 
        root.content_hash = self.defer_from_thread(_get_root_content_hash)
513
 
        if root.content_hash is None:
514
 
            raise ValueError("No content available for node %s" % root_uuid)
515
 
 
516
 
        @log_timing
517
 
        @inlineCallbacks
518
 
        def _get_children(parent_uuid, parent_content_hash):
519
 
            """Obtain a sequence of MergeNodes corresponding to a node's
520
 
            immediate children.
521
 
 
522
 
            """
523
 
            entries = yield self._get_raw_dir_entries(share_uuid,
524
 
                                                      parent_uuid,
525
 
                                                      parent_content_hash)
526
 
            children = {}
527
 
            for entry in entries:
528
 
                if should_sync(entry.name):
529
 
                    child = MergeNode(node_type=entry.node_type,
530
 
                                      uuid=uuid.UUID(entry.node))
531
 
                    children[entry.name] = child
532
 
 
533
 
            child_uuids = [child.uuid for child in children.itervalues()]
534
 
            content_hashes = yield self._get_node_hashes(share_uuid,
535
 
                                                         child_uuids)
536
 
            for child in children.itervalues():
537
 
                child.content_hash = content_hashes.get(child.uuid, None)
538
 
 
539
 
            returnValue(children)
540
 
 
541
 
        need_children = [root]
542
 
        while need_children:
543
 
            node = need_children.pop()
544
 
            if node.content_hash is not None:
545
 
                children = self.defer_from_thread(_get_children, node.uuid,
546
 
                                                  node.content_hash)
547
 
                node.children = children
548
 
                for child in children.itervalues():
549
 
                    if child.node_type == DIRECTORY:
550
 
                        need_children.append(child)
551
 
 
552
 
        return root
553
 
 
554
 
    @log_timing
555
 
    def _get_raw_dir_entries(self, share_uuid, node_uuid, content_hash):
556
 
        """Gets raw dir entries for the given directory."""
557
 
        d = self.factory.current_protocol.get_content(share_str(share_uuid),
558
 
                                                      str(node_uuid),
559
 
                                                      content_hash)
560
 
        d.addCallback(lambda c: zlib.decompress(c.data))
561
 
 
562
 
        def _parse_content(raw_content):
563
 
            """Parses directory content into a list of entry objects."""
564
 
            unserialized_content = DirectoryContent()
565
 
            unserialized_content.ParseFromString(raw_content)
566
 
            return list(unserialized_content.entries)
567
 
 
568
 
        d.addCallback(_parse_content)
569
 
        return d
570
 
 
571
 
    @log_timing
572
 
    def download_string(self, share_uuid, node_uuid, content_hash):
573
 
        """Reads a file from the server into a string."""
574
 
        output = StringIO()
575
 
        self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
576
 
                             content_hash=content_hash, output=output)
577
 
        return output.getValue()
578
 
 
579
 
    @log_timing
580
 
    def download_file(self, share_uuid, node_uuid, content_hash, filename):
581
 
        """Downloads a file from the server."""
582
 
        partial_filename = "%s.u1partial" % filename
583
 
        output = open_file(partial_filename, "w")
584
 
 
585
 
        @log_timing
586
 
        def rename_file():
587
 
            """Renames the temporary file to the final name."""
588
 
            output.close()
589
 
            rename(partial_filename, filename)
590
 
 
591
 
        @log_timing
592
 
        def delete_file():
593
 
            """Deletes the temporary file."""
594
 
            output.close()
595
 
            remove_file(partial_filename)
596
 
 
597
 
        self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
598
 
                             content_hash=content_hash, output=output,
599
 
                             on_success=rename_file, on_failure=delete_file)
600
 
 
601
 
    @log_timing
602
 
    def _download_inner(self, share_uuid, node_uuid, content_hash, output,
603
 
                        on_success=lambda: None, on_failure=lambda: None):
604
 
        """Helper function for content downloads."""
605
 
        dec = zlib.decompressobj()
606
 
 
607
 
        @log_timing
608
 
        def write_data(data):
609
 
            """Helper which writes data to the output file."""
610
 
            uncompressed_data = dec.decompress(data)
611
 
            output.write(uncompressed_data)
612
 
 
613
 
        @log_timing
614
 
        def finish_download(value):
615
 
            """Helper which finishes the download."""
616
 
            uncompressed_data = dec.flush()
617
 
            output.write(uncompressed_data)
618
 
            on_success()
619
 
            return value
620
 
 
621
 
        @log_timing
622
 
        def abort_download(value):
623
 
            """Helper which aborts the download."""
624
 
            on_failure()
625
 
            return value
626
 
 
627
 
        @log_timing
628
 
        def _download():
629
 
            """Async helper."""
630
 
            _get_content = self.factory.current_protocol.get_content
631
 
            d = _get_content(share_str(share_uuid), str(node_uuid),
632
 
                             content_hash, callback=write_data)
633
 
            d.addCallbacks(finish_download, abort_download)
634
 
            return d
635
 
 
636
 
        self.defer_from_thread(_download)
637
 
 
638
 
    @log_timing
639
 
    def create_directory(self, share_uuid, parent_uuid, name):
640
 
        """Creates a directory on the server."""
641
 
        r = self.defer_from_thread(self.factory.current_protocol.make_dir,
642
 
                                   share_str(share_uuid), str(parent_uuid),
643
 
                                   name)
644
 
        return uuid.UUID(r.new_id)
645
 
 
646
 
    @log_timing
647
 
    def create_file(self, share_uuid, parent_uuid, name):
648
 
        """Creates a file on the server."""
649
 
        r = self.defer_from_thread(self.factory.current_protocol.make_file,
650
 
                                   share_str(share_uuid), str(parent_uuid),
651
 
                                   name)
652
 
        return uuid.UUID(r.new_id)
653
 
 
654
 
    @log_timing
655
 
    def create_symlink(self, share_uuid, parent_uuid, name, target):
656
 
        """Creates a symlink on the server."""
657
 
        raise UnsupportedOperationError("Protocol does not support symlinks")
658
 
 
659
 
    @log_timing
660
 
    def upload_string(self, share_uuid, node_uuid, old_content_hash,
661
 
                      content_hash, content):
662
 
        """Uploads a string to the server as file content."""
663
 
        crc = crc32(content, 0)
664
 
        compressed_content = zlib.compress(content, 9)
665
 
        compressed = StringIO(compressed_content)
666
 
        self.defer_from_thread(self.factory.current_protocol.put_content,
667
 
                               share_str(share_uuid), str(node_uuid),
668
 
                               old_content_hash, content_hash,
669
 
                               crc, len(content), len(compressed_content),
670
 
                               compressed)
671
 
 
672
 
    @log_timing
673
 
    def upload_file(self, share_uuid, node_uuid, old_content_hash,
674
 
                    content_hash, filename):
675
 
        """Uploads a file to the server."""
676
 
        parent_dir = os.path.split(filename)[0]
677
 
        unique_filename = os.path.join(parent_dir, "." + str(uuid.uuid4()))
678
 
 
679
 
 
680
 
        class StagingFile(object):
681
 
            """An object which tracks data being compressed for staging."""
682
 
            def __init__(self, stream):
683
 
                """Initialize a compression object."""
684
 
                self.crc32 = 0
685
 
                self.enc = zlib.compressobj(9)
686
 
                self.size = 0
687
 
                self.compressed_size = 0
688
 
                self.stream = stream
689
 
 
690
 
            def write(self, bytes):
691
 
                """Compress bytes, keeping track of length and crc32."""
692
 
                self.size += len(bytes)
693
 
                self.crc32 = crc32(bytes, self.crc32)
694
 
                compressed_bytes = self.enc.compress(bytes)
695
 
                self.compressed_size += len(compressed_bytes)
696
 
                self.stream.write(compressed_bytes)
697
 
 
698
 
            def finish(self):
699
 
                """Finish staging compressed data."""
700
 
                compressed_bytes = self.enc.flush()
701
 
                self.compressed_size += len(compressed_bytes)
702
 
                self.stream.write(compressed_bytes)
703
 
 
704
 
        with open_file(unique_filename, "w+") as compressed:
705
 
            remove_file(unique_filename)
706
 
            with open_file(filename, "r") as original:
707
 
                staging = StagingFile(compressed)
708
 
                shutil.copyfileobj(original, staging)
709
 
            staging.finish()
710
 
            compressed.seek(0)
711
 
            self.defer_from_thread(self.factory.current_protocol.put_content,
712
 
                                   share_str(share_uuid), str(node_uuid),
713
 
                                   old_content_hash, content_hash,
714
 
                                   staging.crc32,
715
 
                                   staging.size, staging.compressed_size,
716
 
                                   compressed)
717
 
 
718
 
    @log_timing
719
 
    def move(self, share_uuid, parent_uuid, name, node_uuid):
720
 
        """Moves a file on the server."""
721
 
        self.defer_from_thread(self.factory.current_protocol.move,
722
 
                               share_str(share_uuid), str(node_uuid),
723
 
                               str(parent_uuid), name)
724
 
 
725
 
    @log_timing
726
 
    def unlink(self, share_uuid, node_uuid):
727
 
        """Unlinks a file on the server."""
728
 
        self.defer_from_thread(self.factory.current_protocol.unlink,
729
 
                               share_str(share_uuid), str(node_uuid))
730
 
 
731
 
    @log_timing
732
 
    def _get_node_hashes(self, share_uuid, node_uuids):
733
 
        """Fetches hashes for the given nodes."""
734
 
        share = share_str(share_uuid)
735
 
        queries = [(share, str(node_uuid), request.UNKNOWN_HASH) \
736
 
                   for node_uuid in node_uuids]
737
 
        d = self.factory.current_protocol.query(queries)
738
 
 
739
 
        @log_timing
740
 
        def _collect_hashes(multi_result):
741
 
            """Accumulate hashes from query replies."""
742
 
            hashes = {}
743
 
            for (success, value) in multi_result:
744
 
                if success:
745
 
                    for node_state in value.response:
746
 
                        node_uuid = uuid.UUID(node_state.node)
747
 
                        hashes[node_uuid] = node_state.hash
748
 
            return hashes
749
 
 
750
 
        d.addCallback(_collect_hashes)
751
 
        return d
752
 
 
753
 
    @log_timing
754
 
    def get_incoming_shares(self):
755
 
        """Returns a list of incoming shares as (name, uuid, accepted)
756
 
        tuples.
757
 
 
758
 
        """
759
 
        _list_shares = self.factory.current_protocol.list_shares
760
 
        r = self.defer_from_thread(_list_shares)
761
 
        return [(s.name, s.id, s.other_visible_name,
762
 
                 s.accepted, s.access_level) \
763
 
                for s in r.shares if s.direction == "to_me"]