~hadware/magicicada-server/trusty-support

« back to all changes in this revision

Viewing changes to dev-scripts/cmd_client.py

  • Committer: Facundo Batista
  • Date: 2015-08-05 13:10:02 UTC
  • Revision ID: facundo@taniquetil.com.ar-20150805131002-he7b7k704d8o7js6
First released version.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
 
 
3
# Copyright 2008-2015 Canonical
 
4
#
 
5
# This program is free software: you can redistribute it and/or modify
 
6
# it under the terms of the GNU Affero General Public License as
 
7
# published by the Free Software Foundation, either version 3 of the
 
8
# License, or (at your option) any later version.
 
9
#
 
10
# This program is distributed in the hope that it will be useful,
 
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
13
# GNU Affero General Public License for more details.
 
14
#
 
15
# You should have received a copy of the GNU Affero General Public License
 
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
 
17
#
 
18
# For further info, check  http://launchpad.net/filesync-server
 
19
 
 
20
"""A simple ping client."""
 
21
 
 
22
import warnings
 
23
warnings.simplefilter("ignore")
 
24
 
 
25
import Queue
 
26
import cmd
 
27
import os
 
28
import shlex
 
29
import tempfile
 
30
import time
 
31
import traceback
 
32
import uuid
 
33
import zlib
 
34
 
 
35
from threading import Thread
 
36
from optparse import OptionParser
 
37
 
 
38
try:
 
39
    from twisted.internet import gireactor
 
40
    gireactor.install()
 
41
except ImportError:
 
42
    from twisted.internet import glib2reactor
 
43
    glib2reactor.install()
 
44
 
 
45
from dbus.mainloop.glib import DBusGMainLoop
 
46
DBusGMainLoop(set_as_default=True)
 
47
 
 
48
from twisted.internet import reactor, defer, ssl
 
49
from twisted.python.failure import Failure
 
50
from twisted.python.util import mergeFunctionMetadata
 
51
try:
 
52
    import gobject
 
53
    gobject.set_application_name('cmd_client')
 
54
except ImportError:
 
55
    pass
 
56
 
 
57
import _pythonpath  # NOQA
 
58
 
 
59
from ubuntuone.storageprotocol.client import (
 
60
    StorageClientFactory, StorageClient)
 
61
from ubuntuone.storageprotocol import request, dircontent_pb2, volumes
 
62
from ubuntuone.storageprotocol.dircontent_pb2 import \
 
63
    DirectoryContent, DIRECTORY
 
64
from ubuntuone.storageprotocol.content_hash import content_hash_factory, crc32
 
65
 
 
66
 
 
67
def show_volume(volume):
 
68
    """Show a volume."""
 
69
    if isinstance(volume, volumes.ShareVolume):
 
70
        print "Share %r (other: %s, access: %s, id: %s)" % (
 
71
            volume.share_name, volume.other_username,
 
72
            volume.access_level, volume.volume_id)
 
73
    elif isinstance(volume, volumes.UDFVolume):
 
74
        print "UDF %r (id: %s)" % (volume.suggested_path, volume.volume_id)
 
75
 
 
76
 
 
77
class CmdStorageClient(StorageClient):
 
78
    """Simple client that calls a callback on connection."""
 
79
 
 
80
    def connectionMade(self):
 
81
        """Setup and call callback."""
 
82
 
 
83
        StorageClient.connectionMade(self)
 
84
        if self.factory.current_protocol not in (None, self):
 
85
            self.factory.current_protocol.transport.loseConnection()
 
86
        self.factory.current_protocol = self
 
87
        self.factory.cmd.status = "connected"
 
88
        print "Connected."
 
89
 
 
90
    def connectionLost(self, reason=None):
 
91
        """Callback for connection lost"""
 
92
 
 
93
        if self.factory.current_protocol is self:
 
94
            self.factory.current_protocol = None
 
95
            self.factory.cmd.status = "disconnected"
 
96
            if reason is not None:
 
97
                print "Disconnected: %s" % reason.value
 
98
            else:
 
99
                print "Disconnected: no reason"
 
100
 
 
101
 
 
102
class CmdClientFactory(StorageClientFactory):
 
103
    """A cmd protocol factory."""
 
