~ubuntu-branches/ubuntu/karmic/tahoe-lafs/karmic

« back to all changes in this revision

Viewing changes to src/allmydata/immutable/download.py

  • Committer: Bazaar Package Importer
  • Author(s): Zooko O'Whielacronx (Hacker)
  • Date: 2009-09-24 00:00:05 UTC
  • Revision ID: james.westby@ubuntu.com-20090924000005-ixe2n4yngmk49ysz
Tags: upstream-1.5.0
ImportĀ upstreamĀ versionĀ 1.5.0

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
import os, random, weakref, itertools, time
 
2
from zope.interface import implements
 
3
from twisted.internet import defer
 
4
from twisted.internet.interfaces import IPushProducer, IConsumer
 
5
from twisted.application import service
 
6
from foolscap.api import DeadReferenceError, RemoteException, eventually
 
7
 
 
8
from allmydata.util import base32, deferredutil, hashutil, log, mathutil, idlib
 
9
from allmydata.util.assertutil import _assert, precondition
 
10
from allmydata import codec, hashtree, uri
 
11
from allmydata.interfaces import IDownloadTarget, IDownloader, \
 
12
     IFileURI, IVerifierURI, \
 
13
     IDownloadStatus, IDownloadResults, IValidatedThingProxy, \
 
14
     IStorageBroker, NotEnoughSharesError, NoSharesError, NoServersError, \
 
15
     UnableToFetchCriticalDownloadDataError
 
16
from allmydata.immutable import layout
 
17
from allmydata.monitor import Monitor
 
18
from pycryptopp.cipher.aes import AES
 
19
 
 
20
class IntegrityCheckReject(Exception):
 
21
    pass
 
22
 
 
23
class BadURIExtensionHashValue(IntegrityCheckReject):
 
24
    pass
 
25
class BadURIExtension(IntegrityCheckReject):
 
26
    pass
 
27
class UnsupportedErasureCodec(BadURIExtension):
 
28
    pass
 
29
class BadCrypttextHashValue(IntegrityCheckReject):
 
30
    pass
 
31
class BadOrMissingHash(IntegrityCheckReject):
 
32
    pass
 
33
 
 
34
class DownloadStopped(Exception):
 
35
    pass
 
36
 
 
37
class DownloadResults:
 
38
    implements(IDownloadResults)
 
39
 
 
40
    def __init__(self):
 
41
        self.servers_used = set()
 
42
        self.server_problems = {}
 
43
        self.servermap = {}
 
44
        self.timings = {}
 
45
        self.file_size = None
 
46
 
 
47
class DecryptingTarget(log.PrefixingLogMixin):
 
48
    implements(IDownloadTarget, IConsumer)
 
49
    def __init__(self, target, key, _log_msg_id=None):
 
50
        precondition(IDownloadTarget.providedBy(target), target)
 
51
        self.target = target
 
52
        self._decryptor = AES(key)
 
53
        prefix = str(target)
 
54
        log.PrefixingLogMixin.__init__(self, "allmydata.immutable.download", _log_msg_id, prefix=prefix)
 
55
    # methods to satisfy the IConsumer interface
 
56
    def registerProducer(self, producer, streaming):
 
57
        if IConsumer.providedBy(self.target):
 
58
            self.target.registerProducer(producer, streaming)
 
59
    def unregisterProducer(self):
 
60
        if IConsumer.providedBy(self.target):
 
61
            self.target.unregisterProducer()
 
62
    def write(self, ciphertext):
 
63
        plaintext = self._decryptor.process(ciphertext)
 
64
        self.target.write(plaintext)
 
65
    def open(self, size):
 
66
        self.target.open(size)
 
67
    def close(self):
 
68
        self.target.close()
 
69
    def finish(self):
 
70
        return self.target.finish()
 
71
    # The following methods is just to pass through to the next target, and just because that
 
72
    # target might be a repairer.DownUpConnector, and just because the current CHKUpload object
 
73
    # expects to find the storage index in its Uploadable.
 
74
    def set_storageindex(self, storageindex):
 
75
        self.target.set_storageindex(storageindex)
 
76
    def set_encodingparams(self, encodingparams):
 
77
        self.target.set_encodingparams(encodingparams)
 
78
 
 
79
class ValidatedThingObtainer:
 
80
    def __init__(self, validatedthingproxies, debugname, log_id):
 
81
        self._validatedthingproxies = validatedthingproxies
 
82
        self._debugname = debugname
 
83
        self._log_id = log_id
 
84
 
 
85
    def _bad(self, f, validatedthingproxy):
 
86
        failtype = f.trap(RemoteException, DeadReferenceError,
 
87
                          IntegrityCheckReject, layout.LayoutInvalid,
 
88
                          layout.ShareVersionIncompatible)
 
89
        level = log.WEIRD
 
90
        if f.check(DeadReferenceError):
 
91
            level = log.UNUSUAL
 
92
        elif f.check(RemoteException):
 
93
            level = log.WEIRD
 
94
        else:
 
95
            level = log.SCARY
 
96
        log.msg(parent=self._log_id, facility="tahoe.immutable.download",
 
97
                format="operation %(op)s from validatedthingproxy %(validatedthingproxy)s failed",
 
98
                op=self._debugname, validatedthingproxy=str(validatedthingproxy),
 
99
                failure=f, level=level, umid="JGXxBA")
 
100
        if not self._validatedthingproxies:
 
101
            raise UnableToFetchCriticalDownloadDataError("ran out of peers, last error was %s" % (f,))
 
102
        # try again with a different one
 
103
        d = self._try_the_next_one()
 
104
        return d
 
105
 
 
106
    def _try_the_next_one(self):
 
107
        vtp = self._validatedthingproxies.pop(0)
 
108
        d = vtp.start() # start() obtains, validates, and callsback-with the thing or else errbacks
 
109
        d.addErrback(self._bad, vtp)
 
110
        return d
 
111
 
 
112
    def start(self):
 
113
        return self._try_the_next_one()
 
114
 
 
115
class ValidatedCrypttextHashTreeProxy:
 
116
    implements(IValidatedThingProxy)
 
117
    """ I am a front-end for a remote crypttext hash tree using a local ReadBucketProxy -- I use
 
118
    its get_crypttext_hashes() method and offer the Validated Thing protocol (i.e., I have a
 
119
    start() method that fires with self once I get a valid one). """
 
120
    def __init__(self, readbucketproxy, crypttext_hash_tree, num_segments, fetch_failures=None):
 
121
        # fetch_failures is for debugging -- see test_encode.py
 
122
        self._readbucketproxy = readbucketproxy
 
123
        self._num_segments = num_segments
 
124
        self._fetch_failures = fetch_failures
 
125
        self._crypttext_hash_tree = crypttext_hash_tree
 
126
 
 
127
    def _validate(self, proposal):
 
128
        ct_hashes = dict(list(enumerate(proposal)))
 
129
        try:
 
130
            self._crypttext_hash_tree.set_hashes(ct_hashes)
 
131
        except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
 
132
            if self._fetch_failures is not None:
 
133
                self._fetch_failures["crypttext_hash_tree"] += 1
 
134
            raise BadOrMissingHash(le)
 
135
        # If we now have enough of the crypttext hash tree to integrity-check *any* segment of ciphertext, then we are done.
 
136
        # TODO: It would have better alacrity if we downloaded only part of the crypttext hash tree at a time.
 
137
        for segnum in range(self._num_segments):
 
138
            if self._crypttext_hash_tree.needed_hashes(segnum):
 
139
                raise BadOrMissingHash("not enough hashes to validate segment number %d" % (segnum,))
 
