source file: /home/buildslave/tahoe/edgy/build/src/allmydata/immutable/download.py
file stats: 795 lines, 753 executed: 94.7% covered
   1. 
   2. import os, random, weakref, itertools, time
   3. from zope.interface import implements
   4. from twisted.internet import defer
   5. from twisted.internet.interfaces import IPushProducer, IConsumer
   6. from twisted.application import service
   7. from foolscap import DeadReferenceError
   8. from foolscap.eventual import eventually
   9. 
  10. from allmydata.util import base32, mathutil, hashutil, log
  11. from allmydata.util.assertutil import _assert
  12. from allmydata import codec, hashtree, storage, uri
  13. from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
  14.      IDownloadStatus, IDownloadResults
  15. from allmydata.immutable.encode import NotEnoughSharesError
  16. from pycryptopp.cipher.aes import AES
  17. 
  18. class HaveAllPeersError(Exception):
  19.     # we use this to jump out of the loop
  20.     pass
  21. 
  22. class IntegrityCheckError(Exception):
  23.     pass
  24. 
  25. class BadURIExtensionHashValue(IntegrityCheckError):
  26.     pass
  27. class BadURIExtension(IntegrityCheckError):
  28.     pass
  29. class BadPlaintextHashValue(IntegrityCheckError):
  30.     pass
  31. class BadCrypttextHashValue(IntegrityCheckError):
  32.     pass
  33. 
  34. class DownloadStopped(Exception):
  35.     pass
  36. 
  37. class DownloadResults:
  38.     implements(IDownloadResults)
  39. 
  40.     def __init__(self):
  41.         self.servers_used = set()
  42.         self.server_problems = {}
  43.         self.servermap = {}
  44.         self.timings = {}
  45.         self.file_size = None
  46. 
  47. class Output:
  48.     def __init__(self, downloadable, key, total_length, log_parent,
  49.                  download_status):
  50.         self.downloadable = downloadable
  51.         self._decryptor = AES(key)
  52.         self._crypttext_hasher = hashutil.crypttext_hasher()
  53.         self._plaintext_hasher = hashutil.plaintext_hasher()
  54.         self.length = 0
  55.         self.total_length = total_length
  56.         self._segment_number = 0
  57.         self._plaintext_hash_tree = None
  58.         self._crypttext_hash_tree = None
  59.         self._opened = False
  60.         self._log_parent = log_parent
  61.         self._status = download_status
  62.         self._status.set_progress(0.0)
  63. 
  64.     def log(self, *args, **kwargs):
  65.         if "parent" not in kwargs:
  66.             kwargs["parent"] = self._log_parent
  67.         if "facility" not in kwargs:
  68.             kwargs["facility"] = "download.output"
  69.         return log.msg(*args, **kwargs)
  70. 
  71.     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
  72.         self._plaintext_hash_tree = plaintext_hashtree
  73.         self._crypttext_hash_tree = crypttext_hashtree
  74. 
  75.     def write_segment(self, crypttext):
  76.         self.length += len(crypttext)
  77.         self._status.set_progress( float(self.length) / self.total_length )
  78. 
  79.         # memory footprint: 'crypttext' is the only segment_size usage
  80.         # outstanding. While we decrypt it into 'plaintext', we hit
  81.         # 2*segment_size.
  82.         self._crypttext_hasher.update(crypttext)
  83.         if self._crypttext_hash_tree:
  84.             ch = hashutil.crypttext_segment_hasher()
  85.             ch.update(crypttext)
  86.             crypttext_leaves = {self._segment_number: ch.digest()}
  87.             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
  88.                      bytes=len(crypttext),
  89.                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
  90.                      level=log.NOISY)
  91.             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
  92. 
  93.         plaintext = self._decryptor.process(crypttext)
  94.         del crypttext
  95. 
  96.         # now we're back down to 1*segment_size.
  97. 
  98.         self._plaintext_hasher.update(plaintext)
  99.         if self._plaintext_hash_tree:
 100.             ph = hashutil.plaintext_segment_hasher()
 101.             ph.update(plaintext)
 102.             plaintext_leaves = {self._segment_number: ph.digest()}
 103.             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
 104.                      bytes=len(plaintext),
 105.                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
 106.                      level=log.NOISY)
 107.             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
 108. 
 109.         self._segment_number += 1
 110.         # We're still at 1*segment_size. The Downloadable is responsible for
 111.         # any memory usage beyond this.
 112.         if not self._opened:
 113.             self._opened = True
 114.             self.downloadable.open(self.total_length)
 115.         self.downloadable.write(plaintext)
 116. 
 117.     def fail(self, why):
 118.         # this is really unusual, and deserves maximum forensics
 119.         if why.check(DownloadStopped):
 120.             # except DownloadStopped just means the consumer aborted the
 121.             # download, not so scary
 122.             self.log("download stopped", level=log.UNUSUAL)
 123.         else:
 124.             self.log("download failed!", failure=why,
 125.                      level=log.SCARY, umid="lp1vaQ")
 126.         self.downloadable.fail(why)
 127. 
 128.     def close(self):
 129.         self.crypttext_hash = self._crypttext_hasher.digest()
 130.         self.plaintext_hash = self._plaintext_hasher.digest()
 131.         self.log("download finished, closing IDownloadable", level=log.NOISY)
 132.         self.downloadable.close()
 133. 
 134.     def finish(self):
 135.         return self.downloadable.finish()
 136. 
 137. class ValidatedBucket:
 138.     """I am a front-end for a remote storage bucket, responsible for
 139.     retrieving and validating data from that bucket.
 140. 
 141.     My get_block() method is used by BlockDownloaders.
 142.     """
 143. 
 144.     def __init__(self, sharenum, bucket,
 145.                  share_hash_tree, roothash,
 146.                  num_blocks):
 147.         self.sharenum = sharenum
 148.         self.bucket = bucket
 149.         self._share_hash = None # None means not validated yet
 150.         self.share_hash_tree = share_hash_tree
 151.         self._roothash = roothash
 152.         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
 153.         self.started = False
 154. 
 155.     def get_block(self, blocknum):
 156.         if not self.started:
 157.             d = self.bucket.start()
 158.             def _started(res):
 159.                 self.started = True
 160.                 return self.get_block(blocknum)
 161.             d.addCallback(_started)
 162.             return d
 163. 
 164.         # the first time we use this bucket, we need to fetch enough elements
 165.         # of the share hash tree to validate it from our share hash up to the
 166.         # hashroot.
 167.         if not self._share_hash:
 168.             d1 = self.bucket.get_share_hashes()
 169.         else:
 170.             d1 = defer.succeed([])
 171. 
 172.         # we might need to grab some elements of our block hash tree, to
 173.         # validate the requested block up to the share hash
 174.         needed = self.block_hash_tree.needed_hashes(blocknum)
 175.         if needed:
 176.             # TODO: get fewer hashes, use get_block_hashes(needed)
 177.             d2 = self.bucket.get_block_hashes()
 178.         else:
 179.             d2 = defer.succeed([])
 180. 
 181.         d3 = self.bucket.get_block(blocknum)
 182. 
 183.         d = defer.gatherResults([d1, d2, d3])
 184.         d.addCallback(self._got_data, blocknum)
 185.         return d
 186. 
 187.     def _got_data(self, res, blocknum):
 188.         sharehashes, blockhashes, blockdata = res
 189.         blockhash = None # to make logging it safe
 190. 
 191.         try:
 192.             if not self._share_hash:
 193.                 sh = dict(sharehashes)
 194.                 sh[0] = self._roothash # always use our own root, from the URI
 195.                 sht = self.share_hash_tree
 196.                 if sht.get_leaf_index(self.sharenum) not in sh:
 197.                     raise hashtree.NotEnoughHashesError
 198.                 sht.set_hashes(sh)
 199.                 self._share_hash = sht.get_leaf(self.sharenum)
 200. 
 201.             blockhash = hashutil.block_hash(blockdata)
 202.             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
 203.             #        "%r .. %r: %s" %
 204.             #        (self.sharenum, blocknum, len(blockdata),
 205.             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
 206. 
 207.             # we always validate the blockhash
 208.             bh = dict(enumerate(blockhashes))
 209.             # replace blockhash root with validated value
 210.             bh[0] = self._share_hash
 211.             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
 212. 
 213.         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
 214.             # log.WEIRD: indicates undetected disk/network error, or more
 215.             # likely a programming error
 216.             log.msg("hash failure in block=%d, shnum=%d on %s" %
 217.                     (blocknum, self.sharenum, self.bucket))
 218.             if self._share_hash:
 219.                 log.msg(""" failure occurred when checking the block_hash_tree.
 220.                 This suggests that either the block data was bad, or that the
 221.                 block hashes we received along with it were bad.""")
 222.             else:
 223.                 log.msg(""" the failure probably occurred when checking the
 224.                 share_hash_tree, which suggests that the share hashes we
 225.                 received from the remote peer were bad.""")
 226.             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
 227.             log.msg(" block length: %d" % len(blockdata))
 228.             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
 229.             if len(blockdata) < 100:
 230.                 log.msg(" block data: %r" % (blockdata,))
 231.             else:
 232.                 log.msg(" block data start/end: %r .. %r" %
 233.                         (blockdata[:50], blockdata[-50:]))
 234.             log.msg(" root hash: %s" % base32.b2a(self._roothash))
 235.             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
 236.             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
 237.             lines = []
 238.             for i,h in sorted(sharehashes):
 239.                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
 240.             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
 241.             lines = []
 242.             for i,h in enumerate(blockhashes):
 243.                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
 244.             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
 245.             raise
 246. 
 247.         # If we made it here, the block is good. If the hash trees didn't
 248.         # like what they saw, they would have raised a BadHashError, causing
 249.         # our caller to see a Failure and thus ignore this block (as well as
 250.         # dropping this bucket).
 251.         return blockdata
 252. 
 253. 
 254. 
 255. class BlockDownloader:
 256.     """I am responsible for downloading a single block (from a single bucket)
 257.     for a single segment.
 258. 
 259.     I am a child of the SegmentDownloader.
 260.     """
 261. 
 262.     def __init__(self, vbucket, blocknum, parent, results):
 263.         self.vbucket = vbucket
 264.         self.blocknum = blocknum
 265.         self.parent = parent
 266.         self.results = results
 267.         self._log_number = self.parent.log("starting block %d" % blocknum)
 268. 
 269.     def log(self, *args, **kwargs):
 270.         if "parent" not in kwargs:
 271.             kwargs["parent"] = self._log_number
 272.         return self.parent.log(*args, **kwargs)
 273. 
 274.     def start(self, segnum):
 275.         lognum = self.log("get_block(segnum=%d)" % segnum)
 276.         started = time.time()
 277.         d = self.vbucket.get_block(segnum)
 278.         d.addCallbacks(self._hold_block, self._got_block_error,
 279.                        callbackArgs=(started, lognum,), errbackArgs=(lognum,))
 280.         return d
 281. 
 282.     def _hold_block(self, data, started, lognum):
 283.         if self.results:
 284.             elapsed = time.time() - started
 285.             peerid = self.vbucket.bucket.get_peerid()
 286.             if peerid not in self.results.timings["fetch_per_server"]:
 287.                 self.results.timings["fetch_per_server"][peerid] = []
 288.             self.results.timings["fetch_per_server"][peerid].append(elapsed)
 289.         self.log("got block", parent=lognum)
 290.         self.parent.hold_block(self.blocknum, data)
 291. 
 292.     def _got_block_error(self, f, lognum):
 293.         level = log.WEIRD
 294.         if f.check(DeadReferenceError):
 295.             level = log.UNUSUAL
 296.         self.log("BlockDownloader[%d] got error" % self.blocknum,
 297.                  failure=f, level=level, parent=lognum, umid="5Z4uHQ")
 298.         if self.results:
 299.             peerid = self.vbucket.bucket.get_peerid()
 300.             self.results.server_problems[peerid] = str(f)
 301.         self.parent.bucket_failed(self.vbucket)
 302. 
 303. class SegmentDownloader:
 304.     """I am responsible for downloading all the blocks for a single segment
 305.     of data.
 306. 
 307.     I am a child of the FileDownloader.
 308.     """
 309. 
 310.     def __init__(self, parent, segmentnumber, needed_shares, results):
 311.         self.parent = parent
 312.         self.segmentnumber = segmentnumber
 313.         self.needed_blocks = needed_shares
 314.         self.blocks = {} # k: blocknum, v: data
 315.         self.results = results
 316.         self._log_number = self.parent.log("starting segment %d" %
 317.                                            segmentnumber)
 318. 
 319.     def log(self, *args, **kwargs):
 320.         if "parent" not in kwargs:
 321.             kwargs["parent"] = self._log_number
 322.         return self.parent.log(*args, **kwargs)
 323. 
 324.     def start(self):
 325.         return self._download()
 326. 
 327.     def _download(self):
 328.         d = self._try()
 329.         def _done(res):
 330.             if len(self.blocks) >= self.needed_blocks:
 331.                 # we only need self.needed_blocks blocks
 332.                 # we want to get the smallest blockids, because they are
 333.                 # more likely to be fast "primary blocks"
 334.                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
 335.                 blocks = []
 336.                 for blocknum in blockids:
 337.                     blocks.append(self.blocks[blocknum])
 338.                 return (blocks, blockids)
 339.             else:
 340.                 return self._download()
 341.         d.addCallback(_done)
 342.         return d
 343. 
 344.     def _try(self):
 345.         # fill our set of active buckets, maybe raising NotEnoughSharesError
 346.         active_buckets = self.parent._activate_enough_buckets()
 347.         # Now we have enough buckets, in self.parent.active_buckets.
 348. 
 349.         # in test cases, bd.start might mutate active_buckets right away, so
 350.         # we need to put off calling start() until we've iterated all the way
 351.         # through it.
 352.         downloaders = []
 353.         for blocknum, vbucket in active_buckets.iteritems():
 354.             bd = BlockDownloader(vbucket, blocknum, self, self.results)
 355.             downloaders.append(bd)
 356.             if self.results:
 357.                 self.results.servers_used.add(vbucket.bucket.get_peerid())
 358.         l = [bd.start(self.segmentnumber) for bd in downloaders]
 359.         return defer.DeferredList(l, fireOnOneErrback=True)
 360. 
 361.     def hold_block(self, blocknum, data):
 362.         self.blocks[blocknum] = data
 363. 
 364.     def bucket_failed(self, vbucket):
 365.         self.parent.bucket_failed(vbucket)
 366. 
 367. class DownloadStatus:
 368.     implements(IDownloadStatus)
 369.     statusid_counter = itertools.count(0)
 370. 
 371.     def __init__(self):
 372.         self.storage_index = None
 373.         self.size = None
 374.         self.helper = False
 375.         self.status = "Not started"
 376.         self.progress = 0.0
 377.         self.paused = False
 378.         self.stopped = False
 379.         self.active = True
 380.         self.results = None
 381.         self.counter = self.statusid_counter.next()
 382.         self.started = time.time()
 383. 
 384.     def get_started(self):
 385.         return self.started
 386.     def get_storage_index(self):
 387.         return self.storage_index
 388.     def get_size(self):
 389.         return self.size
 390.     def using_helper(self):
 391.         return self.helper
 392.     def get_status(self):
 393.         status = self.status
 394.         if self.paused:
 395.             status += " (output paused)"
 396.         if self.stopped:
 397.             status += " (output stopped)"
 398.         return status
 399.     def get_progress(self):
 400.         return self.progress
 401.     def get_active(self):
 402.         return self.active
 403.     def get_results(self):
 404.         return self.results
 405.     def get_counter(self):
 406.         return self.counter
 407. 
 408.     def set_storage_index(self, si):
 409.         self.storage_index = si
 410.     def set_size(self, size):
 411.         self.size = size
 412.     def set_helper(self, helper):
 413.         self.helper = helper
 414.     def set_status(self, status):
 415.         self.status = status
 416.     def set_paused(self, paused):
 417.         self.paused = paused
 418.     def set_stopped(self, stopped):
 419.         self.stopped = stopped
 420.     def set_progress(self, value):
 421.         self.progress = value
 422.     def set_active(self, value):
 423.         self.active = value
 424.     def set_results(self, value):
 425.         self.results = value
 426. 
 427. class FileDownloader:
 428.     implements(IPushProducer)
 429.     check_crypttext_hash = True
 430.     check_plaintext_hash = True
 431.     _status = None
 432. 
 433.     def __init__(self, client, u, downloadable):
 434.         self._client = client
 435. 
 436.         u = IFileURI(u)
 437.         self._storage_index = u.storage_index
 438.         self._uri_extension_hash = u.uri_extension_hash
 439.         self._total_shares = u.total_shares
 440.         self._size = u.size
 441.         self._num_needed_shares = u.needed_shares
 442. 
 443.         self._si_s = storage.si_b2a(self._storage_index)
 444.         self.init_logging()
 445. 
 446.         self._started = time.time()
 447.         self._status = s = DownloadStatus()
 448.         s.set_status("Starting")
 449.         s.set_storage_index(self._storage_index)
 450.         s.set_size(self._size)
 451.         s.set_helper(False)
 452.         s.set_active(True)
 453. 
 454.         self._results = DownloadResults()
 455.         s.set_results(self._results)
 456.         self._results.file_size = self._size
 457.         self._results.timings["servers_peer_selection"] = {}
 458.         self._results.timings["fetch_per_server"] = {}
 459.         self._results.timings["cumulative_fetch"] = 0.0
 460.         self._results.timings["cumulative_decode"] = 0.0
 461.         self._results.timings["cumulative_decrypt"] = 0.0
 462.         self._results.timings["paused"] = 0.0
 463. 
 464.         self._paused = False
 465.         self._stopped = False
 466.         if IConsumer.providedBy(downloadable):
 467.             downloadable.registerProducer(self, True)
 468.         self._downloadable = downloadable
 469.         self._output = Output(downloadable, u.key, self._size, self._log_number,
 470.                               self._status)
 471. 
 472.         self.active_buckets = {} # k: shnum, v: bucket
 473.         self._share_buckets = [] # list of (sharenum, bucket) tuples
 474.         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
 475.         self._uri_extension_sources = []
 476. 
 477.         self._uri_extension_data = None
 478. 
 479.         self._fetch_failures = {"uri_extension": 0,
 480.                                 "plaintext_hashroot": 0,
 481.                                 "plaintext_hashtree": 0,
 482.                                 "crypttext_hashroot": 0,
 483.                                 "crypttext_hashtree": 0,
 484.                                 }
 485. 
 486.     def init_logging(self):
 487.         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
 488.         num = self._client.log(format="FileDownloader(%(si)s): starting",
 489.                                si=storage.si_b2a(self._storage_index))
 490.         self._log_number = num
 491. 
 492.     def log(self, *args, **kwargs):
 493.         if "parent" not in kwargs:
 494.             kwargs["parent"] = self._log_number
 495.         if "facility" not in kwargs:
 496.             kwargs["facility"] = "tahoe.download"
 497.         return log.msg(*args, **kwargs)
 498. 
 499.     def pauseProducing(self):
 500.         if self._paused:
 501.             return
 502.         self._paused = defer.Deferred()
 503.         self._paused_at = time.time()
 504.         if self._status:
 505.             self._status.set_paused(True)
 506. 
 507.     def resumeProducing(self):
 508.         if self._paused:
 509.             paused_for = time.time() - self._paused_at
 510.             self._results.timings['paused'] += paused_for
 511.             p = self._paused
 512.             self._paused = None
 513.             eventually(p.callback, None)
 514.             if self._status:
 515.                 self._status.set_paused(False)
 516. 
 517.     def stopProducing(self):
 518.         self.log("Download.stopProducing")
 519.         self._stopped = True
 520.         self.resumeProducing()
 521.         if self._status:
 522.             self._status.set_stopped(True)
 523.             self._status.set_active(False)
 524. 
 525.     def start(self):
 526.         self.log("starting download")
 527. 
 528.         # first step: who should we download from?
 529.         d = defer.maybeDeferred(self._get_all_shareholders)
 530.         d.addCallback(self._got_all_shareholders)
 531.         # now get the uri_extension block from somebody and validate it
 532.         d.addCallback(self._obtain_uri_extension)
 533.         d.addCallback(self._got_uri_extension)
 534.         d.addCallback(self._get_hashtrees)
 535.         d.addCallback(self._create_validated_buckets)
 536.         # once we know that, we can download blocks from everybody
 537.         d.addCallback(self._download_all_segments)
 538.         def _finished(res):
 539.             if self._status:
 540.                 self._status.set_status("Finished")
 541.                 self._status.set_active(False)
 542.                 self._status.set_paused(False)
 543.             if IConsumer.providedBy(self._downloadable):
 544.                 self._downloadable.unregisterProducer()
 545.             return res
 546.         d.addBoth(_finished)
 547.         def _failed(why):
 548.             if self._status:
 549.                 self._status.set_status("Failed")
 550.                 self._status.set_active(False)
 551.             self._output.fail(why)
 552.             return why
 553.         d.addErrback(_failed)
 554.         d.addCallback(self._done)
 555.         return d
 556. 
 557.     def _get_all_shareholders(self):
 558.         dl = []
 559.         for (peerid,ss) in self._client.get_permuted_peers("storage",
 560.                                                            self._storage_index):
 561.             d = ss.callRemote("get_buckets", self._storage_index)
 562.             d.addCallbacks(self._got_response, self._got_error,
 563.                            callbackArgs=(peerid,))
 564.             dl.append(d)
 565.         self._responses_received = 0
 566.         self._queries_sent = len(dl)
 567.         if self._status:
 568.             self._status.set_status("Locating Shares (%d/%d)" %
 569.                                     (self._responses_received,
 570.                                      self._queries_sent))
 571.         return defer.DeferredList(dl)
 572. 
 573.     def _got_response(self, buckets, peerid):
 574.         self._responses_received += 1
 575.         if self._results:
 576.             elapsed = time.time() - self._started
 577.             self._results.timings["servers_peer_selection"][peerid] = elapsed
 578.         if self._status:
 579.             self._status.set_status("Locating Shares (%d/%d)" %
 580.                                     (self._responses_received,
 581.                                      self._queries_sent))
 582.         for sharenum, bucket in buckets.iteritems():
 583.             b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
 584.             self.add_share_bucket(sharenum, b)
 585.             self._uri_extension_sources.append(b)
 586.             if self._results:
 587.                 if peerid not in self._results.servermap:
 588.                     self._results.servermap[peerid] = set()
 589.                 self._results.servermap[peerid].add(sharenum)
 590. 
 591.     def add_share_bucket(self, sharenum, bucket):
 592.         # this is split out for the benefit of test_encode.py
 593.         self._share_buckets.append( (sharenum, bucket) )
 594. 
 595.     def _got_error(self, f):
 596.         level = log.WEIRD
 597.         if f.check(DeadReferenceError):
 598.             level = log.UNUSUAL
 599.         self._client.log("Error during get_buckets", failure=f, level=level,
 600.                          umid="3uuBUQ")
 601. 
 602.     def bucket_failed(self, vbucket):
 603.         shnum = vbucket.sharenum
 604.         del self.active_buckets[shnum]
 605.         s = self._share_vbuckets[shnum]
 606.         # s is a set of ValidatedBucket instances
 607.         s.remove(vbucket)
 608.         # ... which might now be empty
 609.         if not s:
 610.             # there are no more buckets which can provide this share, so
 611.             # remove the key. This may prompt us to use a different share.
 612.             del self._share_vbuckets[shnum]
 613. 
 614.     def _got_all_shareholders(self, res):
 615.         if self._results:
 616.             now = time.time()
 617.             self._results.timings["peer_selection"] = now - self._started
 618. 
 619.         if len(self._share_buckets) < self._num_needed_shares:
 620.             raise NotEnoughSharesError
 621. 
 622.         #for s in self._share_vbuckets.values():
 623.         #    for vb in s:
 624.         #        assert isinstance(vb, ValidatedBucket), \
 625.         #               "vb is %s but should be a ValidatedBucket" % (vb,)
 626. 
 627.     def _unpack_uri_extension_data(self, data):
 628.         return uri.unpack_extension(data)
 629. 
 630.     def _obtain_uri_extension(self, ignored):
 631.         # all shareholders are supposed to have a copy of uri_extension, and
 632.         # all are supposed to be identical. We compute the hash of the data
 633.         # that comes back, and compare it against the version in our URI. If
 634.         # they don't match, ignore their data and try someone else.
 635.         if self._status:
 636.             self._status.set_status("Obtaining URI Extension")
 637. 
 638.         self._uri_extension_fetch_started = time.time()
 639.         def _validate(proposal, bucket):
 640.             h = hashutil.uri_extension_hash(proposal)
 641.             if h != self._uri_extension_hash:
 642.                 self._fetch_failures["uri_extension"] += 1
 643.                 msg = ("The copy of uri_extension we received from "
 644.                        "%s was bad: wanted %s, got %s" %
 645.                        (bucket,
 646.                         base32.b2a(self._uri_extension_hash),
 647.                         base32.b2a(h)))
 648.                 self.log(msg, level=log.SCARY, umid="jnkTtQ")
 649.                 raise BadURIExtensionHashValue(msg)
 650.             return self._unpack_uri_extension_data(proposal)
 651.         return self._obtain_validated_thing(None,
 652.                                             self._uri_extension_sources,
 653.                                             "uri_extension",
 654.                                             "get_uri_extension", (), _validate)
 655. 
 656.     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
 657.                                 validatorfunc):
 658.         if not sources:
 659.             raise NotEnoughSharesError("started with zero peers while fetching "
 660.                                       "%s" % name)
 661.         bucket = sources[0]
 662.         sources = sources[1:]
 663.         #d = bucket.callRemote(methname, *args)
 664.         d = bucket.startIfNecessary()
 665.         d.addCallback(lambda res: getattr(bucket, methname)(*args))
 666.         d.addCallback(validatorfunc, bucket)
 667.         def _bad(f):
 668.             level = log.WEIRD
 669.             if f.check(DeadReferenceError):
 670.                 level = log.UNUSUAL
 671.             self.log(format="operation %(op)s from vbucket %(vbucket)s failed",
 672.                      op=name, vbucket=str(bucket),
 673.                      failure=f, level=level, umid="JGXxBA")
 674.             if not sources:
 675.                 raise NotEnoughSharesError("ran out of peers, last error was %s"
 676.                                           % (f,))
 677.             # try again with a different one
 678.             return self._obtain_validated_thing(None, sources, name,
 679.                                                 methname, args, validatorfunc)
 680.         d.addErrback(_bad)
 681.         return d
 682. 
 683.     def _got_uri_extension(self, uri_extension_data):
 684.         if self._results:
 685.             elapsed = time.time() - self._uri_extension_fetch_started
 686.             self._results.timings["uri_extension"] = elapsed
 687. 
 688.         d = self._uri_extension_data = uri_extension_data
 689. 
 690.         self._codec = codec.get_decoder_by_name(d['codec_name'])
 691.         self._codec.set_serialized_params(d['codec_params'])
 692.         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
 693.         self._tail_codec.set_serialized_params(d['tail_codec_params'])
 694. 
 695.         crypttext_hash = d.get('crypttext_hash', None) # optional
 696.         if crypttext_hash:
 697.             assert isinstance(crypttext_hash, str)
 698.             assert len(crypttext_hash) == 32
 699.         self._crypttext_hash = crypttext_hash
 700.         self._plaintext_hash = d.get('plaintext_hash', None) # optional
 701. 
 702.         self._roothash = d['share_root_hash']
 703. 
 704.         self._segment_size = segment_size = d['segment_size']
 705.         self._total_segments = mathutil.div_ceil(self._size, segment_size)
 706.         self._current_segnum = 0
 707. 
 708.         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
 709.         self._share_hashtree.set_hashes({0: self._roothash})
 710. 
 711.     def _get_hashtrees(self, res):
 712.         self._get_hashtrees_started = time.time()
 713.         if self._status:
 714.             self._status.set_status("Retrieving Hash Trees")
 715.         d = defer.maybeDeferred(self._get_plaintext_hashtrees)
 716.         d.addCallback(self._get_crypttext_hashtrees)
 717.         d.addCallback(self._setup_hashtrees)
 718.         return d
 719. 
 720.     def _get_plaintext_hashtrees(self):
 721.         # plaintext hashes are optional. If the root isn't in the UEB, then
 722.         # the share will be holding an empty list. We don't even bother
 723.         # fetching it.
 724.         if "plaintext_root_hash" not in self._uri_extension_data:
 725.             self._plaintext_hashtree = None
 726.             return
 727.         def _validate_plaintext_hashtree(proposal, bucket):
 728.             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
 729.                 self._fetch_failures["plaintext_hashroot"] += 1
 730.                 msg = ("The copy of the plaintext_root_hash we received from"
 731.                        " %s was bad" % bucket)
 732.                 raise BadPlaintextHashValue(msg)
 733.             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
 734.             pt_hashes = dict(list(enumerate(proposal)))
 735.             try:
 736.                 pt_hashtree.set_hashes(pt_hashes)
 737.             except hashtree.BadHashError:
 738.                 # the hashes they gave us were not self-consistent, even
 739.                 # though the root matched what we saw in the uri_extension
 740.                 # block
 741.                 self._fetch_failures["plaintext_hashtree"] += 1
 742.                 raise
 743.             self._plaintext_hashtree = pt_hashtree
 744.         d = self._obtain_validated_thing(None,
 745.                                          self._uri_extension_sources,
 746.                                          "plaintext_hashes",
 747.                                          "get_plaintext_hashes", (),
 748.                                          _validate_plaintext_hashtree)
 749.         return d
 750. 
 751.     def _get_crypttext_hashtrees(self, res):
 752.         # Ciphertext hash tree root is mandatory, so that there is at
 753.         # most one ciphertext that matches this read-cap or
 754.         # verify-cap.  The integrity check on the shares is not
 755.         # sufficient to prevent the original encoder from creating
 756.         # some shares of file A and other shares of file B.
 757.         if "crypttext_root_hash" not in self._uri_extension_data:
 758.             raise BadURIExtension("URI Extension block did not have the ciphertext hash tree root")
 759.         def _validate_crypttext_hashtree(proposal, bucket):
 760.             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
 761.                 self._fetch_failures["crypttext_hashroot"] += 1
 762.                 msg = ("The copy of the crypttext_root_hash we received from"
 763.                        " %s was bad" % bucket)
 764.                 raise BadCrypttextHashValue(msg)
 765.             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
 766.             ct_hashes = dict(list(enumerate(proposal)))
 767.             try:
 768.                 ct_hashtree.set_hashes(ct_hashes)
 769.             except hashtree.BadHashError:
 770.                 self._fetch_failures["crypttext_hashtree"] += 1
 771.                 raise
 772.             ct_hashtree.set_hashes(ct_hashes)
 773.             self._crypttext_hashtree = ct_hashtree
 774.         d = self._obtain_validated_thing(None,
 775.                                          self._uri_extension_sources,
 776.                                          "crypttext_hashes",
 777.                                          "get_crypttext_hashes", (),
 778.                                          _validate_crypttext_hashtree)
 779.         return d
 780. 
 781.     def _setup_hashtrees(self, res):
 782.         self._output.setup_hashtrees(self._plaintext_hashtree,
 783.                                      self._crypttext_hashtree)
 784.         if self._results:
 785.             elapsed = time.time() - self._get_hashtrees_started
 786.             self._results.timings["hashtrees"] = elapsed
 787. 
 788.     def _create_validated_buckets(self, ignored=None):
 789.         self._share_vbuckets = {}
 790.         for sharenum, bucket in self._share_buckets:
 791.             vbucket = ValidatedBucket(sharenum, bucket,
 792.                                       self._share_hashtree,
 793.                                       self._roothash,
 794.                                       self._total_segments)
 795.             s = self._share_vbuckets.setdefault(sharenum, set())
 796.             s.add(vbucket)
 797. 
 798.     def _activate_enough_buckets(self):
 799.         """either return a mapping from shnum to a ValidatedBucket that can
 800.         provide data for that share, or raise NotEnoughSharesError"""
 801. 
 802.         while len(self.active_buckets) < self._num_needed_shares:
 803.             # need some more
 804.             handled_shnums = set(self.active_buckets.keys())
 805.             available_shnums = set(self._share_vbuckets.keys())
 806.             potential_shnums = list(available_shnums - handled_shnums)
 807.             if not potential_shnums:
 808.                 raise NotEnoughSharesError
 809.             # choose a random share
 810.             shnum = random.choice(potential_shnums)
 811.             # and a random bucket that will provide it
 812.             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
 813.             self.active_buckets[shnum] = validated_bucket
 814.         return self.active_buckets
 815. 
 816. 
 817.     def _download_all_segments(self, res):
 818.         # the promise: upon entry to this function, self._share_vbuckets
 819.         # contains enough buckets to complete the download, and some extra
 820.         # ones to tolerate some buckets dropping out or having errors.
 821.         # self._share_vbuckets is a dictionary that maps from shnum to a set
 822.         # of ValidatedBuckets, which themselves are wrappers around
 823.         # RIBucketReader references.
 824.         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
 825. 
 826.         self._started_fetching = time.time()
 827. 
 828.         d = defer.succeed(None)
 829.         for segnum in range(self._total_segments-1):
 830.             d.addCallback(self._download_segment, segnum)
 831.             # this pause, at the end of write, prevents pre-fetch from
 832.             # happening until the consumer is ready for more data.
 833.             d.addCallback(self._check_for_pause)
 834.         d.addCallback(self._download_tail_segment, self._total_segments-1)
 835.         return d
 836. 
 837.     def _check_for_pause(self, res):
 838.         if self._paused:
 839.             d = defer.Deferred()
 840.             self._paused.addCallback(lambda ignored: d.callback(res))
 841.             return d
 842.         if self._stopped:
 843.             raise DownloadStopped("our Consumer called stopProducing()")
 844.         return res
 845. 
 846.     def _download_segment(self, res, segnum):
 847.         if self._status:
 848.             self._status.set_status("Downloading segment %d of %d" %
 849.                                     (segnum+1, self._total_segments))
 850.         self.log("downloading seg#%d of %d (%d%%)"
 851.                  % (segnum, self._total_segments,
 852.                     100.0 * segnum / self._total_segments))
 853.         # memory footprint: when the SegmentDownloader finishes pulling down
 854.         # all shares, we have 1*segment_size of usage.
 855.         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
 856.                                         self._results)
 857.         started = time.time()
 858.         d = segmentdler.start()
 859.         def _finished_fetching(res):
 860.             elapsed = time.time() - started
 861.             self._results.timings["cumulative_fetch"] += elapsed
 862.             return res
 863.         if self._results:
 864.             d.addCallback(_finished_fetching)
 865.         # pause before using more memory
 866.         d.addCallback(self._check_for_pause)
 867.         # while the codec does its job, we hit 2*segment_size
 868.         def _started_decode(res):
 869.             self._started_decode = time.time()
 870.             return res
 871.         if self._results:
 872.             d.addCallback(_started_decode)
 873.         d.addCallback(lambda (shares, shareids):
 874.                       self._codec.decode(shares, shareids))
 875.         # once the codec is done, we drop back to 1*segment_size, because
 876.         # 'shares' goes out of scope. The memory usage is all in the
 877.         # plaintext now, spread out into a bunch of tiny buffers.
 878.         def _finished_decode(res):
 879.             elapsed = time.time() - self._started_decode
 880.             self._results.timings["cumulative_decode"] += elapsed
 881.             return res
 882.         if self._results:
 883.             d.addCallback(_finished_decode)
 884. 
 885.         # pause/check-for-stop just before writing, to honor stopProducing
 886.         d.addCallback(self._check_for_pause)
 887.         def _done(buffers):
 888.             # we start by joining all these buffers together into a single
 889.             # string. This makes Output.write easier, since it wants to hash
 890.             # data one segment at a time anyways, and doesn't impact our
 891.             # memory footprint since we're already peaking at 2*segment_size
 892.             # inside the codec a moment ago.
 893.             segment = "".join(buffers)
 894.             del buffers
 895.             # we're down to 1*segment_size right now, but write_segment()
 896.             # will decrypt a copy of the segment internally, which will push
 897.             # us up to 2*segment_size while it runs.
 898.             started_decrypt = time.time()
 899.             self._output.write_segment(segment)
 900.             if self._results:
 901.                 elapsed = time.time() - started_decrypt
 902.                 self._results.timings["cumulative_decrypt"] += elapsed
 903.         d.addCallback(_done)
 904.         return d
 905. 
 906.     def _download_tail_segment(self, res, segnum):
 907.         self.log("downloading seg#%d of %d (%d%%)"
 908.                  % (segnum, self._total_segments,
 909.                     100.0 * segnum / self._total_segments))
 910.         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
 911.                                         self._results)
 912.         started = time.time()
 913.         d = segmentdler.start()
 914.         def _finished_fetching(res):
 915.             elapsed = time.time() - started
 916.             self._results.timings["cumulative_fetch"] += elapsed
 917.             return res
 918.         if self._results:
 919.             d.addCallback(_finished_fetching)
 920.         # pause before using more memory
 921.         d.addCallback(self._check_for_pause)
 922.         def _started_decode(res):
 923.             self._started_decode = time.time()
 924.             return res
 925.         if self._results:
 926.             d.addCallback(_started_decode)
 927.         d.addCallback(lambda (shares, shareids):
 928.                       self._tail_codec.decode(shares, shareids))
 929.         def _finished_decode(res):
 930.             elapsed = time.time() - self._started_decode
 931.             self._results.timings["cumulative_decode"] += elapsed
 932.             return res
 933.         if self._results:
 934.             d.addCallback(_finished_decode)
 935.         # pause/check-for-stop just before writing, to honor stopProducing
 936.         d.addCallback(self._check_for_pause)
 937.         def _done(buffers):
 938.             # trim off any padding added by the upload side
 939.             segment = "".join(buffers)
 940.             del buffers
 941.             # we never send empty segments. If the data was an exact multiple
 942.             # of the segment size, the last segment will be full.
 943.             pad_size = mathutil.pad_size(self._size, self._segment_size)
 944.             tail_size = self._segment_size - pad_size
 945.             segment = segment[:tail_size]
 946.             started_decrypt = time.time()
 947.             self._output.write_segment(segment)
 948.             if self._results:
 949.                 elapsed = time.time() - started_decrypt
 950.                 self._results.timings["cumulative_decrypt"] += elapsed
 951.         d.addCallback(_done)
 952.         return d
 953. 
 954.     def _done(self, res):
 955.         self.log("download done")
 956.         if self._results:
 957.             now = time.time()
 958.             self._results.timings["total"] = now - self._started
 959.             self._results.timings["segments"] = now - self._started_fetching
 960.         self._output.close()
 961.         if self.check_crypttext_hash and self._crypttext_hash:
 962.             _assert(self._crypttext_hash == self._output.crypttext_hash,
 963.                     "bad crypttext_hash: computed=%s, expected=%s" %
 964.                     (base32.b2a(self._output.crypttext_hash),
 965.                      base32.b2a(self._crypttext_hash)))
 966.         if self.check_plaintext_hash and self._plaintext_hash:
 967.             _assert(self._plaintext_hash == self._output.plaintext_hash,
 968.                     "bad plaintext_hash: computed=%s, expected=%s" %
 969.                     (base32.b2a(self._output.plaintext_hash),
 970.                      base32.b2a(self._plaintext_hash)))
 971.         _assert(self._output.length == self._size,
 972.                 got=self._output.length, expected=self._size)
 973.         return self._output.finish()
 974. 
 975.     def get_download_status(self):
 976.         return self._status
 977. 
 978. 
 979. class FileName:
 980.     implements(IDownloadTarget)
 981.     def __init__(self, filename):
 982.         self._filename = filename
 983.         self.f = None
 984.     def open(self, size):
 985.         self.f = open(self._filename, "wb")
 986.         return self.f
 987.     def write(self, data):
 988.         self.f.write(data)
 989.     def close(self):
 990.         if self.f:
 991.             self.f.close()
 992.     def fail(self, why):
 993.         if self.f:
 994.             self.f.close()
 995.             os.unlink(self._filename)
 996.     def register_canceller(self, cb):
 997.         pass # we won't use it
 998.     def finish(self):
 999.         pass