104
 
 
105
    protocol = CmdStorageClient
 
106
 
 
107
    def __init__(self, cmd):
 
108
        """Create the factory"""
 
109
        self.cmd = cmd
 
110
        self.current_protocol = None
 
111
 
 
112
    def clientConnectionFailed(self, connector, reason):
 
113
        """We failed at connecting."""
 
114
        print 'ERROR: Connection failed. Reason:', reason.value
 
115
        self.current_protocol = None
 
116
        self.cmd.status = "disconnected"
 
117
 
 
118
 
 
119
def split_args(args):
 
120
    """Split a string using shlex."""
 
121
    sh = shlex.shlex(args, "args", True)
 
122
    sh.wordchars = sh.wordchars + '-./'
 
123
    result = []
 
124
    part = sh.get_token()
 
125
    while part is not None:
 
126
        result.append(part)
 
127
        part = sh.get_token()
 
128
    return result
 
129
 
 
130
 
 
131
def parse_args(*args, **kwargs):
 
132
    """Decorates a method so that we can parse its arguments:
 
133
    Example:
 
134
    @parse_args(int, int):
 
135
    def p(self, one, two):
 
136
        print one + two
 
137
    o.p("10 10")
 
138
    will print 20.
 
139
    """
 
140
    def inner(method):
 
141
        """the actual decorator"""
 
142
        def parser(self, rest):
 
143
            """the parser"""
 
144
            parts = split_args(rest)
 
145
 
 
146
            if len(parts) != len(args):
 
147
                print (
 
148
                    "ERROR: Wrong number of arguments. Expected %i, got %i" % (
 
149
                        len(args), len(parts)))
 
150
                return
 
151
 
 
152
            result = []
 
153
            for i, (constructor, part) in enumerate(zip(args, parts)):
 
154
                try:
 
155
                    value = constructor(part)
 
156
                except ValueError:
 
157
                    print "ERROR: cant convert argument %i to %s" % (
 
158
                        i, constructor)
 
159
                    return
 
160
                result.append(value)
 
161
 
 
162
            return method(self, *result)
 
163
        return mergeFunctionMetadata(method, parser)
 
164
    return inner
 
165
 
 
166
 
 
167
def require_connection(method):
 
168
    """This decorator ensures functions that require a connection dont
 
169
    get called without one"""
 
170
 
 
171
    def decorator(self, *args):
 
172
        """inner"""
 
173
        if self.status != "connected":
 
174
            print "ERROR: Must be connected."
 
175
            return
 
176
        else:
 
177
            return method(self, *args)
 
178
    return mergeFunctionMetadata(method, decorator)
 
179
 
 
180
 
 
181
def show_exception(function):
 
182
    """Trap exceptions and print them."""
 
183
    def decorator(*args, **kwargs):
 
184
        """inner"""
 
185
        # we do want to catch all
 
186
        try:
 
187
            function(*args, **kwargs)
 
188
        except Exception:
 
189
            traceback.print_exc()
 
190
    return mergeFunctionMetadata(function, decorator)
 
191
 
 
192
 
 
193
class ClientCmd(cmd.Cmd):
 
194
    """An interactive shell to manipulate the server."""
 
195
 
 
196
    use_rawinput = False
 
197
 
 
198
    def __init__(self, username, password):
 
199
        """Create the instance."""
 
200
 
 
201
        cmd.Cmd.__init__(self)
 
202
        self.thread = Thread(target=self._run)
 
203
        self.thread.setDaemon(True)
 
204
        self.thread.start()
 
205
        self.factory = CmdClientFactory(self)
 
206
        self.connected = False
 
207
        self.status = "disconnected"
 
208
        self.cwd = "/"
 
209
        self.volume = request.ROOT
 
210
        self.volume_root = None
 
211
        self.queue = Queue.Queue()
 
212
 
 
213
        self.username = username
 
214
        self.password = password
 
215
 
 
216
        self.volumes = set()
 
217
        self.shares = set()
 
218
 
 
219
    def _run(self):
 
220
        """Run the reactor in bg."""
 
221
        reactor.run(installSignalHandlers=False)
 
222
 
 
223
    @property
 
224
    def prompt(self):
 