140
        return self
 
141
 
 
142
    def start(self):
 
143
        d = self._readbucketproxy.get_crypttext_hashes()
 
144
        d.addCallback(self._validate)
 
145
        return d
 
146
 
 
147
class ValidatedExtendedURIProxy:
 
148
    implements(IValidatedThingProxy)
 
149
    """ I am a front-end for a remote UEB (using a local ReadBucketProxy), responsible for
 
150
    retrieving and validating the elements from the UEB. """
 
151
 
 
152
    def __init__(self, readbucketproxy, verifycap, fetch_failures=None):
 
153
        # fetch_failures is for debugging -- see test_encode.py
 
154
        self._fetch_failures = fetch_failures
 
155
        self._readbucketproxy = readbucketproxy
 
156
        precondition(IVerifierURI.providedBy(verifycap), verifycap)
 
157
        self._verifycap = verifycap
 
158
 
 
159
        # required
 
160
        self.segment_size = None
 
161
        self.crypttext_root_hash = None
 
162
        self.share_root_hash = None
 
163
 
 
164
        # computed
 
165
        self.block_size = None
 
166
        self.share_size = None
 
167
        self.num_segments = None
 
168
        self.tail_data_size = None
 
169
        self.tail_segment_size = None
 
170
 
 
171
        # optional
 
172
        self.crypttext_hash = None
 
173
 
 
174
    def __str__(self):
 
175
        return "<%s %s>" % (self.__class__.__name__, self._verifycap.to_string())
 
176
 
 
177
    def _check_integrity(self, data):
 
178
        h = hashutil.uri_extension_hash(data)
 
179
        if h != self._verifycap.uri_extension_hash:
 
180
            msg = ("The copy of uri_extension we received from %s was bad: wanted %s, got %s" %
 
181
                   (self._readbucketproxy, base32.b2a(self._verifycap.uri_extension_hash), base32.b2a(h)))
 
182
            if self._fetch_failures is not None:
 
183
                self._fetch_failures["uri_extension"] += 1
 
184
            raise BadURIExtensionHashValue(msg)
 
185
        else:
 
186
            return data
 
187
 
 
188
    def _parse_and_validate(self, data):
 
189
        self.share_size = mathutil.div_ceil(self._verifycap.size, self._verifycap.needed_shares)
 
190
 
 
191
        d = uri.unpack_extension(data)
 
192
 
 
193
        # There are several kinds of things that can be found in a UEB.  First, things that we
 
194
        # really need to learn from the UEB in order to do this download. Next: things which are
 
195
        # optional but not redundant -- if they are present in the UEB they will get used. Next,
 
196
        # things that are optional and redundant. These things are required to be consistent:
 
197
        # they don't have to be in the UEB, but if they are in the UEB then they will be checked
 
198
        # for consistency with the already-known facts, and if they are inconsistent then an
 
199
        # exception will be raised. These things aren't actually used -- they are just tested
 
200
        # for consistency and ignored. Finally: things which are deprecated -- they ought not be
 
201
        # in the UEB at all, and if they are present then a warning will be logged but they are
 
202
        # otherwise ignored.
 
203
 
 
204
       # First, things that we really need to learn from the UEB: segment_size,
 
205
        # crypttext_root_hash, and share_root_hash.
 
206
        self.segment_size = d['segment_size']
 
207
 
 
208
        self.block_size = mathutil.div_ceil(self.segment_size, self._verifycap.needed_shares)
 
209
        self.num_segments = mathutil.div_ceil(self._verifycap.size, self.segment_size)
 
210
 
 
211
        self.tail_data_size = self._verifycap.size % self.segment_size
 
212
        if not self.tail_data_size:
 
213
            self.tail_data_size = self.segment_size
 
214
        # padding for erasure code
 
215
        self.tail_segment_size = mathutil.next_multiple(self.tail_data_size, self._verifycap.needed_shares)
 
216
 
 
217
        # Ciphertext hash tree root is mandatory, so that there is at most one ciphertext that
 
218
        # matches this read-cap or verify-cap.  The integrity check on the shares is not
 
219
        # sufficient to prevent the original encoder from creating some shares of file A and
 
220
        # other shares of file B.
 
221
        self.crypttext_root_hash = d['crypttext_root_hash']
 
222
 
 
223
        self.share_root_hash = d['share_root_hash']
 
224
 
 
225
 
 
226
        # Next: things that are optional and not redundant: crypttext_hash
 
227
        if d.has_key('crypttext_hash'):
 
228
            self.crypttext_hash = d['crypttext_hash']
 
229
            if len(self.crypttext_hash) != hashutil.CRYPTO_VAL_SIZE:
 
230
                raise BadURIExtension('crypttext_hash is required to be hashutil.CRYPTO_VAL_SIZE bytes, not %s bytes' % (len(self.crypttext_hash),))
 
231
 
 
232
 
 
233
        # Next: things that are optional, redundant, and required to be consistent: codec_name,
 
234
        # codec_params, tail_codec_params, num_segments, size, needed_shares, total_shares
 
235
        if d.has_key('codec_name'):
 
236
            if d['codec_name'] != "crs":
 
237
                raise UnsupportedErasureCodec(d['codec_name'])
 
238
 
 
239
        if d.has_key('codec_params'):
 
240
            ucpss, ucpns, ucpts = codec.parse_params(d['codec_params'])
 
241
            if ucpss != self.segment_size:
 
242
                raise BadURIExtension("inconsistent erasure code params: ucpss: %s != "
 
243
                                      "self.segment_size: %s" % (ucpss, self.segment_size))
 
244
            if ucpns != self._verifycap.needed_shares:
 
245
                raise BadURIExtension("inconsistent erasure code params: ucpns: %s != "
 
246
                                      "self._verifycap.needed_shares: %s" % (ucpns,
 
247
                                                                             self._verifycap.needed_shares))
 
248
            if ucpts != self._verifycap.total_shares:
 
249
                raise BadURIExtension("inconsistent erasure code params: ucpts: %s != "
 
250
                                      "self._verifycap.total_shares: %s" % (ucpts,
 
251
                                                                            self._verifycap.total_shares))
 
252
 
 
253
        if d.has_key('tail_codec_params'):
 
254
            utcpss, utcpns, utcpts = codec.parse_params(d['tail_codec_params'])
 
255
            if utcpss != self.tail_segment_size:
 
256
                raise BadURIExtension("inconsistent erasure code params: utcpss: %s != "
 
257
                                      "self.tail_segment_size: %s, self._verifycap.size: %s, "
 
258
                                      "self.segment_size: %s, self._verifycap.needed_shares: %s"
 
259
                                      % (utcpss, self.tail_segment_size, self._verifycap.size,
 
260
                                         self.segment_size, self._verifycap.needed_shares))
 
261
            if utcpns != self._verifycap.needed_shares:
 
262
                raise BadURIExtension("inconsistent erasure code params: utcpns: %s != "
 
263
                                      "self._verifycap.needed_shares: %s" % (utcpns,
 
264
                                                                             self._verifycap.needed_shares))
 
265
            if utcpts != self._verifycap.total_shares:
 
266
                raise BadURIExtension("inconsistent erasure code params: utcpts: %s != "
 
267
                                      "self._verifycap.total_shares: %s" % (utcpts,
 
268
                                                                            self._verifycap.total_shares))
 
269
 
 
270
        if d.has_key('num_segments'):
 
271
            if d['num_segments'] != self.num_segments:
 
