1
# canonical.ubuntuone.storage.u1sync.client
3
# Client/protocol end of u1sync
5
# Author: Lucio Torre <lucio.torre@canonical.com>
6
# Author: Tim Cole <tim.cole@canonical.com>
8
# Copyright 2009 Canonical Ltd.
10
# This program is free software: you can redistribute it and/or modify it
11
# under the terms of the GNU General Public License version 3, as published
12
# by the Free Software Foundation.
14
# This program is distributed in the hope that it will be useful, but
15
# WITHOUT ANY WARRANTY; without even the implied warranties of
16
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
17
# PURPOSE. See the GNU General Public License for more details.
19
# You should have received a copy of the GNU General Public License along
20
# with this program. If not, see <http://www.gnu.org/licenses/>.
21
"""Pretty API for protocol client."""
23
from __future__ import with_statement
28
from Queue import Queue, Empty
29
from threading import Thread, Lock
33
from cStringIO import StringIO
35
from twisted.internet import reactor, defer, ssl
36
from twisted.internet.defer import inlineCallbacks, returnValue
37
from canonical.ubuntuone.storage.protocol.hash import crc32
38
from canonical.ubuntuone.oauthdesktop.config import get_config \
40
from canonical.ubuntuone.oauthdesktop.auth import AuthorisationClient
41
from canonical.ubuntuone.storage.u1sync.genericmerge import MergeNode
42
from canonical.ubuntuone.storage.u1sync.utils import should_sync
44
CONSUMER_KEY = "ubuntuone"
46
from canonical.ubuntuone.storage.protocol.oauth import OAuthConsumer
47
from canonical.ubuntuone.storage.protocol.client import (
48
StorageClientFactory, StorageClient)
49
from canonical.ubuntuone.storage.protocol import request
50
from canonical.ubuntuone.storage.protocol.dircontent_pb2 import \
51
DirectoryContent, DIRECTORY
54
def share_str(share_uuid):
55
"""Converts a share UUID to a form the protocol likes."""
56
return str(share_uuid) if share_uuid is not None else request.ROOT
59
class SyncStorageClient(StorageClient):
60
"""Simple client that calls a callback on connection."""
62
def connectionMade(self):
63
"""Setup and call callback."""
64
StorageClient.connectionMade(self)
65
if self.factory.current_protocol not in (None, self):
66
self.factory.current_protocol.transport.loseConnection()
67
self.factory.current_protocol = self
68
self.factory.observer.connected()
70
def connectionLost(self, reason=None):
71
"""Callback for established connection lost"""
72
if self.factory.current_protocol is self:
73
self.factory.current_protocol = None
74
self.factory.observer.disconnected(reason)
77
class SyncClientFactory(StorageClientFactory):
78
"""A cmd protocol factory."""
79
# no init: pylint: disable-msg=W0232
81
protocol = SyncStorageClient
83
def __init__(self, observer):
84
"""Create the factory"""
85
self.observer = observer
86
self.current_protocol = None
88
def clientConnectionFailed(self, connector, reason):
89
"""We failed at connecting."""
90
self.current_protocol = None
91
self.observer.connection_failed(reason)
94
class UnsupportedOperationError(Exception):
95
"""The operation is unsupported by the protocol version."""
98
class ConnectionError(Exception):
99
"""A connection error."""
102
class AuthenticationError(Exception):
103
"""An authentication error."""
106
class NoSuchShareError(Exception):
107
"""Error when there is no such share available."""
110
class Client(object):
111
"""U1 storage client facade."""
113
def __init__(self, realm):
114
"""Create the instance."""
116
self.thread = Thread(target=self._run)
117
self.thread.setDaemon(True)
118
self.factory = SyncClientFactory(self)
120
self._status_lock = Lock()
121
self._status = "disconnected"
122
self._status_reason = None
123
self._status_waiting = []
127
oauth_config = get_oauth_config()
128
if oauth_config.has_section(realm):
129
config_section = realm
130
elif self.realm.startswith("http://localhost") and \
131
oauth_config.has_section("http://localhost"):
132
config_section = "http://localhost"
134
config_section = "default"
136
def get_oauth_option(option):
137
"""Retrieves an option from oauth config."""
139
return oauth_config.get(config_section, option)
140
except ConfigParser.NoOptionError:
141
return oauth_config.get("default", option)
143
def get_oauth_url(option):
144
"""Retrieves an absolutized URL from the OAuth config."""
145
suffix = get_oauth_option(option)
146
return urlparse.urljoin(realm, suffix)
148
self.consumer_key = CONSUMER_KEY
149
self.consumer_secret = get_oauth_option("consumer_secret")
151
self.request_token_url = get_oauth_url("request_token_url")
152
self.user_authorisation_url = get_oauth_url("user_authorisation_url")
153
self.access_token_url = get_oauth_url("access_token_url")
155
def obtain_oauth_token(self, create_token):
156
"""Obtains an oauth token, optionally creating one if requried."""
157
token_result = Queue()
159
def have_token(token):
160
"""When a token is available."""
161
token_result.put(token)
164
"""When no token is available."""
165
token_result.put(None)
167
oauth_client = AuthorisationClient(realm=self.realm,
169
self.request_token_url,
170
user_authorisation_url=
171
self.user_authorisation_url,
173
self.access_token_url,
174
consumer_key=self.consumer_key,
176
self.consumer_secret,
177
callback_parent=have_token,
178
callback_denied=no_token,
179
do_login=create_token)
182
"""Obtains or creates a token."""
184
oauth_client.clear_token()
185
oauth_client.ensure_access_token()
187
reactor.callFromThread(_obtain_token)
188
token = token_result.get()
190
raise AuthenticationError("Unable to obtain OAuth token.")
193
def _change_status(self, status, reason=None):
194
"""Changes the client status. Usually called from the reactor
198
with self._status_lock:
199
self._status = status
200
self._status_reason = reason
201
waiting = self._status_waiting
203
self._status_waiting = []
204
for waiter in waiting:
205
waiter.put((status, reason))
207
def _await_status_not(self, *ignore_statuses):
208
"""Blocks until the client status changes, returning the new status.
209
Should never be called from the reactor thread.
212
with self._status_lock:
213
status = self._status
214
reason = self._status_reason
215
while status in ignore_statuses:
217
self._status_waiting.append(waiter)
218
self._status_lock.release()
220
status, reason = waiter.get()
222
self._status_lock.acquire()
223
return (status, reason)
225
def connection_failed(self, reason):
226
"""Notification that connection failed."""
227
self._change_status("disconnected", reason)
230
"""Notification that connection succeeded."""
231
self._change_status("connected")
233
def disconnected(self, reason):
234
"""Notification that we were disconnected."""
235
self._change_status("disconnected", reason)
238
"""Run the reactor in bg."""
239
reactor.run(installSignalHandlers=False)
242
"""Start the reactor thread."""
246
"""Shut down the reactor."""
247
reactor.callWhenRunning(reactor.stop)
248
self.thread.join(1.0)
250
def defer_from_thread(self, function, *args, **kwargs):
251
"""Do twisted defer magic to get results and show exceptions."""
256
# we do want to catch all
257
# no init: pylint: disable-msg=W0703
259
d = function(*args, **kwargs)
260
if isinstance(d, defer.Deferred):
261
d.addCallbacks(lambda r: queue.put((r, None, None)),
262
lambda f: queue.put((None, None, f)))
264
queue.put((d, None, None))
266
queue.put((None, sys.exc_info(), None))
268
reactor.callFromThread(runner)
271
# poll with a timeout so that interrupts are still serviced
272
result, exc_info, failure = queue.get(True, 1)
274
except Empty: # pylint: disable-msg=W0704
277
raise exc_info[1], None, exc_info[2]
279
failure.raiseException()
283
def connect(self, host, port):
284
"""Connect to host/port."""
287
reactor.connectTCP(host, port, self.factory)
288
self._connect_inner(_connect)
290
def connect_ssl(self, host, port):
291
"""Connect to host/port using ssl."""
294
reactor.connectSSL(host, port, self.factory,
295
ssl.ClientContextFactory())
296
self._connect_inner(_connect)
298
def _connect_inner(self, _connect):
299
"""Helper function for connecting."""
300
self._change_status("connecting")
301
reactor.callFromThread(_connect)
302
status, reason = self._await_status_not("connecting")
303
if status != "connected":
304
raise ConnectionError(reason.value)
306
def disconnect(self):
308
if self.factory.current_protocol is not None:
309
reactor.callFromThread(
310
self.factory.current_protocol.transport.loseConnection)
311
self._await_status_not("connecting", "connected", "authenticated")
313
def oauth_from_token(self, token):
314
"""Perform OAuth authorisation using an existing token."""
316
consumer = OAuthConsumer(self.consumer_key, self.consumer_secret)
318
def _auth_successful(value):
319
"""Callback for successful auth. Changes status to
321
self._change_status("authenticated")
324
def _auth_failed(value):
325
"""Callback for failed auth. Disconnects."""
326
self.factory.current_protocol.transport.loseConnection()
329
def _wrapped_authenticate():
330
"""Wrapped authenticate."""
331
d = self.factory.current_protocol.oauth_authenticate(consumer,
333
d.addCallbacks(_auth_successful, _auth_failed)
337
self.defer_from_thread(_wrapped_authenticate)
338
except request.StorageProtocolError, e:
339
raise AuthenticationError(e)
340
status, reason = self._await_status_not("connected")
341
if status != "authenticated":
342
raise AuthenticationError(reason.value)
344
def get_root_info(self, share_uuid):
345
"""Returns the UUID of the applicable share root."""
346
if share_uuid is None:
347
_get_root = self.factory.current_protocol.get_root
348
root = self.defer_from_thread(_get_root)
349
return (uuid.UUID(root), True)
351
str_share_uuid = str(share_uuid)
352
share = self._match_share(lambda s: str(s.id) == str_share_uuid)
353
return (uuid.UUID(str(share.subtree)),
354
share.access_level == "Modify")
356
def find_share(self, share_spec):
357
"""Finds a share matching the given UUID. Looks at both share UUIDs
358
and root node UUIDs."""
359
share = self._match_share(lambda s: str(s.id) == share_spec or \
360
str(s.subtree) == share_spec)
361
return uuid.UUID(str(share.id))
363
def _match_share(self, predicate):
364
"""Finds a share matching the given predicate."""
365
_list_shares = self.factory.current_protocol.list_shares
366
r = self.defer_from_thread(_list_shares)
367
for share in r.shares:
368
if predicate(share) and share.direction == "to_me":
370
raise NoSuchShareError()
372
def build_tree(self, share_uuid, root_uuid):
373
"""Builds and returns a tree representing the metadata for the given
374
subtree in the given share.
376
@param share_uuid: the share UUID or None for the user's volume
377
@param root_uuid: the root UUID of the subtree (must be a directory)
378
@return: a MergeNode tree
381
root = MergeNode(node_type=DIRECTORY, uuid=root_uuid)
384
def _get_root_content_hash():
385
"""Obtain the content hash for the root node."""
386
result = yield self._get_node_hashes(share_uuid, [root_uuid])
387
returnValue(result.get(root_uuid, None))
389
root.content_hash = self.defer_from_thread(_get_root_content_hash)
392
def _get_children(parent_uuid, parent_content_hash):
393
"""Obtain a sequence of MergeNodes corresponding to a node's
397
entries = yield self._get_raw_dir_entries(share_uuid,
401
for entry in entries:
402
if should_sync(entry.name):
403
child = MergeNode(node_type=entry.node_type,
404
uuid=uuid.UUID(entry.node))
405
children[entry.name] = child
407
child_uuids = [child.uuid for child in children.itervalues()]
408
content_hashes = yield self._get_node_hashes(share_uuid,
410
for child in children.itervalues():
411
child.content_hash = content_hashes.get(child.uuid, None)
413
returnValue(children)
415
need_children = [root]
417
node = need_children.pop()
418
if node.content_hash is not None:
419
children = self.defer_from_thread(_get_children, node.uuid,
421
node.children = children
422
for child in children.itervalues():
423
if child.node_type == DIRECTORY:
424
need_children.append(child)
428
def _get_raw_dir_entries(self, share_uuid, node_uuid, content_hash):
429
"""Gets raw dir entries for the given directory."""
430
d = self.factory.current_protocol.get_content(share_str(share_uuid),
433
d.addCallback(lambda c: zlib.decompress(c.data))
435
def _parse_content(raw_content):
436
"""Parses directory content into a list of entry objects."""
437
unserialized_content = DirectoryContent()
438
unserialized_content.ParseFromString(raw_content)
439
return list(unserialized_content.entries)
441
d.addCallback(_parse_content)
444
def download_string(self, share_uuid, node_uuid, content_hash):
445
"""Reads a file from the server into a string."""
447
self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
448
content_hash=content_hash, output=output)
449
return output.getValue()
451
def download_file(self, share_uuid, node_uuid, content_hash, filename):
452
"""Downloads a file from the server."""
453
partial_filename = "%s.u1partial" % filename
454
output = open(partial_filename, "w")
457
"""Renames the temporary file to the final name."""
459
os.rename(partial_filename, filename)
462
"""Deletes the temporary file."""
464
os.unlink(partial_filename)
466
self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
467
content_hash=content_hash, output=output,
468
on_success=rename_file, on_failure=delete_file)
470
def _download_inner(self, share_uuid, node_uuid, content_hash, output,
471
on_success=lambda: None, on_failure=lambda: None):
472
"""Helper function for content downloads."""
473
dec = zlib.decompressobj()
475
def write_data(data):
476
"""Helper which writes data to the output file."""
477
uncompressed_data = dec.decompress(data)
478
output.write(uncompressed_data)
480
def finish_download(value):
481
"""Helper which finishes the download."""
482
uncompressed_data = dec.flush()
483
output.write(uncompressed_data)
487
def abort_download(value):
488
"""Helper which aborts the download."""
494
_get_content = self.factory.current_protocol.get_content
495
d = _get_content(share_str(share_uuid), str(node_uuid),
496
content_hash, callback=write_data)
497
d.addCallbacks(finish_download, abort_download)
500
self.defer_from_thread(_download)
502
def create_directory(self, share_uuid, parent_uuid, name):
503
"""Creates a directory on the server."""
504
r = self.defer_from_thread(self.factory.current_protocol.make_dir,
505
share_str(share_uuid), str(parent_uuid),
507
return uuid.UUID(r.new_id)
509
def create_file(self, share_uuid, parent_uuid, name):
510
"""Creates a file on the server."""
511
r = self.defer_from_thread(self.factory.current_protocol.make_file,
512
share_str(share_uuid), str(parent_uuid),
514
return uuid.UUID(r.new_id)
516
def create_symlink(self, share_uuid, parent_uuid, name, target):
517
"""Creates a symlink on the server."""
518
raise UnsupportedOperationError("Protocol does not support symlinks")
520
def upload_string(self, share_uuid, node_uuid, old_content_hash,
521
content_hash, content):
522
"""Uploads a string to the server as file content."""
523
crc32 = crc32(content, 0)
524
compressed_content = zlib.compress(content, 9)
525
compressed = StringIO(compressed_content)
526
self.defer_from_thread(self.factory.current_protocol.put_content,
527
share_str(share_uuid), str(node_uuid),
528
old_content_hash, content_hash,
529
crc32, len(content), len(compressed_content),
532
def upload_file(self, share_uuid, node_uuid, old_content_hash,
533
content_hash, filename):
534
"""Uploads a file to the server."""
535
parent_dir = os.path.split(filename)[0]
536
unique_filename = os.path.join(parent_dir, "." + str(uuid.uuid4()))
539
class StagingFile(object):
540
"""An object which tracks data being compressed for staging."""
541
def __init__(self, stream):
542
"""Initialize a compression object."""
544
self.enc = zlib.compressobj(9)
546
self.compressed_size = 0
549
def write(self, bytes):
550
"""Compress bytes, keeping track of length and crc32."""
551
self.size += len(bytes)
552
self.crc32 = crc32(bytes, self.crc32)
553
compressed_bytes = self.enc.compress(bytes)
554
self.compressed_size += len(compressed_bytes)
555
self.stream.write(compressed_bytes)
558
"""Finish staging compressed data."""
559
compressed_bytes = self.enc.flush()
560
self.compressed_size += len(compressed_bytes)
561
self.stream.write(compressed_bytes)
563
with open(unique_filename, "w+") as compressed:
564
os.unlink(unique_filename)
565
with open(filename, "r") as original:
566
staging = StagingFile(compressed)
567
shutil.copyfileobj(original, staging)
570
self.defer_from_thread(self.factory.current_protocol.put_content,
571
share_str(share_uuid), str(node_uuid),
572
old_content_hash, content_hash,
574
staging.size, staging.compressed_size,
577
def move(self, share_uuid, parent_uuid, name, node_uuid):
578
"""Moves a file on the server."""
579
self.defer_from_thread(self.factory.current_protocol.move,
580
share_str(share_uuid), str(node_uuid),
581
str(parent_uuid), name)
583
def unlink(self, share_uuid, node_uuid):
584
"""Unlinks a file on the server."""
585
self.defer_from_thread(self.factory.current_protocol.unlink,
586
share_str(share_uuid), str(node_uuid))
588
def _get_node_hashes(self, share_uuid, node_uuids):
589
"""Fetches hashes for the given nodes."""
590
share = share_str(share_uuid)
591
queries = [(share, str(node_uuid), request.UNKNOWN_HASH) \
592
for node_uuid in node_uuids]
593
d = self.factory.current_protocol.query(queries)
595
def _collect_hashes(multi_result):
596
"""Accumulate hashes from query replies."""
598
for (success, value) in multi_result:
600
for node_state in value.response:
601
node_uuid = uuid.UUID(node_state.node)
602
hashes[node_uuid] = node_state.hash
605
d.addCallback(_collect_hashes)
608
def get_incoming_shares(self):
609
"""Returns a list of incoming shares as (name, uuid, accepted)
613
_list_shares = self.factory.current_protocol.list_shares
614
r = self.defer_from_thread(_list_shares)
615
return [(s.name, s.id, s.other_visible_name,
616
s.accepted, s.access_level) \
617
for s in r.shares if s.direction == "to_me"]