225
        """Our prompt is our path."""
 
226
        return "%s $ " % self.cwd
 
227
 
 
228
    def emptyline(self):
 
229
        """We do nothing on an empty line."""
 
230
        return
 
231
 
 
232
    def defer_from_thread(self, function, *args, **kwargs):
 
233
        """Do twisted defer magic to get results and show exceptions."""
 
234
 
 
235
        queue = Queue.Queue()
 
236
 
 
237
        def runner():
 
238
            """inner."""
 
239
            # we do want to catch all
 
240
            try:
 
241
                d = function(*args, **kwargs)
 
242
                if isinstance(d, defer.Deferred):
 
243
                    d.addBoth(queue.put)
 
244
                else:
 
245
                    queue.put(d)
 
246
            except Exception, e:
 
247
                queue.put(e)
 
248
 
 
249
        reactor.callFromThread(runner)
 
250
        result = queue.get()
 
251
        if isinstance(result, Exception):
 
252
            raise result
 
253
        elif isinstance(result, Failure):
 
254
            result.raiseException()
 
255
        else:
 
256
            return result
 
257
 
 
258
    def get_cwd_id(self):
 
259
        """Get the id of the current working directory."""
 
260
 
 
261
        parts = [part for part in self.cwd.split("/") if part]
 
262
        # this will block forever if we didnt authenticate
 
263
        parent_id = self.get_root()
 
264
        for part in parts:
 
265
            if not self.is_dir(parent_id, part):
 
266
                raise ValueError("cwd is not a directory")
 
267
            parent_id = self.get_child_id(parent_id, part)
 
268
 
 
269
        return parent_id
 
270
 
 
271
    def get_root(self):
 
272
        """Get the root id."""
 
273
        if self.volume_root:
 
274
            return self.volume_root
 
275
        else:
 
276
            return self.defer_from_thread(
 
277
                self.factory.current_protocol.get_root)
 
278
 
 
279
    def get_id_from_filename(self, filename):
 
280
        """Get a node id from a filename."""
 
281
 
 
282
        root = self.cwd
 
283
        parent_id = self.get_root()
 
284
        if filename and filename[0] == "/":
 
285
            newdir = os.path.normpath(filename)
 
286
        else:
 
287
            newdir = os.path.normpath(os.path.join(root, filename))
 
288
        parts = [part for part in newdir.split("/") if part]
 
289
        if not parts:
 
290
            return parent_id
 
291
        file = parts[-1]
 
292
        parts = parts[:-1]
 
293
 
 
294
        for part in parts:
 
295
            if not self.is_dir(parent_id, part):
 
296
                raise ValueError("not a directory")
 
297
 
 
298
            parent_id = self.get_child_id(parent_id, part)
 
299
        return self.get_child_id(parent_id, file)
 
300
 
 
301
    def is_dir(self, parent_id, name):
 
302
        """Is name inside of parent_id a directory?"""
 
303
 
 
304
        content = self.get_content(parent_id)
 
305
        unserialized_content = DirectoryContent()
 
306
        unserialized_content.ParseFromString(content)
 
307
        for entry in unserialized_content.entries:
 
308
            if entry.name == name and entry.node_type == DIRECTORY:
 
309
                return True
 
310
        return False
 
311
 
 
312
    def get_child_id(self, parent_id, name):
 
313
        """Get the node id of name inside of parent_id."""
 
314
        content = self.get_content(parent_id)
 
315
        unserialized_content = DirectoryContent()
 
316
        unserialized_content.ParseFromString(content)
 
317
        for entry in unserialized_content.entries:
 
318
            if entry.name == name:
 
319
                return entry.node
 
320
        raise ValueError("not found")
 
321
 
 
322
    def get_file(self, filename):
 
323
        """Get the content of filename."""
 
324
 
 
325
        node_id = self.get_id_from_filename(filename)
 
326
        content = self.get_content(node_id)
 
327
        return content
 
328
 
 
329
    def get_hash(self, node_id):
 
330
        """Get the hash of node_id."""
 
331
        def _got_query(query):
 
332
            """deferred part."""
 
333
            message = query[0][1].response[0]
 
334
            return message.hash
 