272
                raise BadURIExtension("inconsistent num_segments: size: %s, "
 
273
                                      "segment_size: %s, computed_num_segments: %s, "
 
274
                                      "ueb_num_segments: %s" % (self._verifycap.size,
 
275
                                                                self.segment_size,
 
276
                                                                self.num_segments, d['num_segments']))
 
277
 
 
278
        if d.has_key('size'):
 
279
            if d['size'] != self._verifycap.size:
 
280
                raise BadURIExtension("inconsistent size: URI size: %s, UEB size: %s" %
 
281
                                      (self._verifycap.size, d['size']))
 
282
 
 
283
        if d.has_key('needed_shares'):
 
284
            if d['needed_shares'] != self._verifycap.needed_shares:
 
285
                raise BadURIExtension("inconsistent needed shares: URI needed shares: %s, UEB "
 
286
                                      "needed shares: %s" % (self._verifycap.total_shares,
 
287
                                                             d['needed_shares']))
 
288
 
 
289
        if d.has_key('total_shares'):
 
290
            if d['total_shares'] != self._verifycap.total_shares:
 
291
                raise BadURIExtension("inconsistent total shares: URI total shares: %s, UEB "
 
292
                                      "total shares: %s" % (self._verifycap.total_shares,
 
293
                                                            d['total_shares']))
 
294
 
 
295
        # Finally, things that are deprecated and ignored: plaintext_hash, plaintext_root_hash
 
296
        if d.get('plaintext_hash'):
 
297
            log.msg("Found plaintext_hash in UEB. This field is deprecated for security reasons "
 
298
                    "and is no longer used.  Ignoring.  %s" % (self,))
 
299
        if d.get('plaintext_root_hash'):
 
300
            log.msg("Found plaintext_root_hash in UEB. This field is deprecated for security "
 
301
                    "reasons and is no longer used.  Ignoring.  %s" % (self,))
 
302
 
 
303
        return self
 
304
 
 
305
    def start(self):
 
306
        """ Fetch the UEB from bucket, compare its hash to the hash from verifycap, then parse
 
307
        it.  Returns a deferred which is called back with self once the fetch is successful, or
 
308
        is erred back if it fails. """
 
309
        d = self._readbucketproxy.get_uri_extension()
 
310
        d.addCallback(self._check_integrity)
 
311
        d.addCallback(self._parse_and_validate)
 
312
        return d
 
313
 
 
314
class ValidatedReadBucketProxy(log.PrefixingLogMixin):
 
315
    """I am a front-end for a remote storage bucket, responsible for retrieving and validating
 
316
    data from that bucket.
 
317
 
 
318
    My get_block() method is used by BlockDownloaders.
 
319
    """
 
320
 
 
321
    def __init__(self, sharenum, bucket, share_hash_tree, num_blocks, block_size, share_size):
 
322
        """ share_hash_tree is required to have already been initialized with the root hash
 
323
        (the number-0 hash), using the share_root_hash from the UEB """
 
324
        precondition(share_hash_tree[0] is not None, share_hash_tree)
 
325
        prefix = "%d-%s-%s" % (sharenum, bucket, base32.b2a_l(share_hash_tree[0][:8], 60))
 
326
        log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
 
327
        self.sharenum = sharenum
 
328
        self.bucket = bucket
 
329
        self.share_hash_tree = share_hash_tree
 
330
        self.num_blocks = num_blocks
 
331
        self.block_size = block_size
 
332
        self.share_size = share_size
 
333
        self.block_hash_tree = hashtree.IncompleteHashTree(self.num_blocks)
 
334
 
 
335
    def get_block(self, blocknum):
 
336
        # the first time we use this bucket, we need to fetch enough elements
 
337
        # of the share hash tree to validate it from our share hash up to the
 
338
        # hashroot.
 
339
        if self.share_hash_tree.needed_hashes(self.sharenum):
 
340
            d1 = self.bucket.get_share_hashes()
 
341
        else:
 
342
            d1 = defer.succeed([])
 
343
 
 
344
        # We might need to grab some elements of our block hash tree, to
 
345
        # validate the requested block up to the share hash.
 
346
        blockhashesneeded = self.block_hash_tree.needed_hashes(blocknum, include_leaf=True)
 
347
        # We don't need the root of the block hash tree, as that comes in the share tree.
 
348
        blockhashesneeded.discard(0)
 
349
        d2 = self.bucket.get_block_hashes(blockhashesneeded)
 
350
 
 
351
        if blocknum < self.num_blocks-1:
 
352
            thisblocksize = self.block_size
 
353
        else:
 
354
            thisblocksize = self.share_size % self.block_size
 
355
            if thisblocksize == 0:
 
356
                thisblocksize = self.block_size
 
357
        d3 = self.bucket.get_block_data(blocknum, self.block_size, thisblocksize)
 
358
 
 
359
        dl = deferredutil.gatherResults([d1, d2, d3])
 
360
        dl.addCallback(self._got_data, blocknum)
 
361
        return dl
 
362
 
 
363
    def _got_data(self, results, blocknum):
 
364
        precondition(blocknum < self.num_blocks, self, blocknum, self.num_blocks)
 
365
        sharehashes, blockhashes, blockdata = results
 
366
        try:
 
367
            sharehashes = dict(sharehashes)
 
368
        except ValueError, le:
 
369
            le.args = tuple(le.args + (sharehashes,))
 
370
            raise
 
371
        blockhashes = dict(enumerate(blockhashes))
 
372
 
 
373
        candidate_share_hash = None # in case we log it in the except block below
 
374
        blockhash = None # in case we log it in the except block below
 
375
 
 
376
        try:
 
377
            if self.share_hash_tree.needed_hashes(self.sharenum):
 
378
                # This will raise exception if the values being passed do not match the root
 
379
                # node of self.share_hash_tree.
 
380
                try:
 
381
                    self.share_hash_tree.set_hashes(sharehashes)
 
382
                except IndexError, le:
 
383
                    # Weird -- sharehashes contained index numbers outside of the range that fit
 
384
                    # into this hash tree.
 
385
                    raise BadOrMissingHash(le)
 
386
 
 
387
            # To validate a block we need the root of the block hash tree, which is also one of
 
388
            # the leafs of the share hash tree, and is called "the share hash".
 
389
            if not self.block_hash_tree[0]: # empty -- no root node yet
 
390
                # Get the share hash from the share hash tree.
 
391
                share_hash = self.share_hash_tree.get_leaf(self.sharenum)
 
392
                if not share_hash:
 
393
                    raise hashtree.NotEnoughHashesError # No root node in block_hash_tree and also the share hash wasn't sent by the server.
 
394
                self.block_hash_tree.set_hashes({0: share_hash})
 
395
 
 
396
            if self.block_hash_tree.needed_hashes(blocknum):
 
397
                self.block_hash_tree.set_hashes(blockhashes)
 
398
 
 
399
            blockhash = hashutil.block_hash(blockdata)
 
400
            self.block_hash_tree.set_hashes(leaves={blocknum: blockhash})
 
401
            #self.log("checking block_hash(shareid=%d, blocknum=%d) len=%d "
 
402
            #        "%r .. %r: %s" %
 
403
            #        (self.sharenum, blocknum, len(blockdata),
 
404
            #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
 
405
 
 
406
        except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
 
407
            # log.WEIRD: indicates undetected disk/network error, or more
 
408
            # likely a programming error
 
409
            self.log("hash failure in block=%d, shnum=%d on %s" %
 
410
                    (blocknum, self.sharenum, self.bucket))
 
411
            if self.block_hash_tree.needed_hashes(blocknum):
 
