1
# Copyright 2009-2015 Canonical
2
# Copyright 2015-2018 Chicharreros (https://launchpad.net/~chicharreros)
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU General Public License version 3, as published
6
# by the Free Software Foundation.
8
# This program is distributed in the hope that it will be useful, but
9
# WITHOUT ANY WARRANTY; without even the implied warranties of
10
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
11
# PURPOSE. See the GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License along
14
# with this program. If not, see <http://www.gnu.org/licenses/>.
16
"""Pretty API for protocol client."""
18
from __future__ import with_statement
28
from cStringIO import StringIO
29
from logging.handlers import RotatingFileHandler
30
from Queue import Queue
31
from threading import Lock
33
from dirspec.basedir import xdg_cache_home
34
from twisted.internet import reactor, defer
35
from twisted.internet.defer import inlineCallbacks, returnValue
36
from ubuntuone.storageprotocol.content_hash import crc32
37
from ubuntuone.storageprotocol.context import get_ssl_context
38
from u1sync.genericmerge import MergeNode
39
from u1sync.utils import should_sync
41
from ubuntuone.storageprotocol.client import (
42
StorageClientFactory, StorageClient)
43
from ubuntuone.storageprotocol import request, volumes
44
from ubuntuone.storageprotocol.delta import DIRECTORY as delta_DIR
45
from ubuntuone.storageprotocol.dircontent_pb2 import DIRECTORY, FILE
48
CONSUMER_KEY = "ubuntuone"
49
CONSUMER_SECRET = "hammertime"
51
u1sync_log_dir = os.path.join(xdg_cache_home, 'u1sync', 'log')
52
LOGFILENAME = os.path.join(u1sync_log_dir, 'u1sync.log')
53
if not os.path.exists(u1sync_log_dir):
54
os.makedirs(u1sync_log_dir)
55
u1_logger = logging.getLogger("u1sync.timing.log")
56
handler = RotatingFileHandler(LOGFILENAME)
57
u1_logger.addHandler(handler)
60
def share_str(share_uuid):
61
"""Converts a share UUID to a form the protocol likes."""
62
return str(share_uuid) if share_uuid is not None else request.ROOT
66
def wrapper(*arg, **kwargs):
68
ent = func(*arg, **kwargs)
70
u1_logger.debug('for %s %0.5f ms elapsed',
71
func.func_name, stop-start * 1000.0)
76
class ForcedShutdown(Exception):
77
"""Client shutdown forced."""
81
"""Wait object for blocking waits."""
84
"""Initializes the wait object."""
87
def wake(self, result):
88
"""Wakes the waiter with a result."""
89
self.queue.put((result, None))
91
def wakeAndRaise(self, exc_info):
92
"""Wakes the waiter, raising the given exception in it."""
93
self.queue.put((None, exc_info))
95
def wakeWithResult(self, func, *args, **kw):
96
"""Wakes the waiter with the result of the given function."""
98
result = func(*args, **kw)
100
self.wakeAndRaise(sys.exc_info())
105
"""Waits for wakeup."""
106
(result, exc_info) = self.queue.get()
109
raise exc_info[0], exc_info[1], exc_info[2]
116
class SyncStorageClient(StorageClient):
117
"""Simple client that calls a callback on connection."""
120
def connectionMade(self):
121
"""Setup and call callback."""
122
StorageClient.connectionMade(self)
123
if self.factory.current_protocol not in (None, self):
124
self.factory.current_protocol.transport.loseConnection()
125
self.factory.current_protocol = self
126
self.factory.observer.connected()
129
def connectionLost(self, reason=None):
130
"""Callback for established connection lost."""
131
StorageClient.connectionLost(self, reason)
132
if self.factory.current_protocol is self:
133
self.factory.current_protocol = None
134
self.factory.observer.disconnected(reason)
137
class SyncClientFactory(StorageClientFactory):
138
"""A cmd protocol factory."""
140
protocol = SyncStorageClient
143
def __init__(self, observer):
144
"""Create the factory"""
145
self.observer = observer
146
self.current_protocol = None
149
def clientConnectionFailed(self, connector, reason):
150
"""We failed at connecting."""
151
self.current_protocol = None
152
self.observer.connection_failed(reason)
155
class UnsupportedOperationError(Exception):
156
"""The operation is unsupported by the protocol version."""
159
class ConnectionError(Exception):
160
"""A connection error."""
163
class AuthenticationError(Exception):
164
"""An authentication error."""
167
class NoSuchShareError(Exception):
168
"""Error when there is no such share available."""
171
class CapabilitiesError(Exception):
172
"""A capabilities set/query related error."""
175
class Client(object):
176
"""U1 storage client facade."""
177
required_caps = frozenset([
178
"no-content", "account-info", "resumable-uploads",
179
"fix462230", "volumes", "generations",
182
def __init__(self, realm=None, reactor=reactor):
183
"""Create the instance.
185
'realm' is no longer used, but is left as param for API compatibility.
188
self.reactor = reactor
189
self.factory = SyncClientFactory(self)
191
self._status_lock = Lock()
192
self._status = "disconnected"
193
self._status_reason = None
194
self._status_waiting = []
195
self._active_waiters = set()
197
self.consumer_key = CONSUMER_KEY
198
self.consumer_secret = CONSUMER_SECRET
200
def force_shutdown(self):
201
"""Forces the client to shut itself down."""
202
with self._status_lock:
203
self._status = "forced_shutdown"
205
for waiter in self._active_waiters:
206
waiter.wakeAndRaise((ForcedShutdown("Forced shutdown"),
208
self._active_waiters.clear()
210
def _get_waiter_locked(self):
211
"""Gets a wait object for blocking waits. Should be called with the
215
if self._status == "forced_shutdown":
216
raise ForcedShutdown("Forced shutdown")
217
self._active_waiters.add(waiter)
220
def _get_waiter(self):
221
"""Get a wait object for blocking waits. Acquires the status lock."""
222
with self._status_lock:
223
return self._get_waiter_locked()
225
def _wait(self, waiter):
226
"""Waits for the waiter."""
230
with self._status_lock:
231
if waiter in self._active_waiters:
232
self._active_waiters.remove(waiter)
235
def _change_status(self, status, reason=None):
236
"""Changes the client status. Usually called from the reactor
240
with self._status_lock:
241
if self._status == "forced_shutdown":
243
self._status = status
244
self._status_reason = reason
245
waiting = self._status_waiting
246
self._status_waiting = []
247
for waiter in waiting:
248
waiter.wake((status, reason))
251
def _await_status_not(self, *ignore_statuses):
252
"""Blocks until the client status changes, returning the new status.
253
Should never be called from the reactor thread.
256
with self._status_lock:
257
status = self._status
258
reason = self._status_reason
259
while status in ignore_statuses:
260
waiter = self._get_waiter_locked()
261
self._status_waiting.append(waiter)
262
self._status_lock.release()
264
status, reason = self._wait(waiter)
266
self._status_lock.acquire()
267
if status == "forced_shutdown":
268
raise ForcedShutdown("Forced shutdown.")
269
return (status, reason)
271
def connection_failed(self, reason):
272
"""Notification that connection failed."""
273
self._change_status("disconnected", reason)
276
"""Notification that connection succeeded."""
277
self._change_status("connected")
279
def disconnected(self, reason):
280
"""Notification that we were disconnected."""
281
self._change_status("disconnected", reason)
283
def defer_from_thread(self, function, *args, **kwargs):
284
"""Do twisted defer magic to get results and show exceptions."""
285
waiter = self._get_waiter()
291
d = function(*args, **kwargs)
292
if isinstance(d, defer.Deferred):
293
d.addCallbacks(lambda r: waiter.wake((r, None, None)),
294
lambda f: waiter.wake((None, None, f)))
296
waiter.wake((d, None, None))
298
waiter.wake((None, sys.exc_info(), None))
300
self.reactor.callFromThread(runner)
301
result, exc_info, failure = self._wait(waiter)
304
raise exc_info[0], exc_info[1], exc_info[2]
308
failure.raiseException()
313
def connect(self, host, port):
314
"""Connect to host/port."""
317
self.reactor.connectTCP(host, port, self.factory)
318
self._connect_inner(_connect)
321
def connect_ssl(self, host, port, no_verify):
322
"""Connect to host/port using ssl."""
325
ctx = get_ssl_context(no_verify, host)
326
self.reactor.connectSSL(host, port, self.factory, ctx)
327
self._connect_inner(_connect)
330
def _connect_inner(self, _connect):
331
"""Helper function for connecting."""
332
self._change_status("connecting")
333
self.reactor.callFromThread(_connect)
334
status, reason = self._await_status_not("connecting")
335
if status != "connected":
336
raise ConnectionError(reason.value)
339
def disconnect(self):
341
if self.factory.current_protocol is not None:
342
self.reactor.callFromThread(
343
self.factory.current_protocol.transport.loseConnection)
344
self._await_status_not("connecting", "connected", "authenticated")
347
def simple_auth(self, username, password):
348
"""Perform simple authorisation."""
351
def _wrapped_authenticate():
352
"""Wrapped authenticate."""
354
yield self.factory.current_protocol.simple_authenticate(
357
self.factory.current_protocol.transport.loseConnection()
359
self._change_status("authenticated")
362
self.defer_from_thread(_wrapped_authenticate)
363
except request.StorageProtocolError as e:
364
raise AuthenticationError(e)
365
status, reason = self._await_status_not("connected")
366
if status != "authenticated":
367
raise AuthenticationError(reason.value)
370
def set_capabilities(self):
371
"""Set the capabilities with the server"""
373
client = self.factory.current_protocol
376
def set_caps_callback(req):
377
"Caps query succeeded"
379
de = defer.fail("The server denied setting %s capabilities" %
384
def query_caps_callback(req):
385
"Caps query succeeded"
387
set_d = client.set_caps(self.required_caps)
388
set_d.addCallback(set_caps_callback)
391
# the server don't have the requested capabilities.
392
# return a failure for now, in the future we might want
393
# to reconnect to another server
394
de = defer.fail("The server don't have the requested"
395
" capabilities: %s" % str(req.caps))
399
def _wrapped_set_capabilities():
400
"""Wrapped set_capabilities """
401
d = client.query_caps(self.required_caps)
402
d.addCallback(query_caps_callback)
406
self.defer_from_thread(_wrapped_set_capabilities)
407
except request.StorageProtocolError as e:
408
raise CapabilitiesError(e)
411
def get_root_info(self, volume_uuid):
412
"""Returns the UUID of the applicable share root."""
413
if volume_uuid is None:
414
_get_root = self.factory.current_protocol.get_root
415
root = self.defer_from_thread(_get_root)
416
return (uuid.UUID(root), True)
418
str_volume_uuid = str(volume_uuid)
419
volume = self._match_volume(
420
lambda v: str(v.volume_id) == str_volume_uuid)
421
if isinstance(volume, volumes.ShareVolume):
422
modify = volume.access_level == "Modify"
423
if isinstance(volume, volumes.UDFVolume):
425
return (uuid.UUID(str(volume.node_id)), modify)
428
def resolve_path(self, share_uuid, root_uuid, path):
429
"""Resolve path relative to the given root node."""
432
def _resolve_worker():
433
"""Path resolution worker."""
434
node_uuid = root_uuid
435
local_path = path.strip('/')
437
while local_path != '':
438
local_path, name = os.path.split(local_path)
439
hashes = yield self._get_node_hashes(share_uuid)
440
content_hash = hashes.get(root_uuid, None)
441
if content_hash is None:
442
raise KeyError("Content hash not available")
443
entries = yield self._get_dir_entries(share_uuid, root_uuid)
444
match_name = name.decode('utf-8')
446
for entry in entries:
447
if match_name == entry.name:
452
raise KeyError("Path not found")
454
node_uuid = uuid.UUID(match.node)
456
returnValue(node_uuid)
458
return self.defer_from_thread(_resolve_worker)
461
def find_volume(self, volume_spec):
462
"""Finds a share matching the given UUID. Looks at both share UUIDs
463
and root node UUIDs."""
466
return (str(s.volume_id) == volume_spec or
467
str(s.node_id) == volume_spec)
469
volume = self._match_volume(match)
470
return uuid.UUID(str(volume.volume_id))
473
def _match_volume(self, predicate):
474
"""Finds a volume matching the given predicate."""
475
_list_shares = self.factory.current_protocol.list_volumes
476
r = self.defer_from_thread(_list_shares)
477
for volume in r.volumes:
478
if predicate(volume):
480
raise NoSuchShareError()
483
def build_tree(self, share_uuid, root_uuid):
484
"""Builds and returns a tree representing the metadata for the given
485
subtree in the given share.
487
@param share_uuid: the share UUID or None for the user's volume
488
@param root_uuid: the root UUID of the subtree (must be a directory)
489
@return: a MergeNode tree
492
root = MergeNode(node_type=DIRECTORY, uuid=root_uuid)
496
def _get_root_content_hash():
497
"""Obtain the content hash for the root node."""
498
result = yield self._get_node_hashes(share_uuid)
499
returnValue(result.get(root_uuid, None))
501
root.content_hash = self.defer_from_thread(_get_root_content_hash)
502
if root.content_hash is None:
503
raise ValueError("No content available for node %s" % root_uuid)
507
def _get_children(parent_uuid, parent_content_hash):
508
"""Obtain a sequence of MergeNodes corresponding to a node's
512
entries = yield self._get_dir_entries(share_uuid, parent_uuid)
514
for entry in entries:
515
if should_sync(entry.name):
516
child = MergeNode(node_type=entry.node_type,
517
uuid=uuid.UUID(entry.node))
518
children[entry.name] = child
520
content_hashes = yield self._get_node_hashes(share_uuid)
521
for child in children.itervalues():
522
child.content_hash = content_hashes.get(child.uuid, None)
524
returnValue(children)
526
need_children = [root]
528
node = need_children.pop()
529
if node.content_hash is not None:
530
children = self.defer_from_thread(_get_children, node.uuid,
532
node.children = children
533
for child in children.itervalues():
534
if child.node_type == DIRECTORY:
535
need_children.append(child)
540
@defer.inlineCallbacks
541
def _get_dir_entries(self, share_uuid, node_uuid):
542
"""Get raw dir entries for the given directory."""
543
result = yield self.factory.current_protocol.get_delta(
544
share_str(share_uuid), from_scratch=True)
545
node_uuid = share_str(node_uuid)
547
for n in result.response:
548
if n.parent_id == node_uuid:
549
# adapt here some attrs so we don't need to change ALL the code
550
n.node_type = DIRECTORY if n.file_type == delta_DIR else FILE
553
defer.returnValue(children)
556
def download_string(self, share_uuid, node_uuid, content_hash):
557
"""Reads a file from the server into a string."""
559
self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
560
content_hash=content_hash, output=output)
561
return output.getValue()
564
def download_file(self, share_uuid, node_uuid, content_hash, filename):
565
"""Downloads a file from the server."""
566
partial_filename = "%s.u1partial" % filename
567
output = open(partial_filename, "w")
571
"""Renames the temporary file to the final name."""
573
os.rename(partial_filename, filename)
577
"""Deletes the temporary file."""
579
os.remove(partial_filename)
581
self._download_inner(share_uuid=share_uuid, node_uuid=node_uuid,
582
content_hash=content_hash, output=output,
583
on_success=rename_file, on_failure=delete_file)
586
def _download_inner(self, share_uuid, node_uuid, content_hash, output,
587
on_success=lambda: None, on_failure=lambda: None):
588
"""Helper function for content downloads."""
589
dec = zlib.decompressobj()
592
def write_data(data):
593
"""Helper which writes data to the output file."""
594
uncompressed_data = dec.decompress(data)
595
output.write(uncompressed_data)
598
def finish_download(value):
599
"""Helper which finishes the download."""
600
uncompressed_data = dec.flush()
601
output.write(uncompressed_data)
606
def abort_download(value):
607
"""Helper which aborts the download."""
614
_get_content = self.factory.current_protocol.get_content
615
d = _get_content(share_str(share_uuid), str(node_uuid),
616
content_hash, callback=write_data)
617
d.addCallbacks(finish_download, abort_download)
620
self.defer_from_thread(_download)
623
def create_directory(self, share_uuid, parent_uuid, name):
624
"""Creates a directory on the server."""
625
r = self.defer_from_thread(self.factory.current_protocol.make_dir,
626
share_str(share_uuid), str(parent_uuid),
628
return uuid.UUID(r.new_id)
631
def create_file(self, share_uuid, parent_uuid, name):
632
"""Creates a file on the server."""
633
r = self.defer_from_thread(self.factory.current_protocol.make_file,
634
share_str(share_uuid), str(parent_uuid),
636
return uuid.UUID(r.new_id)
639
def create_symlink(self, share_uuid, parent_uuid, name, target):
640
"""Creates a symlink on the server."""
641
raise UnsupportedOperationError("Protocol does not support symlinks")
644
def upload_string(self, share_uuid, node_uuid, old_content_hash,
645
content_hash, content):
646
"""Uploads a string to the server as file content."""
647
crc = crc32(content, 0)
648
compressed_content = zlib.compress(content, 9)
649
compressed = StringIO(compressed_content)
650
self.defer_from_thread(self.factory.current_protocol.put_content,
651
share_str(share_uuid), str(node_uuid),
652
old_content_hash, content_hash,
653
crc, len(content), len(compressed_content),
657
def upload_file(self, share_uuid, node_uuid, old_content_hash,
658
content_hash, filename):
659
"""Uploads a file to the server."""
660
parent_dir = os.path.split(filename)[0]
661
unique_filename = os.path.join(parent_dir, "." + str(uuid.uuid4()))
663
class StagingFile(object):
664
"""An object which tracks data being compressed for staging."""
665
def __init__(self, stream):
666
"""Initialize a compression object."""
668
self.enc = zlib.compressobj(9)
670
self.compressed_size = 0
673
def write(self, bytes):
674
"""Compress bytes, keeping track of length and crc32."""
675
self.size += len(bytes)
676
self.crc32 = crc32(bytes, self.crc32)
677
compressed_bytes = self.enc.compress(bytes)
678
self.compressed_size += len(compressed_bytes)
679
self.stream.write(compressed_bytes)
682
"""Finish staging compressed data."""
683
compressed_bytes = self.enc.flush()
684
self.compressed_size += len(compressed_bytes)
685
self.stream.write(compressed_bytes)
687
with open(unique_filename, "w+") as compressed:
688
os.remove(unique_filename)
689
with open(filename, "r") as original:
690
staging = StagingFile(compressed)
691
shutil.copyfileobj(original, staging)
694
self.defer_from_thread(self.factory.current_protocol.put_content,
695
share_str(share_uuid), str(node_uuid),
696
old_content_hash, content_hash,
698
staging.size, staging.compressed_size,
702
def move(self, share_uuid, parent_uuid, name, node_uuid):
703
"""Moves a file on the server."""
704
self.defer_from_thread(self.factory.current_protocol.move,
705
share_str(share_uuid), str(node_uuid),
706
str(parent_uuid), name)
709
def unlink(self, share_uuid, node_uuid):
710
"""Unlinks a file on the server."""
711
self.defer_from_thread(self.factory.current_protocol.unlink,
712
share_str(share_uuid), str(node_uuid))
715
@defer.inlineCallbacks
716
def _get_node_hashes(self, share_uuid):
717
"""Fetches hashes for the given nodes."""
718
result = yield self.factory.current_protocol.get_delta(
719
share_str(share_uuid), from_scratch=True)
721
for fid in result.response:
722
node_uuid = uuid.UUID(fid.node_id)
723
hashes[node_uuid] = fid.content_hash
724
defer.returnValue(hashes)
727
def get_incoming_shares(self):
728
"""Returns a list of incoming shares as (name, uuid, accepted)
732
_list_shares = self.factory.current_protocol.list_shares
733
r = self.defer_from_thread(_list_shares)
734
return [(s.name, s.id, s.other_visible_name,
735
s.accepted, s.access_level)
736
for s in r.shares if s.direction == "to_me"]