335
 
 
336
        def _query():
 
337
            """deferred part."""
 
338
            d = self.factory.current_protocol.query(
 
339
                [(self.volume, node_id, request.UNKNOWN_HASH)]
 
340
            )
 
341
            d.addCallback(_got_query)
 
342
            return d
 
343
        return self.defer_from_thread(_query)
 
344
 
 
345
    def get_content(self, node_id):
 
346
        """Get the content of node_id."""
 
347
 
 
348
        hash = self.get_hash(node_id)
 
349
 
 
350
        def _get_content():
 
351
            """deferred part."""
 
352
            d = self.factory.current_protocol.get_content(self.volume,
 
353
                                                          node_id, hash)
 
354
            return d
 
355
 
 
356
        content = self.defer_from_thread(_get_content)
 
357
        return zlib.decompress(content.data)
 
358
 
 
359
    def unlink(self, node_id):
 
360
        """unlink a node."""
 
361
 
 
362
        def _unlink():
 
363
            """deferred part."""
 
364
            d = self.factory.current_protocol.unlink(self.volume, node_id)
 
365
            return d
 
366
 
 
367
        return self.defer_from_thread(_unlink)
 
368
 
 
369
    def move(self, node_id, new_parent_id, new_name):
 
370
        """move a node."""
 
371
 
 
372
        def _move():
 
373
            """deferred part."""
 
374
            d = self.factory.current_protocol.move(
 
375
                self.volume, node_id, new_parent_id, new_name)
 
376
            return d
 
377
 
 
378
        return self.defer_from_thread(_move)
 
379
 
 
380
    @parse_args(str, int)
 
381
    def do_connect(self, host, port):
 
382
        """Connect to host/port."""
 
383
        def _connect():
 
384
            """deferred part."""
 
385
            reactor.connectTCP(host, port, self.factory)
 
386
        self.status = "connecting"
 
387
        reactor.callFromThread(_connect)
 
388
 
 
389
    @parse_args(str, int)
 
390
    def do_connect_ssl(self, host, port):
 
391
        """Connect to host/port using ssl."""
 
392
        def _connect():
 
393
            """deferred part."""
 
394
            reactor.connectSSL(host, port, self.factory,
 
395
                               ssl.ClientContextFactory())
 
396
        self.status = "connecting"
 
397
        reactor.callFromThread(_connect)
 
398
 
 
399
    @parse_args()
 
400
    def do_status(self):
 
401
        """Print the status string."""
 
402
        print "STATUS: %s" % self.status
 
403
 
 
404
    @parse_args()
 
405
    def do_disconnect(self):
 
406
        """Disconnect."""
 
407
        if self.status != "connected":
 
408
            print "ERROR: Not connecting."
 
409
            return
 
410
        reactor.callFromThread(
 
411
            self.factory.current_protocol.transport.loseConnection)
 
412
 
 
413
    @parse_args(str)
 
414
    @require_connection
 
415
    @show_exception
 
416
    def do_dummy_auth(self, token):
 
417
        """Perform dummy authentication."""
 
418
        self.defer_from_thread(
 
419
            self.factory.current_protocol.dummy_authenticate, token)
 
420
 
 
421
    @parse_args()
 
422
    @require_connection
 
423
    @show_exception
 
424
    def do_shares(self):
 
425
        """Perform dummy authentication."""
 
426
        r = self.defer_from_thread(
 
427
            self.factory.current_protocol.list_shares)
 
428
        for share in r.shares:
 
429
            print share
 
430
            if share.accepted and share.direction == 'to_me':
 
431
                self.shares.add(str(share.id))
 
432
 
 
433
    @parse_args(str)
 
434
    @require_connection
 
435
    @show_exception
 
436
    def do_set_share(self, sharename):
 
437
        """Perform dummy authentication."""
 
438
        r = self.defer_from_thread(
 
439
            self.factory.current_protocol.list_shares)
 
440
        for share in r.shares:
 
441
            if str(share.id) == sharename:
 
442
                self.volume_root = share.subtree
 
443
                break
 
444
        else:
 
445
            print "BAD SHARE NAME"
 
446
            return
 
447
        self.volume = sharename
 
448
        self.cwd = '/'
 