412
                self.log(""" failure occurred when checking the block_hash_tree.
 
413
                This suggests that either the block data was bad, or that the
 
414
                block hashes we received along with it were bad.""")
 
415
            else:
 
416
                self.log(""" the failure probably occurred when checking the
 
417
                share_hash_tree, which suggests that the share hashes we
 
418
                received from the remote peer were bad.""")
 
419
            self.log(" have candidate_share_hash: %s" % bool(candidate_share_hash))
 
420
            self.log(" block length: %d" % len(blockdata))
 
421
            self.log(" block hash: %s" % base32.b2a_or_none(blockhash))
 
422
            if len(blockdata) < 100:
 
423
                self.log(" block data: %r" % (blockdata,))
 
424
            else:
 
425
                self.log(" block data start/end: %r .. %r" %
 
426
                        (blockdata[:50], blockdata[-50:]))
 
427
            self.log(" share hash tree:\n" + self.share_hash_tree.dump())
 
428
            self.log(" block hash tree:\n" + self.block_hash_tree.dump())
 
429
            lines = []
 
430
            for i,h in sorted(sharehashes.items()):
 
431
                lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
 
432
            self.log(" sharehashes:\n" + "\n".join(lines) + "\n")
 
433
            lines = []
 
434
            for i,h in blockhashes.items():
 
435
                lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
 
436
            log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
 
437
            raise BadOrMissingHash(le)
 
438
 
 
439
        # If we made it here, the block is good. If the hash trees didn't
 
440
        # like what they saw, they would have raised a BadHashError, causing
 
441
        # our caller to see a Failure and thus ignore this block (as well as
 
442
        # dropping this bucket).
 
443
        return blockdata
 
444
 
 
445
 
 
446
 
 
447
class BlockDownloader(log.PrefixingLogMixin):
 
448
    """I am responsible for downloading a single block (from a single bucket)
 
449
    for a single segment.
 
450
 
 
451
    I am a child of the SegmentDownloader.
 
452
    """
 
453
 
 
454
    def __init__(self, vbucket, blocknum, parent, results):
 
455
        precondition(isinstance(vbucket, ValidatedReadBucketProxy), vbucket)
 
456
        prefix = "%s-%d" % (vbucket, blocknum)
 
457
        log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
 
458
        self.vbucket = vbucket
 
459
        self.blocknum = blocknum
 
460
        self.parent = parent
 
461
        self.results = results
 
462
 
 
463
    def start(self, segnum):
 
464
        self.log("get_block(segnum=%d)" % segnum)
 
465
        started = time.time()
 
466
        d = self.vbucket.get_block(segnum)
 
467
        d.addCallbacks(self._hold_block, self._got_block_error,
 
468
                       callbackArgs=(started,))
 
469
        return d
 
470
 
 
471
    def _hold_block(self, data, started):
 
472
        if self.results:
 
473
            elapsed = time.time() - started
 
474
            peerid = self.vbucket.bucket.get_peerid()
 
475
            if peerid not in self.results.timings["fetch_per_server"]:
 
476
                self.results.timings["fetch_per_server"][peerid] = []
 
477
            self.results.timings["fetch_per_server"][peerid].append(elapsed)
 
478
        self.log("got block")
 
479
        self.parent.hold_block(self.blocknum, data)
 
480
 
 
481
    def _got_block_error(self, f):
 
482
        failtype = f.trap(RemoteException, DeadReferenceError,
 
483
                          IntegrityCheckReject,
 
484
                          layout.LayoutInvalid, layout.ShareVersionIncompatible)
 
485
        if f.check(RemoteException, DeadReferenceError):
 
486
            level = log.UNUSUAL
 
487
        else:
 
488
            level = log.WEIRD
 
489
        self.log("failure to get block", level=level, umid="5Z4uHQ")
 
490
        if self.results:
 
491
            peerid = self.vbucket.bucket.get_peerid()
 
492
            self.results.server_problems[peerid] = str(f)
 
493
        self.parent.bucket_failed(self.vbucket)
 
494
 
 
495
class SegmentDownloader:
 
496
    """I am responsible for downloading all the blocks for a single segment
 
497
    of data.
 
498
 
 
499
    I am a child of the CiphertextDownloader.
 
500
    """
 
501
 
 
502
    def __init__(self, parent, segmentnumber, needed_shares, results):
 
503
        self.parent = parent
 
504
        self.segmentnumber = segmentnumber
 
505
        self.needed_blocks = needed_shares
 
506
        self.blocks = {} # k: blocknum, v: data
 
507
        self.results = results
 
508
        self._log_number = self.parent.log("starting segment %d" %
 
509
                                           segmentnumber)
 
510
 
 
511
    def log(self, *args, **kwargs):
 
512
        if "parent" not in kwargs:
 
513
            kwargs["parent"] = self._log_number
 
514
        return self.parent.log(*args, **kwargs)
 
515
 
 
516
    def start(self):
 
517
        return self._download()
 
518
 
 
519
    def _download(self):
 
520
        d = self._try()
 
521
        def _done(res):
 
522
            if len(self.blocks) >= self.needed_blocks:
 
523
                # we only need self.needed_blocks blocks
 
524
                # we want to get the smallest blockids, because they are
 
525
                # more likely to be fast "primary blocks"
 
526
                blockids = sorted(self.blocks.keys())[:self.needed_blocks]
 
527
                blocks = []
 
528
                for blocknum in blockids:
 
529
                    blocks.append(self.blocks[blocknum])
 
530
                return (blocks, blockids)
 
531
            else:
 
532
                return self._download()
 
533
        d.addCallback(_done)
 
534
        return d
 
535
 
 
536
    def _try(self):
 
537
        # fill our set of active buckets, maybe raising NotEnoughSharesError
 
538
        active_buckets = self.parent._activate_enough_buckets()
 
539
        # Now we have enough buckets, in self.parent.active_buckets.
 
540
 
 
541
        # in test cases, bd.start might mutate active_buckets right away, so
 
542
        # we need to put off calling start() until we've iterated all the way
 
543
        # through it.
 
544
        downloaders = []
 
545
        for blocknum, vbucket in active_buckets.iteritems():
 
546
            assert isinstance(vbucket, ValidatedReadBucketProxy), vbucket
 
547
            bd = BlockDownloader(vbucket, blocknum, self, self.results)
 
548
            downloaders.append(bd)
 
549
            if self.results:
 
550
                self.results.servers_used.add(vbucket.bucket.get_peerid())
 
551
        l = [bd.start(self.segmentnumber) for bd in downloaders]
 
552
        return defer.DeferredList(l, fireOnOneErrback=True)
 
553
 
 
554
    def hold_block(self, blocknum, data):
 
555
        self.blocks[blocknum] = data
 
556
 
 
557
    def bucket_failed(self, vbucket):
 
558
        self.parent.bucket_failed(vbucket)
 
559
 
 
560
class DownloadStatus:
 
561
    implements(IDownloadStatus)
 
562
    statusid_counter = itertools.count(0)
 
563
 
 
564
    def __init__(self):
 
565
        self.storage_index = None
 
566
        self.size = None
 
567
        self.helper = False
 
568
        self.status = "Not started"
 
569
        self.progress = 0.0
 
570
        self.paused = False
 
571
        self.stopped = False
 
572
        self.active = True
 
573
        self.results = None
 
574
        self.counter = self.statusid_counter.next()
 
575
        self.started = time.time()
 
576
 
 
577
    def get_started(self):
 
578
        return self.started
 
579
    def get_storage_index(self):
 
580
        return self.storage_index
 
581
    def get_size(self):
 
582
        return self.size
 
583
    def using_helper(self):
 
