1
# ubuntuone.u1sync.client
3
# Client/protocol end of u1sync
5
# Author: Lucio Torre <lucio.torre@canonical.com>
6
# Author: Tim Cole <tim.cole@canonical.com>
8
# Copyright 2009 Canonical Ltd.
10
# This program is free software: you can redistribute it and/or modify it
11
# under the terms of the GNU General Public License version 3, as published
12
# by the Free Software Foundation.
14
# This program is distributed in the hope that it will be useful, but
15
# WITHOUT ANY WARRANTY; without even the implied warranties of
16
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
17
# PURPOSE. See the GNU General Public License for more details.
19
# You should have received a copy of the GNU General Public License along
20
# with this program. If not, see <http://www.gnu.org/licenses/>.
21
"""Pretty API for protocol client."""
23
from __future__ import with_statement
28
from Queue import Queue
29
from threading import Lock
31
from cStringIO import StringIO
33
from twisted.internet import reactor, defer
34
from twisted.internet.defer import inlineCallbacks, returnValue
35
from ubuntuone.logger import LOGFOLDER
36
from ubuntuone.storageprotocol.content_hash import crc32
37
from ubuntuone.storageprotocol.context import get_ssl_context
38
from ubuntuone.u1sync.genericmerge import MergeNode
39
from ubuntuone.u1sync.utils import should_sync
41
CONSUMER_KEY = "ubuntuone"
42
CONSUMER_SECRET = "hammertime"
44
from oauth.oauth import OAuthConsumer
45
from ubuntuone.storageprotocol.client import (
46
StorageClientFactory, StorageClient)
47
from ubuntuone.storageprotocol import request, volumes
48
from ubuntuone.storageprotocol.dircontent_pb2 import \
49
DirectoryContent, DIRECTORY
52
from logging.handlers import RotatingFileHandler
55
from ubuntuone.platform import (
61
def share_str(share_uuid):
62
"""Converts a share UUID to a form the protocol likes."""
63
return str(share_uuid) if share_uuid is not None else request.ROOT
65
LOGFILENAME = os.path.join(LOGFOLDER, 'u1sync.log')
66
u1_logger = logging.getLogger("u1sync.timing.log")
67
handler = RotatingFileHandler(LOGFILENAME)
68
u1_logger.addHandler(handler)
71
def wrapper(*arg, **kwargs):
73
ent = func(*arg, **kwargs)
75
u1_logger.debug('for %s %0.5f ms elapsed' % (func.func_name, \
81
class ForcedShutdown(Exception):
82
"""Client shutdown forced."""
86
"""Wait object for blocking waits."""
89
"""Initializes the wait object."""
92
def wake(self, result):
93
"""Wakes the waiter with a result."""
94
self.queue.put((result, None))
96
def wakeAndRaise(self, exc_info):
97
"""Wakes the waiter, raising the given exception in it."""
98
self.queue.put((None, exc_info))
100
def wakeWithResult(self, func, *args, **kw):
101
"""Wakes the waiter with the result of the given function."""
103
result = func(*args, **kw)
105
self.wakeAndRaise(sys.exc_info())
110
"""Waits for wakeup."""
111
(result, exc_info) = self.queue.get()
114
raise exc_info[0], exc_info[1], exc_info[2]
121
class SyncStorageClient(StorageClient):
122
"""Simple client that calls a callback on connection."""
125
def connectionMade(self):
126
"""Setup and call callback."""
127
StorageClient.connectionMade(self)
128
if self.factory.current_protocol not in (None, self):
129
self.factory.current_protocol.transport.loseConnection()
130
self.factory.current_protocol = self
131
self.factory.observer.connected()
134
def connectionLost(self, reason=None):
135
"""Callback for established connection lost."""
136
StorageClient.connectionLost(self, reason)
137
if self.factory.current_protocol is self:
138
self.factory.current_protocol = None
139
self.factory.observer.disconnected(reason)
142
class SyncClientFactory(StorageClientFactory):
143
"""A cmd protocol factory."""
144
# no init: pylint: disable-msg=W0232
146
protocol = SyncStorageClient
149
def __init__(self, observer):
150
"""Create the factory"""
151
self.observer = observer
152
self.current_protocol = None
155
def clientConnectionFailed(self, connector, reason):
156
"""We failed at connecting."""
157
self.current_protocol = None
158
self.observer.connection_failed(reason)
161
class UnsupportedOperationError(Exception):
162
"""The operation is unsupported by the protocol version."""
165
class ConnectionError(Exception):
166
"""A connection error."""
169
class AuthenticationError(Exception):
170
"""An authentication error."""
173
class NoSuchShareError(Exception):
174
"""Error when there is no such share available."""
177
class CapabilitiesError(Exception):
178
"""A capabilities set/query related error."""
180
class Client(object):
181
"""U1 storage client facade."""
182
required_caps = frozenset(["no-content", "fix462230"])
184
def __init__(self, realm=None, reactor=reactor):
185
"""Create the instance.
187
'realm' is no longer used, but is left as param for API compatibility.
190
self.reactor = reactor
191
self.factory = SyncClientFactory(self)
193
self._status_lock = Lock()
194
self._status = "disconnected"
195
self._status_reason = None
196
self._status_waiting = []
197
self._active_waiters = set()
199
self.consumer_key = CONSUMER_KEY
200
self.consumer_secret = CONSUMER_SECRET
202
def force_shutdown(self):
203
"""Forces the client to shut itself down."""
204
with self._status_lock:
205
self._status = "forced_shutdown"
207
for waiter in self._active_waiters:
208
waiter.wakeAndRaise((ForcedShutdown("Forced shutdown"),
210
self._active_waiters.clear()
212
def _get_waiter_locked(self):
213
"""Gets a wait object for blocking waits. Should be called with the
217
if self._status == "forced_shutdown":
218
raise ForcedShutdown("Forced shutdown")
219
self._active_waiters.add(waiter)
222
def _get_waiter(self):
223
"""Get a wait object for blocking waits. Acquires the status lock."""
224
with self._status_lock:
225
return self._get_waiter_locked()
227
def _wait(self, waiter):
228
"""Waits for the waiter."""
232
with self._status_lock:
233
if waiter in self._active_waiters:
234
self._active_waiters.remove(waiter)
237
def _change_status(self, status, reason=None):
238
"""Changes the client status. Usually called from the reactor
242
with self._status_lock:
243
if self._status == "forced_shutdown":
245
self._status = status
246
self._status_reason = reason
247
waiting = self._status_waiting
248
self._status_waiting = []
249
for waiter in waiting:
250
waiter.wake((status, reason))
253
def _await_status_not(self, *ignore_statuses):
254
"""Blocks until the client status changes, returning the new status.
255
Should never be called from the reactor thread.
258
with self._status_lock:
259
status = self._status
260
reason = self._status_reason
261
while status in ignore_statuses:
262
waiter = self._get_waiter_locked()
263
self._status_waiting.append(waiter)
264
self._status_lock.release()
266
status, reason = self._wait(waiter)
268
self._status_lock.acquire()
269
if status == "forced_shutdown":
270
raise ForcedShutdown("Forced shutdown.")
271
return (status, reason)
273
def connection_failed(self, reason):
274
"""Notification that connection failed."""
275
self._change_status("disconnected", reason)
278
"""Notification that connection succeeded."""
279
self._change_status("connected")
281
def disconnected(self, reason):
282
"""Notification that we were disconnected."""
283
self._change_status("disconnected", reason)
285
def defer_from_thread(self, function, *args, **kwargs):
286
"""Do twisted defer magic to get results and show exceptions."""
287
waiter = self._get_waiter()
291
# we do want to catch all
292
# no init: pylint: disable-msg=W0703
294
d = function(*args, **kwargs)
295
if isinstance(d, defer.Deferred):
296
d.addCallbacks(lambda r: waiter.wake((r, None, None)),
297
lambda f: waiter.wake((None, None, f)))
299
waiter.wake((d, None, None))
301
waiter.wake((None, sys.exc_info(), None))
303
self.reactor.callFromThread(runner)
304
result, exc_info, failure = self._wait(waiter)
307
raise exc_info[0], exc_info[1], exc_info[2]
311
failure.raiseException()
316
def connect(self, host, port):
317
"""Connect to host/port."""
320
self.reactor.connectTCP(host, port, self.factory)
321
self._connect_inner(_connect)
324
def connect_ssl(self, host, port, no_verify):
325
"""Connect to host/port using ssl."""
328
ctx = get_ssl_context(no_verify)
329
self.reactor.connectSSL(host, port, self.factory, ctx)
330
self._connect_inner(_connect)
333
def _connect_inner(self, _connect):
334
"""Helper function for connecting."""
335
self._change_status("connecting")
336
self.reactor.callFromThread(_connect)
337
status, reason = self._await_status_not("connecting")
338
if status != "connected":
339
raise ConnectionError(reason.value)
342
def disconnect(self):
344
if self.factory.current_protocol is not None:
345
self.reactor.callFromThread(
346
self.factory.current_protocol.transport.loseConnection)
347
self._await_status_not("connecting", "connected", "authenticated")
350
def oauth_from_token(self, token, consumer=None):
351
"""Perform OAuth authorisation using an existing token."""
354
consumer = OAuthConsumer(self.consumer_key, self.consumer_secret)
356
def _auth_successful(value):
357
"""Callback for successful auth. Changes status to
359
self._change_status("authenticated")
362
def _auth_failed(value):
363
"""Callback for failed auth. Disconnects."""
364
self.factory.current_protocol.transport.loseConnection()
367
def _wrapped_authenticate():
368
"""Wrapped authenticate."""
369
d = self.factory.current_protocol.oauth_authenticate(consumer,
371
d.addCallbacks(_auth_successful, _auth_failed)
375
self.defer_from_thread(_wrapped_authenticate)
376
except request.StorageProtocolError, e:
377
raise AuthenticationError(e)
378
status, reason = self._await_status_not("connected")
379
if status != "authenticated":
380
raise AuthenticationError(reason.value)
383
def set_capabilities(self):
384
"""Set the capabilities with the server"""
386
client = self.factory.current_protocol
388
def set_caps_callback(req):
389
"Caps query succeeded"
391
de = defer.fail("The server denied setting %s capabilities" % \
396
def query_caps_callback(req):
397
"Caps query succeeded"
399
set_d = client.set_caps(self.required_caps)
400
set_d.addCallback(set_caps_callback)
403
# the server don't have the requested capabilities.
404
# return a failure for now, in the future we might want
405
# to reconnect to another server
406
de = defer.fail("The server don't have the requested"
407
" capabilities: %s" % str(req.caps))
411
def _wrapped_set_capabilities():
412
"""Wrapped set_capabilities """
413
d = client.query_caps(self.required_caps)
414
d.addCallback(query_caps_callback)
418
self.defer_from_thread(_wrapped_set_capabilities)
419
except request.StorageProtocolError, e:
420
raise CapabilitiesError(e)
423
def get_root_info(self, volume_uuid):
424
"""Returns the UUID of the applicable share root."""
425
if volume_uuid is None:
426
_get_root = self.factory.current_protocol.get_root
427
root = self.defer_from_thread(_get_root)
428
return (uuid.UUID(root), True)
430
str_volume_uuid = str(volume_uuid)
431
volume = self._match_volume(lambda v: \
432
str(v.volume_id) == str_volume_uuid)
433
if isinstance(volume, volumes.ShareVolume):
434
modify = volume.access_level == "Modify"
435
if isinstance(volume, volumes.UDFVolume):
437
return (uuid.UUID(str(volume.node_id)), modify)
440
def resolve_path(self, share_uuid, root_uuid, path):
441
"""Resolve path relative to the given root node."""
444
def _resolve_worker():
445
"""Path resolution worker."""
446
node_uuid = root_uuid
447
local_path = path.strip('/')
449
while local_path != '':
450
local_path, name = os.path.split(local_path)
451
hashes = yield self._get_node_hashes(share_uuid, [root_uuid])
452
content_hash = hashes.get(root_uuid, None)
453
if content_hash is None:
454
raise KeyError, "Content hash not available"
455
entries = yield self._get_raw_dir_entries(share_uuid,
458
match_name = name.decode('utf-8')
460
for entry in entries:
461
if match_name == entry.name:
466
raise KeyError, "Path not found"
468
node_uuid = uuid.UUID(match.node)
470
returnValue(node_uuid)
472
return self.defer_from_thread(_resolve_worker)
475
def find_volume(self, volume_spec):
476
"""Finds a share matching the given UUID. Looks at both share UUIDs
477
and root node UUIDs."""
478
volume = self._match_volume(lambda s: \
479
str(s.volume_id) == volume_spec or \
480
str(s.node_id) == volume_spec)
481
return uuid.UUID(str(volume.volume_id))
484
def _match_volume(self, predicate):
485
"""Finds a volume matching the given predicate."""
486
_list_shares = self.factory.current_protocol.list_volumes
487
r = self.defer_from_thread(_list_shares)
488
for volume in r.volumes:
489
if predicate(volume):
491
raise NoSuchShareError()
494
def build_tree(self, share_uuid, root_uuid):
495
"""Builds and returns a tree representing the metadata for the given
496
subtree in the given share.
498
@param share_uuid: the share UUID or None for the user's volume
499
@param root_uuid: the root UUID of the subtree (must be a directory)
500
@return: a MergeNode tree
503
root = MergeNode(node_type=DIRECTORY, uuid=root_uuid)
507
def _get_root_content_hash():
508
"""Obtain the content hash for the root node."""
509
result = yield self._get_node_hashes(share_uuid, [root_uuid])
510
returnValue(result.get(root_uuid, None))
512
root.content_hash = self.defer_from_thread(_get_root_content_hash)
513
if root.content_hash is None:
514
raise ValueError("No content available for node %s" % root_uuid)
518
def _get_children(parent_uuid, parent_content_hash):
519
"""Obtain a sequence of MergeNodes corresponding to a node's
523
entries = yield self._get_raw_dir_entries(share_uuid,
527
for entry in entries:
528
if should_sync(entry.name):
529
child = MergeNode(node_type=entry.node_type,
530
uuid=uuid.UUID(entry.node))
531
children[entry.name] = child
533
child_uuids = [child.uuid for child in children.itervalues()]
534
content_hashes = yield self._get_node_hashes(share_uuid,
536
for child in children.itervalues():
537
child.content_hash = content_hashes.get(child.uuid, None)
539
returnValue(children)
541
need_children = [root]
543
node = need_children.pop()
544
if node.content_hash is not None:
545
children = self.defer_from_thread(_get_children, node.uuid,
547
node.children = children
548
for child in children.itervalues():
549
if child.node_type == DIRECTORY:
550
need_children.append(child)
555
def _get_raw_dir_entries(self, share_uuid, node_uuid, content_hash):
556
"""Gets raw dir entries for the given directory."""
557
d = self.factory.current_protocol.get_content(share_str(share_uuid),
560
d.addCallback(lambda c: zlib.decompress(c.data))
562
def _parse_content(raw_content):
563
"""Parses directory content into a list of entry objects."""
564
unserialized_content = DirectoryContent()
565
unserialized_content.ParseFromString(raw_content)
566
return list(unserialized_content.entries)
568
d.addCallback(_parse_content)
572
def download_string(self, share_uuid, node_uuid, content_hash):
573
"""Reads a file from the server into a string."""
575
self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
576
content_hash=content_hash, output=output)
577
return output.getValue()
580
def download_file(self, share_uuid, node_uuid, content_hash, filename):
581
"""Downloads a file from the server."""
582
partial_filename = "%s.u1partial" % filename
583
output = open_file(partial_filename, "w")
587
"""Renames the temporary file to the final name."""
589
rename(partial_filename, filename)
593
"""Deletes the temporary file."""
595
remove_file(partial_filename)
597
self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
598
content_hash=content_hash, output=output,
599
on_success=rename_file, on_failure=delete_file)
602
def _download_inner(self, share_uuid, node_uuid, content_hash, output,
603
on_success=lambda: None, on_failure=lambda: None):
604
"""Helper function for content downloads."""
605
dec = zlib.decompressobj()
608
def write_data(data):
609
"""Helper which writes data to the output file."""
610
uncompressed_data = dec.decompress(data)
611
output.write(uncompressed_data)
614
def finish_download(value):
615
"""Helper which finishes the download."""
616
uncompressed_data = dec.flush()
617
output.write(uncompressed_data)
622
def abort_download(value):
623
"""Helper which aborts the download."""
630
_get_content = self.factory.current_protocol.get_content
631
d = _get_content(share_str(share_uuid), str(node_uuid),
632
content_hash, callback=write_data)
633
d.addCallbacks(finish_download, abort_download)
636
self.defer_from_thread(_download)
639
def create_directory(self, share_uuid, parent_uuid, name):
640
"""Creates a directory on the server."""
641
r = self.defer_from_thread(self.factory.current_protocol.make_dir,
642
share_str(share_uuid), str(parent_uuid),
644
return uuid.UUID(r.new_id)
647
def create_file(self, share_uuid, parent_uuid, name):
648
"""Creates a file on the server."""
649
r = self.defer_from_thread(self.factory.current_protocol.make_file,
650
share_str(share_uuid), str(parent_uuid),
652
return uuid.UUID(r.new_id)
655
def create_symlink(self, share_uuid, parent_uuid, name, target):
656
"""Creates a symlink on the server."""
657
raise UnsupportedOperationError("Protocol does not support symlinks")
660
def upload_string(self, share_uuid, node_uuid, old_content_hash,
661
content_hash, content):
662
"""Uploads a string to the server as file content."""
663
crc = crc32(content, 0)
664
compressed_content = zlib.compress(content, 9)
665
compressed = StringIO(compressed_content)
666
self.defer_from_thread(self.factory.current_protocol.put_content,
667
share_str(share_uuid), str(node_uuid),
668
old_content_hash, content_hash,
669
crc, len(content), len(compressed_content),
673
def upload_file(self, share_uuid, node_uuid, old_content_hash,
674
content_hash, filename):
675
"""Uploads a file to the server."""
676
parent_dir = os.path.split(filename)[0]
677
unique_filename = os.path.join(parent_dir, "." + str(uuid.uuid4()))
680
class StagingFile(object):
681
"""An object which tracks data being compressed for staging."""
682
def __init__(self, stream):
683
"""Initialize a compression object."""
685
self.enc = zlib.compressobj(9)
687
self.compressed_size = 0
690
def write(self, bytes):
691
"""Compress bytes, keeping track of length and crc32."""
692
self.size += len(bytes)
693
self.crc32 = crc32(bytes, self.crc32)
694
compressed_bytes = self.enc.compress(bytes)
695
self.compressed_size += len(compressed_bytes)
696
self.stream.write(compressed_bytes)
699
"""Finish staging compressed data."""
700
compressed_bytes = self.enc.flush()
701
self.compressed_size += len(compressed_bytes)
702
self.stream.write(compressed_bytes)
704
with open_file(unique_filename, "w+") as compressed:
705
remove_file(unique_filename)
706
with open_file(filename, "r") as original:
707
staging = StagingFile(compressed)
708
shutil.copyfileobj(original, staging)
711
self.defer_from_thread(self.factory.current_protocol.put_content,
712
share_str(share_uuid), str(node_uuid),
713
old_content_hash, content_hash,
715
staging.size, staging.compressed_size,
719
def move(self, share_uuid, parent_uuid, name, node_uuid):
720
"""Moves a file on the server."""
721
self.defer_from_thread(self.factory.current_protocol.move,
722
share_str(share_uuid), str(node_uuid),
723
str(parent_uuid), name)
726
def unlink(self, share_uuid, node_uuid):
727
"""Unlinks a file on the server."""
728
self.defer_from_thread(self.factory.current_protocol.unlink,
729
share_str(share_uuid), str(node_uuid))
732
def _get_node_hashes(self, share_uuid, node_uuids):
733
"""Fetches hashes for the given nodes."""
734
share = share_str(share_uuid)
735
queries = [(share, str(node_uuid), request.UNKNOWN_HASH) \
736
for node_uuid in node_uuids]
737
d = self.factory.current_protocol.query(queries)
740
def _collect_hashes(multi_result):
741
"""Accumulate hashes from query replies."""
743
for (success, value) in multi_result:
745
for node_state in value.response:
746
node_uuid = uuid.UUID(node_state.node)
747
hashes[node_uuid] = node_state.hash
750
d.addCallback(_collect_hashes)
754
def get_incoming_shares(self):
755
"""Returns a list of incoming shares as (name, uuid, accepted)
759
_list_shares = self.factory.current_protocol.list_shares
760
r = self.defer_from_thread(_list_shares)
761
return [(s.name, s.id, s.other_visible_name,
762
s.accepted, s.access_level) \
763
for s in r.shares if s.direction == "to_me"]