449
 
 
450
    @parse_args()
 
451
    @require_connection
 
452
    @show_exception
 
453
    def do_volumes(self):
 
454
        """Perform dummy authentication."""
 
455
        r = self.defer_from_thread(
 
456
            self.factory.current_protocol.list_volumes)
 
457
        for volume in r.volumes:
 
458
            show_volume(volume)
 
459
            if not isinstance(volume, volumes.RootVolume):
 
460
                self.volumes.add(str(volume.volume_id))
 
461
 
 
462
    @parse_args(str)
 
463
    @require_connection
 
464
    @show_exception
 
465
    def do_set_volume(self, volume_id):
 
466
        """Perform dummy authentication."""
 
467
        r = self.defer_from_thread(
 
468
            self.factory.current_protocol.list_volumes)
 
469
        for volume in r.volumes:
 
470
            if str(volume.volume_id) == volume_id:
 
471
                self.volume_root = volume.node_id
 
472
                break
 
473
        else:
 
474
            print "BAD Volume ID"
 
475
            return
 
476
        self.volume = volume_id
 
477
        self.cwd = '/'
 
478
 
 
479
    @parse_args()
 
480
    @require_connection
 
481
    @show_exception
 
482
    def do_root(self):
 
483
        """Perform dummy authentication."""
 
484
        self.volume = request.ROOT
 
485
        self.cwd = '/'
 
486
        self.volume_root = None
 
487
        root = self.get_root()
 
488
        print "root is", root
 
489
 
 
490
    @require_connection
 
491
    def _list_dir(self, node_id):
 
492
        """Return the content of a directory."""
 
493
        content = self.get_content(node_id)
 
494
        unserialized_content = DirectoryContent()
 
495
        # TODO: what exceptions can protobuf's parser raise?
 
496
        unserialized_content.ParseFromString(content)
 
497
        return unserialized_content.entries
 
498
 
 
499
    @parse_args()
 
500
    @require_connection
 
501
    @show_exception
 
502
    def do_ls(self):
 
503
        """Get a listing of the current working directory."""
 
504
        node_id = self.get_cwd_id()
 
505
        entries = self._list_dir(node_id)
 
506
        for entry in entries:
 
507
            node_type = dircontent_pb2._NODETYPE. \
 
508
                values_by_number[entry.node_type].name
 
509
            print "%s %10s %s" % (entry.node, node_type, entry.name)
 
510
 
 
511
    @parse_args(str)
 
512
    @require_connection
 
513
    @show_exception
 
514
    def do_mkfile(self, name):
 
515
        """Create a file named name on the current working directory."""
 
516
        node_id = self.get_cwd_id()
 
517
        self.defer_from_thread(
 
518
            self.factory.current_protocol.make_file,
 
519
            self.volume, node_id, name)
 
520
 
 
521
    @parse_args(str)
 
522
    @show_exception
 
523
    def do_mkdir(self, name):
 
524
        """Create a directory named name on the current working directory."""
 
525
        self.mkdir(name)
 
526
 
 
527
    @require_connection
 
528
    def mkdir(self, name):
 
529
        """Create a directory named name on the current working directory."""
 
530
        node_id = self.get_cwd_id()
 
531
        self.defer_from_thread(
 
532
            self.factory.current_protocol.make_dir,
 
533
            self.volume, node_id, name)
 
534
 
 
535
    @parse_args(int)
 
536
    @show_exception
 
537
    def do_storm(self, intensity):
 
538
        """Storm operations to the server.
 
539
 
 
540
        It creates N directories, with N (empty) files in each dir.
 
541
        """
 
542
        self.storm(intensity)
 
543
 
 
544
    @require_connection
 
545
    def storm(self, intensity):
 
546
        """Storm operations to the server."""
 
547
        subroot_id = self.get_cwd_id()
 
548
 
 
549
        make_dir = self.factory.current_protocol.make_dir
 
550
        make_file = self.factory.current_protocol.make_file
 
551
 
 
552
        @defer.inlineCallbacks
 
553
        def go():
 
554
            """Actually do it."""
 
555
            tini = time.time()
 
556
            for _ in xrange(intensity):
 