584
        return self.helper
 
585
    def get_status(self):
 
586
        status = self.status
 
587
        if self.paused:
 
588
            status += " (output paused)"
 
589
        if self.stopped:
 
590
            status += " (output stopped)"
 
591
        return status
 
592
    def get_progress(self):
 
593
        return self.progress
 
594
    def get_active(self):
 
595
        return self.active
 
596
    def get_results(self):
 
597
        return self.results
 
598
    def get_counter(self):
 
599
        return self.counter
 
600
 
 
601
    def set_storage_index(self, si):
 
602
        self.storage_index = si
 
603
    def set_size(self, size):
 
604
        self.size = size
 
605
    def set_helper(self, helper):
 
606
        self.helper = helper
 
607
    def set_status(self, status):
 
608
        self.status = status
 
609
    def set_paused(self, paused):
 
610
        self.paused = paused
 
611
    def set_stopped(self, stopped):
 
612
        self.stopped = stopped
 
613
    def set_progress(self, value):
 
614
        self.progress = value
 
615
    def set_active(self, value):
 
616
        self.active = value
 
617
    def set_results(self, value):
 
618
        self.results = value
 
619
 
 
620
class CiphertextDownloader(log.PrefixingLogMixin):
 
621
    """ I download shares, check their integrity, then decode them, check the
 
622
    integrity of the resulting ciphertext, then and write it to my target.
 
623
    Before I send any new request to a server, I always ask the 'monitor'
 
624
    object that was passed into my constructor whether this task has been
 
625
    cancelled (by invoking its raise_if_cancelled() method)."""
 
626
    implements(IPushProducer)
 
627
    _status = None
 
628
 
 
629
    def __init__(self, storage_broker, v, target, monitor):
 
630
 
 
631
        precondition(IStorageBroker.providedBy(storage_broker), storage_broker)
 
632
        precondition(IVerifierURI.providedBy(v), v)
 
633
        precondition(IDownloadTarget.providedBy(target), target)
 
634
 
 
635
        prefix=base32.b2a_l(v.storage_index[:8], 60)
 
636
        log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
 
637
        self._storage_broker = storage_broker
 
638
 
 
639
        self._verifycap = v
 
640
        self._storage_index = v.storage_index
 
641
        self._uri_extension_hash = v.uri_extension_hash
 
642
 
 
643
        self._started = time.time()
 
644
        self._status = s = DownloadStatus()
 
645
        s.set_status("Starting")
 
646
        s.set_storage_index(self._storage_index)
 
647
        s.set_size(self._verifycap.size)
 
648
        s.set_helper(False)
 
649
        s.set_active(True)
 
650
 
 
651
        self._results = DownloadResults()
 
652
        s.set_results(self._results)
 
653
        self._results.file_size = self._verifycap.size
 
654
        self._results.timings["servers_peer_selection"] = {}
 
655
        self._results.timings["fetch_per_server"] = {}
 
656
        self._results.timings["cumulative_fetch"] = 0.0
 
657
        self._results.timings["cumulative_decode"] = 0.0
 
658
        self._results.timings["cumulative_decrypt"] = 0.0
 
659
        self._results.timings["paused"] = 0.0
 
660
 
 
661
        self._paused = False
 
662
        self._stopped = False
 
663
        if IConsumer.providedBy(target):
 
664
            target.registerProducer(self, True)
 
665
        self._target = target
 
666
        self._target.set_storageindex(self._storage_index) # Repairer (uploader) needs the storageindex.
 
667
        self._monitor = monitor
 
668
        self._opened = False
 
669
 
 
670
        self.active_buckets = {} # k: shnum, v: bucket
 
671
        self._share_buckets = [] # list of (sharenum, bucket) tuples
 
672
        self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
 
673
 
 
674
        self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, }
 
675
 
 
676
        self._ciphertext_hasher = hashutil.crypttext_hasher()
 
677
 
 
678
        self._bytes_done = 0
 
679
        self._status.set_progress(float(self._bytes_done)/self._verifycap.size)
 
680
 
 
681
        # _got_uri_extension() will create the following:
 
682
        # self._crypttext_hash_tree
 
683
        # self._share_hash_tree
 
684
        # self._current_segnum = 0
 
685
        # self._vup # ValidatedExtendedURIProxy
 
686
 
 
687
    def pauseProducing(self):
 
688
        if self._paused:
 
689
            return
 
690
        self._paused = defer.Deferred()
 
691
        self._paused_at = time.time()
 
692
        if self._status:
 
693
            self._status.set_paused(True)
 
694
 
 
695
    def resumeProducing(self):
 
696
        if self._paused:
 
697
            paused_for = time.time() - self._paused_at
 
698
            self._results.timings['paused'] += paused_for
 
699
            p = self._paused
 
700
            self._paused = None
 
701
            eventually(p.callback, None)
 
702
            if self._status:
 
703
                self._status.set_paused(False)
 
704
 
 
705
    def stopProducing(self):
 
706
        self.log("Download.stopProducing")
 
707
        self._stopped = True
 
708
        self.resumeProducing()
 
709
        if self._status:
 
710
            self._status.set_stopped(True)
 
711
            self._status.set_active(False)
 
712
 
 
713
    def start(self):
 
714
        self.log("starting download")
 
715
 
 
716
        # first step: who should we download from?
 
717
        d = defer.maybeDeferred(self._get_all_shareholders)
 
718
        d.addCallback(self._got_all_shareholders)
 
719
        # now get the uri_extension block from somebody and integrity check it and parse and validate its contents
 
720
        d.addCallback(self._obtain_uri_extension)
 
721
        d.addCallback(self._get_crypttext_hash_tree)
 
722
        # once we know that, we can download blocks from everybody
 
723
        d.addCallback(self._download_all_segments)
 
724
        def _finished(res):
 
725
            if self._status:
 
726
                self._status.set_status("Finished")
 
727
                self._status.set_active(False)
 
728
                self._status.set_paused(False)
 
729
            if IConsumer.providedBy(self._target):
 
730
                self._target.unregisterProducer()
 
731
            return res
 
732
        d.addBoth(_finished)
 
733
        def _failed(why):
 
734
            if self._status:
 
735
                self._status.set_status("Failed")
 
736
                self._status.set_active(False)
 
737
            if why.check(DownloadStopped):
 
738
                # DownloadStopped just means the consumer aborted the download; not so scary.
 
739
                self.log("download stopped", level=log.UNUSUAL)
 
740
            else:
 
741
                # This is really unusual, and deserves maximum forensics.
 
742
                self.log("download failed!", failure=why, level=log.SCARY, umid="lp1vaQ")
 
743
            return why
 
744
        d.addErrback(_failed)
 
745
        d.addCallback(self._done)
 
746
        return d
 
747
 
 
748
    def _get_all_shareholders(self):
 
749
        dl = []
 
750
        sb = self._storage_broker
 
751
        servers = sb.get_servers_for_index(self._storage_index)
 
752
        if not servers:
 
753
            raise NoServersError("broker gave us no servers!")
 
754
        for (peerid,ss) in servers:
 
755
            self.log(format="sending DYHB to [%(peerid)s]",
 
756
                     peerid=idlib.shortnodeid_b2a(peerid),
 
757
                     level=log.NOISY, umid="rT03hg")
 
758
            d = ss.callRemote("get_buckets", self._storage_index)
 
759
            d.addCallbacks(self._got_response, self._got_error,
 
760
                           callbackArgs=(peerid,))
 
761
            dl.append(d)
 
762
        self._responses_received = 0
 
763
        self._queries_sent = len(dl)
 
764
        if self._status:
 