1000. 
1001. class Data:
1002.     implements(IDownloadTarget)
1003.     def __init__(self):
1004.         self._data = []
1005.     def open(self, size):
1006.         pass
1007.     def write(self, data):
1008.         self._data.append(data)
1009.     def close(self):
1010.         self.data = "".join(self._data)
1011.         del self._data
1012.     def fail(self, why):
1013.         del self._data
1014.     def register_canceller(self, cb):
1015.         pass # we won't use it
1016.     def finish(self):
1017.         return self.data
1018. 
1019. class FileHandle:
1020.     """Use me to download data to a pre-defined filehandle-like object. I
1021.     will use the target's write() method. I will *not* close the filehandle:
1022.     I leave that up to the originator of the filehandle. The download process
1023.     will return the filehandle when it completes.
1024.     """
1025.     implements(IDownloadTarget)
1026.     def __init__(self, filehandle):
1027.         self._filehandle = filehandle
1028.     def open(self, size):
1029.         pass
1030.     def write(self, data):
1031.         self._filehandle.write(data)
1032.     def close(self):
1033.         # the originator of the filehandle reserves the right to close it
1034.         pass
1035.     def fail(self, why):
1036.         pass
1037.     def register_canceller(self, cb):
1038.         pass
1039.     def finish(self):
1040.         return self._filehandle
1041. 
1042. class Downloader(service.MultiService):
1043.     """I am a service that allows file downloading.
1044.     """
1045.     # TODO: in fact, this service only downloads immutable files (URI:CHK:).
1046.     # It is scheduled to go away, to be replaced by filenode.download()
1047.     implements(IDownloader)
1048.     name = "downloader"
1049.     MAX_DOWNLOAD_STATUSES = 10
1050. 
1051.     def __init__(self, stats_provider=None):
1052.         service.MultiService.__init__(self)
1053.         self.stats_provider = stats_provider
1054.         self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1055.         self._all_download_statuses = weakref.WeakKeyDictionary()
1056.         self._recent_download_statuses = []
1057. 
1058.     def download(self, u, t):
1059.         assert self.parent
1060.         assert self.running
1061.         u = IFileURI(u)
1062.         t = IDownloadTarget(t)
1063.         assert t.write
1064.         assert t.close
1065. 
1066.         assert isinstance(u, uri.CHKFileURI)
1067.         if self.stats_provider:
1068.             # these counters are meant for network traffic, and don't
1069.             # include LIT files
1070.             self.stats_provider.count('downloader.files_downloaded', 1)
1071.             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1072.         dl = FileDownloader(self.parent, u, t)
1073.         self._add_download(dl)
1074.         d = dl.start()
1075.         return d
1076. 
1077.     # utility functions
1078.     def download_to_data(self, uri):
1079.         return self.download(uri, Data())
1080.     def download_to_filename(self, uri, filename):
1081.         return self.download(uri, FileName(filename))
1082.     def download_to_filehandle(self, uri, filehandle):
1083.         return self.download(uri, FileHandle(filehandle))
1084. 
1085.     def _add_download(self, downloader):
1086.         self._all_downloads[downloader] = None
1087.         s = downloader.get_download_status()
1088.         self._all_download_statuses[s] = None
1089.         self._recent_download_statuses.append(s)
1090.         while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1091.             self._recent_download_statuses.pop(0)
1092. 
1093.     def list_all_download_statuses(self):
1094.         for ds in self._all_download_statuses:
1095.             yield ds