557
                name = u"testdir-" + unicode(uuid.uuid4())
 
558
                req = yield make_dir(self.volume, subroot_id, name)
 
559
                for _ in xrange(intensity):
 
560
                    name = u"testfile-" + unicode(uuid.uuid4())
 
561
                    yield make_file(self.volume, req.new_id, name)
 
562
            tend = time.time()
 
563
            print "%d dirs and %d files created in %.2f seconds" % (
 
564
                intensity, intensity ** 2, tend - tini)
 
565
 
 
566
        self.defer_from_thread(go)
 
567
 
 
568
    @parse_args(str)
 
569
    @show_exception
 
570
    def do_cd(self, name):
 
571
        """CD to name."""
 
572
        self.cd(name)
 
573
 
 
574
    @require_connection
 
575
    def cd(self, name):
 
576
        """CD to name."""
 
577
 
 
578
        root = self.cwd
 
579
        newdir = os.path.normpath(os.path.join(root, name))
 
580
        parts = [part for part in newdir.split("/") if part]
 
581
        parent_id = self.get_root()
 
582
 
 
583
        for part in parts:
 
584
            if not self.is_dir(parent_id, part):
 
585
                print "ERROR: Not a directory"
 
586
                return
 
587
 
 
588
            parent_id = self.get_child_id(parent_id, part)
 
589
 
 
590
        self.cwd = newdir
 
591
 
 
592
    @parse_args(str, str)
 
593
    @show_exception
 
594
    def do_put(self, local, remote):
 
595
        """Put local file into remote file."""
 
596
        self.put(local, remote)
 
597
 
 
598
    @require_connection
 
599
    @show_exception
 
600
    def put(self, local, remote):
 
601
        """Put local file into remote file."""
 
602
        try:
 
603
            node_id = self.get_id_from_filename(remote)
 
604
        except ValueError:
 
605
            parent_id = self.get_cwd_id()
 
606
            r = self.defer_from_thread(
 
607
                self.factory.current_protocol.make_file,
 
608
                self.volume, parent_id, remote.split("/")[-1])
 
609
            node_id = r.new_id
 
610
 
 
611
        old_hash = self.get_hash(node_id)
 
612
 
 
613
        ho = content_hash_factory()
 
614
        zipper = zlib.compressobj()
 
615
        crc32_value = 0
 
616
        size = 0
 
617
        deflated_size = 0
 
618
        temp_file_name = None
 
619
        with open(local) as fh:
 
620
            with tempfile.NamedTemporaryFile(mode='w', prefix='cmd_client-',
 
621
                                             delete=False) as dest:
 
622
                temp_file_name = dest.name
 
623
                while True:
 
624
                    cont = fh.read(1024 ** 2)
 
625
                    if not cont:
 
626
                        dest.write(zipper.flush())
 
627
                        deflated_size = dest.tell()
 
628
                        break
 
629
                    ho.update(cont)
 
630
                    crc32_value = crc32(cont, crc32_value)
 
631
                    size += len(cont)
 
632
                    dest.write(zipper.compress(cont))
 
633
        hash_value = ho.content_hash()
 
634
        try:
 
635
            self.defer_from_thread(
 
636
                self.factory.current_protocol.put_content,
 
637
                self.volume, node_id, old_hash, hash_value,
 
638
                crc32_value, size, deflated_size, open(temp_file_name, 'r'))
 
639
        finally:
 
640
            if os.path.exists(temp_file_name):
 
641
                os.unlink(temp_file_name)
 
642
 
 
643
    @parse_args(str, str)
 
644
    @require_connection
 
645
    @show_exception
 
646
    def do_rput(self, local, remote):
 
647
        """Put local directory and it's files into remote directory."""
 
648
        def get_server_path(path):
 
649
            """ returns the server relative path """
 
650
            return path.rpartition(os.path.dirname(local))[2].lstrip('/')
 
651
        cwd = self.cwd
 
652
        for dirpath, dirnames, fnames in os.walk(local):
 
653
            server_path = get_server_path(dirpath)
 
654
            self.cd(os.path.dirname(server_path))
 
655
            leaf = os.path.basename(server_path)
 
656
            self.mkdir(leaf)
 
657
            self.cd(leaf)
 