765
            self._status.set_status("Locating Shares (%d/%d)" %
 
766
                                    (self._responses_received,
 
767
                                     self._queries_sent))
 
768
        return defer.DeferredList(dl)
 
769
 
 
770
    def _got_response(self, buckets, peerid):
 
771
        self.log(format="got results from [%(peerid)s]: shnums %(shnums)s",
 
772
                 peerid=idlib.shortnodeid_b2a(peerid),
 
773
                 shnums=sorted(buckets.keys()),
 
774
                 level=log.NOISY, umid="o4uwFg")
 
775
        self._responses_received += 1
 
776
        if self._results:
 
777
            elapsed = time.time() - self._started
 
778
            self._results.timings["servers_peer_selection"][peerid] = elapsed
 
779
        if self._status:
 
780
            self._status.set_status("Locating Shares (%d/%d)" %
 
781
                                    (self._responses_received,
 
782
                                     self._queries_sent))
 
783
        for sharenum, bucket in buckets.iteritems():
 
784
            b = layout.ReadBucketProxy(bucket, peerid, self._storage_index)
 
785
            self.add_share_bucket(sharenum, b)
 
786
 
 
787
            if self._results:
 
788
                if peerid not in self._results.servermap:
 
789
                    self._results.servermap[peerid] = set()
 
790
                self._results.servermap[peerid].add(sharenum)
 
791
 
 
792
    def add_share_bucket(self, sharenum, bucket):
 
793
        # this is split out for the benefit of test_encode.py
 
794
        self._share_buckets.append( (sharenum, bucket) )
 
795
 
 
796
    def _got_error(self, f):
 
797
        level = log.WEIRD
 
798
        if f.check(DeadReferenceError):
 
799
            level = log.UNUSUAL
 
800
        self.log("Error during get_buckets", failure=f, level=level,
 
801
                         umid="3uuBUQ")
 
802
 
 
803
    def bucket_failed(self, vbucket):
 
804
        shnum = vbucket.sharenum
 
805
        del self.active_buckets[shnum]
 
806
        s = self._share_vbuckets[shnum]
 
807
        # s is a set of ValidatedReadBucketProxy instances
 
808
        s.remove(vbucket)
 
809
        # ... which might now be empty
 
810
        if not s:
 
811
            # there are no more buckets which can provide this share, so
 
812
            # remove the key. This may prompt us to use a different share.
 
813
            del self._share_vbuckets[shnum]
 
814
 
 
815
    def _got_all_shareholders(self, res):
 
816
        if self._results:
 
817
            now = time.time()
 
818
            self._results.timings["peer_selection"] = now - self._started
 
819
 
 
820
        if len(self._share_buckets) < self._verifycap.needed_shares:
 
821
            msg = "Failed to get enough shareholders: have %d, need %d" \
 
822
                  % (len(self._share_buckets), self._verifycap.needed_shares)
 
823
            if self._share_buckets:
 
824
                raise NotEnoughSharesError(msg)
 
825
            else:
 
826
                raise NoSharesError(msg)
 
827
 
 
828
        #for s in self._share_vbuckets.values():
 
829
        #    for vb in s:
 
830
        #        assert isinstance(vb, ValidatedReadBucketProxy), \
 
831
        #               "vb is %s but should be a ValidatedReadBucketProxy" % (vb,)
 
832
 
 
833
    def _obtain_uri_extension(self, ignored):
 
834
        # all shareholders are supposed to have a copy of uri_extension, and
 
835
        # all are supposed to be identical. We compute the hash of the data
 
836
        # that comes back, and compare it against the version in our URI. If
 
837
        # they don't match, ignore their data and try someone else.
 
838
        if self._status:
 
839
            self._status.set_status("Obtaining URI Extension")
 
840
 
 
841
        uri_extension_fetch_started = time.time()
 
842
 
 
843
        vups = []
 
844
        for sharenum, bucket in self._share_buckets:
 
845
            vups.append(ValidatedExtendedURIProxy(bucket, self._verifycap, self._fetch_failures))
 
846
        vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid)
 
847
        d = vto.start()
 
848
 
 
849
        def _got_uri_extension(vup):
 
850
            precondition(isinstance(vup, ValidatedExtendedURIProxy), vup)
 
851
            if self._results:
 
852
                elapsed = time.time() - uri_extension_fetch_started
 
853
                self._results.timings["uri_extension"] = elapsed
 
854
 
 
855
            self._vup = vup
 
856
            self._codec = codec.CRSDecoder()
 
857
            self._codec.set_params(self._vup.segment_size, self._verifycap.needed_shares, self._verifycap.total_shares)
 
858
            self._tail_codec = codec.CRSDecoder()
 
859
            self._tail_codec.set_params(self._vup.tail_segment_size, self._verifycap.needed_shares, self._verifycap.total_shares)
 
860
 
 
861
            self._current_segnum = 0
 
862
 
 
863
            self._share_hash_tree = hashtree.IncompleteHashTree(self._verifycap.total_shares)
 
864
            self._share_hash_tree.set_hashes({0: vup.share_root_hash})
 
865
 
 
866
            self._crypttext_hash_tree = hashtree.IncompleteHashTree(self._vup.num_segments)
 
867
            self._crypttext_hash_tree.set_hashes({0: self._vup.crypttext_root_hash})
 
868
 
 
869
            # Repairer (uploader) needs the encodingparams.
 
870
            self._target.set_encodingparams((
 
871
                self._verifycap.needed_shares,
 
872
                self._verifycap.total_shares, # I don't think the target actually cares about "happy".
 
873
                self._verifycap.total_shares,
 
874
                self._vup.segment_size
 
875
                ))
 
876
        d.addCallback(_got_uri_extension)
 
877
        return d
 
878
 
 
879
    def _get_crypttext_hash_tree(self, res):
 
880
        vchtps = []
 
881
        for sharenum, bucket in self._share_buckets:
 
882
            vchtp = ValidatedCrypttextHashTreeProxy(bucket, self._crypttext_hash_tree, self._vup.num_segments, self._fetch_failures)
 
883
            vchtps.append(vchtp)
 
884
 
 
885
        _get_crypttext_hash_tree_started = time.time()
 
886
        if self._status:
 
887
            self._status.set_status("Retrieving crypttext hash tree")
 
888
 
 
889
        vto = ValidatedThingObtainer(vchtps , debugname="vchtps", log_id=self._parentmsgid)
 
890
        d = vto.start()
 
891
 
 
892
        def _got_crypttext_hash_tree(res):
 
893
            # Good -- the self._crypttext_hash_tree that we passed to vchtp is now populated
 
894
            # with hashes.
 
895
            if self._results:
 
896
                elapsed = time.time() - _get_crypttext_hash_tree_started
 
897
                self._results.timings["hashtrees"] = elapsed
 
898
        d.addCallback(_got_crypttext_hash_tree)
 
899
        return d
 
900
 
 
901
    def _activate_enough_buckets(self):
 
902
        """either return a mapping from shnum to a ValidatedReadBucketProxy that can
 
903
        provide data for that share, or raise NotEnoughSharesError"""
 
904
 
 
905
        while len(self.active_buckets) < self._verifycap.needed_shares:
 
906
            # need some more
 
907
            handled_shnums = set(self.active_buckets.keys())
 
908
            available_shnums = set(self._share_vbuckets.keys())
 
909
            potential_shnums = list(available_shnums - handled_shnums)
 
910
            if len(potential_shnums) < (self._verifycap.needed_shares - len(self.active_buckets)):
 
911
                have = len(potential_shnums) + len(self.active_buckets)
 
912
                msg = "Unable to activate enough shares: have %d, need %d" \
 
913
                      % (have, self._verifycap.needed_shares)
 
914
                if have:
 
915
                    raise NotEnoughSharesError(msg)
 
916
                else:
 
917
                    raise NoSharesError(msg)
 
918
            # For the next share, choose a primary share if available, else a randomly chosen
 
919
            # secondary share.
 
920
            potential_shnums.sort()
 
921
            if potential_shnums[0] < self._verifycap.needed_shares:
 
922
                shnum = potential_shnums[0]
 
923
            else:
 
924
                shnum = random.choice(potential_shnums)
 
925
            # and a random bucket that will provide it
 
926
            validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
 
927
            self.active_buckets[shnum] = validated_bucket
 
928
        return self.active_buckets
 
929
 
 
930
 
 
931
    def _download_all_segments(self, res):
 
932
        for sharenum, bucket in self._share_buckets:
 
933
            vbucket = ValidatedReadBucketProxy(sharenum, bucket, self._share_hash_tree, self._vup.num_segments, self._vup.block_size, self._vup.share_size)
 
934
            self._share_vbuckets.setdefault(sharenum, set()).add(vbucket)
 
935
 
 
936
        # after the above code, self._share_vbuckets contains enough
 
937
        # buckets to complete the download, and some extra ones to
 
938
        # tolerate some buckets dropping out or having
 
939
        # errors. self._share_vbuckets is a dictionary that maps from
 
940
        # shnum to a set of ValidatedBuckets, which themselves are
 
941
        # wrappers around RIBucketReader references.
 
942
        self.active_buckets = {} # k: shnum, v: ValidatedReadBucketProxy instance
 
943
 
 
944
        self._started_fetching = time.time()
 
945
 
 
946
        d = defer.succeed(None)
 
947
        for segnum in range(self._vup.num_segments):
 
948
            d.addCallback(self._download_segment, segnum)
 
949
            # this pause, at the end of write, prevents pre-fetch from
 
950
            # happening until the consumer is ready for more data.
 
951
            d.addCallback(self._check_for_pause)
 
952
        return d
 
953
 
 
954
    def _check_for_pause(self, res):
 
955
        if self._paused:
 
956
            d = defer.Deferred()
 
957
            self._paused.addCallback(lambda ignored: d.callback(res))
 
958
            return d
 
959
        if self._stopped:
 
960
            raise DownloadStopped("our Consumer called stopProducing()")
 
961
        self._monitor.raise_if_cancelled()
 
962
        return res
 
963
 
 
964
    def _download_segment(self, res, segnum):
 
965
        if self._status:
 
966
            self._status.set_status("Downloading segment %d of %d" %
 
967
                                    (segnum+1, self._vup.num_segments))
 
968
        self.log("downloading seg#%d of %d (%d%%)"
 
969
                 % (segnum, self._vup.num_segments,
 
970
                    100.0 * segnum / self._vup.num_segments))
 
971
        # memory footprint: when the SegmentDownloader finishes pulling down
 
972
        # all shares, we have 1*segment_size of usage.
 
973
        segmentdler = SegmentDownloader(self, segnum, self._verifycap.needed_shares,
 
974
                                        self._results)
 
975
        started = time.time()
 
976
        d = segmentdler.start()
 
977
        def _finished_fetching(res):
 
978
            elapsed = time.time() - started
 
979
            self._results.timings["cumulative_fetch"] += elapsed
 
980
            return res
 
981
        if self._results:
 
982
            d.addCallback(_finished_fetching)
 
983
        # pause before using more memory
 
984
        d.addCallback(self._check_for_pause)
 
985
        # while the codec does its job, we hit 2*segment_size
 
986
        def _started_decode(res):
 
987
            self._started_decode = time.time()
 
988
            return res
 
989
        if self._results:
 
990
            d.addCallback(_started_decode)
 
991
        if segnum + 1 == self._vup.num_segments:
 
992
            codec = self._tail_codec
 
993
        else:
 
994
            codec = self._codec
 
995
        d.addCallback(lambda (shares, shareids): codec.decode(shares, shareids))
 
996
        # once the codec is done, we drop back to 1*segment_size, because
 
997
        # 'shares' goes out of scope. The memory usage is all in the
 
998
        # plaintext now, spread out into a bunch of tiny buffers.
 
999
        def _finished_decode(res):
 
1000
            elapsed = time.time() - self._started_decode
 
1001
            self._results.timings["cumulative_decode"] += elapsed
 
1002
            return res
 
1003
        if self._results:
 
1004
            d.addCallback(_finished_decode)
 
1005
 
 
1006
        # pause/check-for-stop just before writing, to honor stopProducing
 
1007
        d.addCallback(self._check_for_pause)
 
1008
        d.addCallback(self._got_segment)
 
1009
        return d
 
1010
 
 
1011
    def _got_segment(self, buffers):
 
1012
        precondition(self._crypttext_hash_tree)
 
1013
        started_decrypt = time.time()
 
1014
        self._status.set_progress(float(self._current_segnum)/self._verifycap.size)
 
1015
 
 
1016
        if self._current_segnum + 1 == self._vup.num_segments:
 
1017
            # This is the last segment.
 
1018
            # Trim off any padding added by the upload side.  We never send empty segments. If
 
1019
            # the data was an exact multiple of the segment size, the last segment will be full.
 
1020
            tail_buf_size = mathutil.div_ceil(self._vup.tail_segment_size, self._verifycap.needed_shares)
 
1021
            num_buffers_used = mathutil.div_ceil(self._vup.tail_data_size, tail_buf_size)
 
1022
            # Remove buffers which don't contain any part of the tail.
 
1023
            del buffers[num_buffers_used:]
 
1024
            # Remove the past-the-tail-part of the last buffer.
 
1025
            tail_in_last_buf = self._vup.tail_data_size % tail_buf_size
 
1026
            if tail_in_last_buf == 0:
 
1027
                tail_in_last_buf = tail_buf_size
 
1028
            buffers[-1] = buffers[-1][:tail_in_last_buf]
 
1029
 
 
1030
        # First compute the hash of this segment and check that it fits.
 
1031
        ch = hashutil.crypttext_segment_hasher()
 
1032
        for buffer in buffers:
 
1033
            self._ciphertext_hasher.update(buffer)
 
1034
            ch.update(buffer)
 
1035
        self._crypttext_hash_tree.set_hashes(leaves={self._current_segnum: ch.digest()})
 
1036
 
 
1037
        # Then write this segment to the target.
 
1038
        if not self._opened:
 
1039
            self._opened = True
 
1040
            self._target.open(self._verifycap.size)
 
1041
 
 
1042
        for buffer in buffers:
 
1043
            self._target.write(buffer)
 
1044
            self._bytes_done += len(buffer)
 
1045
 
 
1046
        self._status.set_progress(float(self._bytes_done)/self._verifycap.size)
 
1047
        self._current_segnum += 1
 
1048
 
 
1049
        if self._results:
 
1050
            elapsed = time.time() - started_decrypt
 
1051
            self._results.timings["cumulative_decrypt"] += elapsed
 
1052
 
 
1053
    def _done(self, res):
 
1054
        self.log("download done")
 
1055
        if self._results:
 
1056
            now = time.time()
 
1057
            self._results.timings["total"] = now - self._started
 
1058
            self._results.timings["segments"] = now - self._started_fetching
 
1059
        if self._vup.crypttext_hash:
 