658
            for filename in fnames:
 
659
                local_path = os.path.join(dirpath, filename)
 
660
                self.put(local_path, filename)
 
661
            self.cd(cwd)
 
662
 
 
663
    @parse_args(str, str)
 
664
    @require_connection
 
665
    @show_exception
 
666
    def do_get(self, remote, local):
 
667
        """Get remote file into local file."""
 
668
        data = self.get_file(remote)
 
669
        f = open(local, "w")
 
670
        f.write(data)
 
671
        f.close()
 
672
 
 
673
    @parse_args(str)
 
674
    @require_connection
 
675
    @show_exception
 
676
    def do_cat(self, remote):
 
677
        """Show the contents of remote file on screen."""
 
678
        data = self.get_file(remote)
 
679
        print data
 
680
 
 
681
    @parse_args(str)
 
682
    @require_connection
 
683
    @show_exception
 
684
    def do_hash(self, filename):
 
685
        """Print the hash of filename."""
 
686
        node_id = self.get_id_from_filename(filename)
 
687
        hash_value = self.get_hash(node_id)
 
688
        print hash_value
 
689
 
 
690
    @parse_args(str)
 
691
    @require_connection
 
692
    @show_exception
 
693
    def do_unlink(self, filename):
 
694
        """Print the hash of filename."""
 
695
        node_id = self.get_id_from_filename(filename)
 
696
        self.unlink(node_id)
 
697
 
 
698
    @parse_args(str, str)
 
699
    @require_connection
 
700
    @show_exception
 
701
    def do_move(self, source, dest):
 
702
        """Move file source to dest."""
 
703
        source_node_id = self.get_id_from_filename(source)
 
704
        try:
 
705
            dest_node_id = self.get_id_from_filename(dest)
 
706
        except ValueError:
 
707
            parent_name, node_name = os.path.split(dest)
 
708
            dest_node_id = self.get_id_from_filename(parent_name)
 
709
            self.move(source_node_id, dest_node_id, node_name)
 
710
        else:
 
711
            parent_name, node_name = os.path.split(source)
 
712
            self.move(source_node_id, dest_node_id, node_name)
 
713
 
 
714
    @defer.inlineCallbacks
 
715
    def _auth(self, consumer, token):
 
716
        """Really authenticate, and show the session id."""
 
717
        auth_method = self.factory.current_protocol.simple_authenticate
 
718
        req = yield auth_method(self.username, self.password)
 
719
        print "Authenticated ok, session:", req.session_id
 
720
 
 
721
    @parse_args(str, str)
 
722
    @require_connection
 
723
    @show_exception
 
724
    def do_oauth(self):
 
725
        """Perform authorisation."""
 
726
        self.defer_from_thread(self._auth)
 
727
 
 
728
    def do_shell(self, cmd):
 
729
        """Execute a shell command."""
 
730
        os.system(cmd)
 
731
 
 
732
    def do_quit(self, rest):
 
733
        """Exit the shell."""
 
734
        print "Goodbye", rest
 
735
        return True
 
736
    do_EOF = do_quit
 
737
 
 
738
    @require_connection
 
739
    def complete_set_volume(self, text, line, begidx, endidx):
 
740
        """Completion for set_volume."""
 
741
        if not self.volumes:
 
742
            r = self.defer_from_thread(
 
743
                self.factory.current_protocol.list_volumes)
 
744
            for volume in r.volumes:
 
745
                if not isinstance(volume, volumes.RootVolume):
 
746
                    self.volumes.add(str(volume.volume_id))
 
747
        return [vol_id for vol_id in sorted(self.volumes)
 
748
                if vol_id.startswith(text)]
 
749
 
 
750
    @require_connection
 
751
    def complete_set_share(self, text, line, begidx, endidx):
 
752
        """Completion for set_share."""
 
753
        if not self.shares:
 
754
            r = self.defer_from_thread(
 
755
                self.factory.current_protocol.list_shares)
 
756
            for share in r.shares:
 
757
                if share.accepted and share.direction == 'to_me':
 
758
                    self.shares.add(str(share.id))
 
759
        return [share_id for share_id in sorted(self.shares)
 
760
                if share_id.startswith(text)]
 
761
 
 
762
    @require_connection
 
763
    def _complete_single_filename(self, text, line, begidx, endidx):
 
764
        """Completion for remote filename for single argument commands."""
 
765
        node_id = self.get_cwd_id()
 
766
        entries = self._list_dir(node_id)
 
767
        return [entry.name for entry in entries
 
768
                if entry.name.startswith(text)]
 
769
 
 
770
    complete_cat = complete_unlink = _complete_single_filename
 
771
 
 
772
    @require_connection
 
773
    def complete_get(self, text, line, begidx, endidx):
 
774
        """Completion for get command."""
 
775
        if len(line.split(' ')) < 3:
 
776
            node_id = self.get_cwd_id()
 
777
            entries = self._list_dir(node_id)
 
778
            return [entry.name for entry in entries
 
779
                    if entry.node_type != dircontent_pb2.DIRECTORY
 
780
                    and entry.name.startswith(text)]
 
781
 
 
782
    def _complete_local(self, text, include_dirs=False):
 
783
        """Return the list of possible local filenames."""
 
784
        isdir = os.path.isdir
 
785
 
 
786
        def filter_files(files):
 
787
            """Firlter files/dirs."""
 
788
            return [f for f in files if include_dirs or not isdir(f)]
 
789
        if not os.path.exists(text):
 
790
            head, tail = os.path.split(text)
 
791
            while head and tail and not os.path.exists(head):
 
792
                head, tail = os.path.split(text)
 
793
            dirs = os.listdir(head or '.')
 
794
            return filter_files([d for d in dirs if d.startswith(tail)])
 
795
        elif os.path.exists(text) and isdir(text):
 
796
            return filter_files(os.listdir(text))
 
797
        else:
 
798
            return []
 
799
 
 
800
    @require_connection
 
801
    def complete_put(self, text, line, begidx, endidx):
 
802
        """Completion for put command."""
 
803
        if len(line.split(' ')) < 3:
 
804
            # local
 
805
            return self._complete_local(text)
 
806
        else:  # remote
 
807
            node_id = self.get_cwd_id()
 
808
            entries = self._list_dir(node_id)
 
809
            return [entry.name for entry in entries
 
810
                    if entry.node_type != dircontent_pb2.DIRECTORY
 
811
                    and entry.name.startswith(text)]
 
812
 
 
813
    @require_connection
 
814
    def complete_cd(self, text, line, begidx, endidx):
 
815
        """Completion for cd command."""
 
816
        node_id = self.get_cwd_id()
 
817
        entries = self._list_dir(node_id)
 
818
        return [entry.name for entry in entries
 
819
                if entry.node_type == dircontent_pb2.DIRECTORY
 
820
                and entry.name.startswith(text)]
 
821
 
 
822
 
 
823
def main():
 
824
    """run the cmd_client parsing cmd line options"""
 
825
    usage = "usage: %prog [options] [CMD]"
 
826
    parser = OptionParser(usage=usage)
 
827
    parser.add_option("--port", dest="port", metavar="PORT",
 
828
                      default=443,
 
829
                      help="The port on which to connect to the server")
 
830
    parser.add_option("--host", dest="host", metavar="HOST",
 
831
                      default='localhost',
 
832
                      help="The server address")
 
833
    parser.add_option("--username", dest="username", metavar="USERNAME",
 
834
                      help="The username")
 
835
    parser.add_option("--password", dest="password", metavar="PASSWORD",
 
836
                      help="The password")
 
837
    parser.add_option("-f", "--file", dest="filename",
 
838
                      help="write report to FILE", metavar="FILE")
 
839
 
 
840
    (options, args) = parser.parse_args()
 
841
 
 
842
    client = ClientCmd(options.username, options.password)
 
843
    client.onecmd('connect_ssl "%s" %s' % (options.host, options.port))
 
844
 
 
845
    while client.status != 'connected':
 
846
        time.sleep(.5)
 
847
 
 
848
    client.onecmd("auth")
 
849
 
 
850
    if args:
 
851
        client.onecmd(" ".join(args))
 
852
    else:
 
853
        client.cmdloop()
 
854
 
 
855
if __name__ == "__main__":
 
856
    main()