1060
            _assert(self._vup.crypttext_hash == self._ciphertext_hasher.digest(),
 
1061
                    "bad crypttext_hash: computed=%s, expected=%s" %
 
1062
                    (base32.b2a(self._ciphertext_hasher.digest()),
 
1063
                     base32.b2a(self._vup.crypttext_hash)))
 
1064
        _assert(self._bytes_done == self._verifycap.size, self._bytes_done, self._verifycap.size)
 
1065
        self._status.set_progress(1)
 
1066
        self._target.close()
 
1067
        return self._target.finish()
 
1068
    def get_download_status(self):
 
1069
        return self._status
 
1070
 
 
1071
 
 
1072
class FileName:
 
1073
    implements(IDownloadTarget)
 
1074
    def __init__(self, filename):
 
1075
        self._filename = filename
 
1076
        self.f = None
 
1077
    def open(self, size):
 
1078
        self.f = open(self._filename, "wb")
 
1079
        return self.f
 
1080
    def write(self, data):
 
1081
        self.f.write(data)
 
1082
    def close(self):
 
1083
        if self.f:
 
1084
            self.f.close()
 
1085
    def fail(self, why):
 
1086
        if self.f:
 
1087
            self.f.close()
 
1088
            os.unlink(self._filename)
 
1089
    def register_canceller(self, cb):
 
1090
        pass # we won't use it
 
1091
    def finish(self):
 
1092
        pass
 
1093
    # The following methods are just because the target might be a repairer.DownUpConnector,
 
1094
    # and just because the current CHKUpload object expects to find the storage index and
 
1095
    # encoding parameters in its Uploadable.
 
1096
    def set_storageindex(self, storageindex):
 
1097
        pass
 
1098
    def set_encodingparams(self, encodingparams):
 
1099
        pass
 
1100
 
 
1101
class Data:
 
1102
    implements(IDownloadTarget)
 
1103
    def __init__(self):
 
1104
        self._data = []
 
1105
    def open(self, size):
 
1106
        pass
 
1107
    def write(self, data):
 
1108
        self._data.append(data)
 
1109
    def close(self):
 
1110
        self.data = "".join(self._data)
 
1111
        del self._data
 
1112
    def fail(self, why):
 
1113
        del self._data
 
1114
    def register_canceller(self, cb):
 
1115
        pass # we won't use it
 
1116
    def finish(self):
 
1117
        return self.data
 
1118
    # The following methods are just because the target might be a repairer.DownUpConnector,
 
1119
    # and just because the current CHKUpload object expects to find the storage index and
 
1120
    # encoding parameters in its Uploadable.
 
1121
    def set_storageindex(self, storageindex):
 
1122
        pass
 
1123
    def set_encodingparams(self, encodingparams):
 
1124
        pass
 
1125
 
 
1126
class FileHandle:
 
1127
    """Use me to download data to a pre-defined filehandle-like object. I
 
1128
    will use the target's write() method. I will *not* close the filehandle:
 
1129
    I leave that up to the originator of the filehandle. The download process
 
1130
    will return the filehandle when it completes.
 
1131
    """
 
1132
    implements(IDownloadTarget)
 
1133
    def __init__(self, filehandle):
 
1134
        self._filehandle = filehandle
 
1135
    def open(self, size):
 
1136
        pass
 
1137
    def write(self, data):
 
1138
        self._filehandle.write(data)
 
1139
    def close(self):
 
1140
        # the originator of the filehandle reserves the right to close it
 
1141
        pass
 
1142
    def fail(self, why):
 
1143
        pass
 
1144
    def register_canceller(self, cb):
 
1145
        pass
 
1146
    def finish(self):
 
1147
        return self._filehandle
 
1148
    # The following methods are just because the target might be a repairer.DownUpConnector,
 
1149
    # and just because the current CHKUpload object expects to find the storage index and
 
1150
    # encoding parameters in its Uploadable.
 
1151
    def set_storageindex(self, storageindex):
 
1152
        pass
 
1153
    def set_encodingparams(self, encodingparams):
 
1154
        pass
 
1155
 
 
1156
class ConsumerAdapter:
 
1157
    implements(IDownloadTarget, IConsumer)
 
1158
    def __init__(self, consumer):
 
1159
        self._consumer = consumer
 
1160
 
 
1161
    def registerProducer(self, producer, streaming):
 
1162
        self._consumer.registerProducer(producer, streaming)
 
1163
    def unregisterProducer(self):
 
1164
        self._consumer.unregisterProducer()
 
1165
 
 
1166
    def open(self, size):
 
1167
        pass
 
1168
    def write(self, data):
 
1169
        self._consumer.write(data)
 
1170
    def close(self):
 
1171
        pass
 
1172
 
 
1173
    def fail(self, why):
 
1174
        pass
 
1175
    def register_canceller(self, cb):
 
1176
        pass
 
1177
    def finish(self):
 
1178
        return self._consumer
 
1179
    # The following methods are just because the target might be a repairer.DownUpConnector,
 
1180
    # and just because the current CHKUpload object expects to find the storage index and
 
1181
    # encoding parameters in its Uploadable.
 
1182
    def set_storageindex(self, storageindex):
 
1183
        pass
 
1184
    def set_encodingparams(self, encodingparams):
 
1185
        pass
 
1186
 
 
1187
 
 
1188
class Downloader(service.MultiService):
 
1189
    """I am a service that allows file downloading.
 
1190
    """
 
1191
    # TODO: in fact, this service only downloads immutable files (URI:CHK:).
 
1192
    # It is scheduled to go away, to be replaced by filenode.download()
 
1193
    implements(IDownloader)
 
1194
    name = "downloader"
 
1195
 
 
1196
    def __init__(self, stats_provider=None):
 
1197
        service.MultiService.__init__(self)
 
1198
        self.stats_provider = stats_provider
 
1199
        self._all_downloads = weakref.WeakKeyDictionary() # for debugging
 
1200
 
 
1201
    def download(self, u, t, _log_msg_id=None, monitor=None, history=None):
 
1202
        assert self.parent
 
1203
        assert self.running
 
1204
        u = IFileURI(u)
 
1205
        t = IDownloadTarget(t)
 
1206
        assert t.write
 
1207
        assert t.close
 
1208
 
 
1209
        assert isinstance(u, uri.CHKFileURI)
 
1210
        if self.stats_provider:
 
1211
            # these counters are meant for network traffic, and don't
 
1212
            # include LIT files
 
1213
            self.stats_provider.count('downloader.files_downloaded', 1)
 
1214
            self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
 
1215
        storage_broker = self.parent.get_storage_broker()
 
1216
 
 
1217
        target = DecryptingTarget(t, u.key, _log_msg_id=_log_msg_id)
 
1218
        if not monitor:
 
1219
            monitor=Monitor()
 
1220
        dl = CiphertextDownloader(storage_broker, u.get_verify_cap(), target,
 
1221
                                  monitor=monitor)
 
1222
        self._all_downloads[dl] = None
 
1223
        if history:
 
1224
            history.add_download(dl.get_download_status())
 
1225
        d = dl.start()
 
1226
        return d
 
1227
 
 
1228
    # utility functions
 
1229
    def download_to_data(self, uri, _log_msg_id=None, history=None):
 
1230
        return self.download(uri, Data(), _log_msg_id=_log_msg_id, history=history)
 
1231
    def download_to_filename(self, uri, filename, _log_msg_id=None):
 
1232
        return self.download(uri, FileName(filename), _log_msg_id=_log_msg_id)
 
1233
    def download_to_filehandle(self, uri, filehandle, _log_msg_id=None):
 
1234
        return self.download(uri, FileHandle(filehandle), _log_msg_id=_log_msg_id)