source file: /home/buildslave/tahoe/edgy/build/src/allmydata/mutable/retrieve.py
file stats: 360 lines, 348 executed: 96.7% covered
   1. 
   2. import struct, time
   3. from itertools import count
   4. from zope.interface import implements
   5. from twisted.internet import defer
   6. from twisted.python import failure
   7. from foolscap import DeadReferenceError
   8. from foolscap.eventual import eventually, fireEventually
   9. from allmydata.interfaces import IRetrieveStatus
  10. from allmydata.util import hashutil, idlib, log
  11. from allmydata import hashtree, codec, storage
  12. from allmydata.immutable.encode import NotEnoughSharesError
  13. from pycryptopp.cipher.aes import AES
  14. from pycryptopp.publickey import rsa
  15. 
  16. from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
  17. from layout import SIGNED_PREFIX, unpack_share_data
  18. 
  19. class RetrieveStatus:
  20.     implements(IRetrieveStatus)
  21.     statusid_counter = count(0)
  22.     def __init__(self):
  23.         self.timings = {}
  24.         self.timings["fetch_per_server"] = {}
  25.         self.timings["cumulative_verify"] = 0.0
  26.         self.problems = {}
  27.         self.active = True
  28.         self.storage_index = None
  29.         self.helper = False
  30.         self.encoding = ("?","?")
  31.         self.size = None
  32.         self.status = "Not started"
  33.         self.progress = 0.0
  34.         self.counter = self.statusid_counter.next()
  35.         self.started = time.time()
  36. 
  37.     def get_started(self):
  38.         return self.started
  39.     def get_storage_index(self):
  40.         return self.storage_index
  41.     def get_encoding(self):
  42.         return self.encoding
  43.     def using_helper(self):
  44.         return self.helper
  45.     def get_size(self):
  46.         return self.size
  47.     def get_status(self):
  48.         return self.status
  49.     def get_progress(self):
  50.         return self.progress
  51.     def get_active(self):
  52.         return self.active
  53.     def get_counter(self):
  54.         return self.counter
  55. 
  56.     def add_fetch_timing(self, peerid, elapsed):
  57.         if peerid not in self.timings["fetch_per_server"]:
  58.             self.timings["fetch_per_server"][peerid] = []
  59.         self.timings["fetch_per_server"][peerid].append(elapsed)
  60.     def set_storage_index(self, si):
  61.         self.storage_index = si
  62.     def set_helper(self, helper):
  63.         self.helper = helper
  64.     def set_encoding(self, k, n):
  65.         self.encoding = (k, n)
  66.     def set_size(self, size):
  67.         self.size = size
  68.     def set_status(self, status):
  69.         self.status = status
  70.     def set_progress(self, value):
  71.         self.progress = value
  72.     def set_active(self, value):
  73.         self.active = value
  74. 
  75. class Marker:
  76.     pass
  77. 
  78. class Retrieve:
  79.     # this class is currently single-use. Eventually (in MDMF) we will make
  80.     # it multi-use, in which case you can call download(range) multiple
  81.     # times, and each will have a separate response chain. However the
  82.     # Retrieve object will remain tied to a specific version of the file, and
  83.     # will use a single ServerMap instance.
  84. 
  85.     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False):
  86.         self._node = filenode
  87.         assert self._node._pubkey
  88.         self._storage_index = filenode.get_storage_index()
  89.         assert self._node._readkey
  90.         self._last_failure = None
  91.         prefix = storage.si_b2a(self._storage_index)[:5]
  92.         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
  93.         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
  94.         self._running = True
  95.         self._decoding = False
  96.         self._bad_shares = set()
  97. 
  98.         self.servermap = servermap
  99.         assert self._node._pubkey
 100.         self.verinfo = verinfo
 101.         # during repair, we may be called upon to grab the private key, since
 102.         # it wasn't picked up during a verify=False checker run, and we'll
 103.         # need it for repair to generate the a new version.
 104.         self._need_privkey = fetch_privkey
 105.         if self._node._privkey:
 106.             self._need_privkey = False
 107. 
 108.         self._status = RetrieveStatus()
 109.         self._status.set_storage_index(self._storage_index)
 110.         self._status.set_helper(False)
 111.         self._status.set_progress(0.0)
 112.         self._status.set_active(True)
 113.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 114.          offsets_tuple) = self.verinfo
 115.         self._status.set_size(datalength)
 116.         self._status.set_encoding(k, N)
 117. 
 118.     def get_status(self):
 119.         return self._status
 120. 
 121.     def log(self, *args, **kwargs):
 122.         if "parent" not in kwargs:
 123.             kwargs["parent"] = self._log_number
 124.         if "facility" not in kwargs:
 125.             kwargs["facility"] = "tahoe.mutable.retrieve"
 126.         return log.msg(*args, **kwargs)
 127. 
 128.     def download(self):
 129.         self._done_deferred = defer.Deferred()
 130.         self._started = time.time()
 131.         self._status.set_status("Retrieving Shares")
 132. 
 133.         # first, which servers can we use?
 134.         versionmap = self.servermap.make_versionmap()
 135.         shares = versionmap[self.verinfo]
 136.         # this sharemap is consumed as we decide to send requests
 137.         self.remaining_sharemap = DictOfSets()
 138.         for (shnum, peerid, timestamp) in shares:
 139.             self.remaining_sharemap.add(shnum, peerid)
 140. 
 141.         self.shares = {} # maps shnum to validated blocks
 142. 
 143.         # how many shares do we need?
 144.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 145.          offsets_tuple) = self.verinfo
 146.         assert len(self.remaining_sharemap) >= k
 147.         # we start with the lowest shnums we have available, since FEC is
 148.         # faster if we're using "primary shares"
 149.         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
 150.         for shnum in self.active_shnums:
 151.             # we use an arbitrary peer who has the share. If shares are
 152.             # doubled up (more than one share per peer), we could make this
 153.             # run faster by spreading the load among multiple peers. But the
 154.             # algorithm to do that is more complicated than I want to write
 155.             # right now, and a well-provisioned grid shouldn't have multiple
 156.             # shares per peer.
 157.             peerid = list(self.remaining_sharemap[shnum])[0]
 158.             self.get_data(shnum, peerid)
 159. 
 160.         # control flow beyond this point: state machine. Receiving responses
 161.         # from queries is the input. We might send out more queries, or we
 162.         # might produce a result.
 163. 
 164.         return self._done_deferred
 165. 
 166.     def get_data(self, shnum, peerid):
 167.         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
 168.                  shnum=shnum,
 169.                  peerid=idlib.shortnodeid_b2a(peerid),
 170.                  level=log.NOISY)
 171.         ss = self.servermap.connections[peerid]
 172.         started = time.time()
 173.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 174.          offsets_tuple) = self.verinfo
 175.         offsets = dict(offsets_tuple)
 176. 
 177.         # we read the checkstring, to make sure that the data we grab is from
 178.         # the right version.
 179.         readv = [ (0, struct.calcsize(SIGNED_PREFIX)) ]
 180. 
 181.         # We also read the data, and the hashes necessary to validate them
 182.         # (share_hash_chain, block_hash_tree, share_data). We don't read the
 183.         # signature or the pubkey, since that was handled during the
 184.         # servermap phase, and we'll be comparing the share hash chain
 185.         # against the roothash that was validated back then.
 186. 
 187.         readv.append( (offsets['share_hash_chain'],
 188.                        offsets['enc_privkey'] - offsets['share_hash_chain'] ) )
 189. 
 190.         # if we need the private key (for repair), we also fetch that
 191.         if self._need_privkey:
 192.             readv.append( (offsets['enc_privkey'],
 193.                            offsets['EOF'] - offsets['enc_privkey']) )
 194. 
 195.         m = Marker()
 196.         self._outstanding_queries[m] = (peerid, shnum, started)
 197. 
 198.         # ask the cache first
 199.         got_from_cache = False
 200.         datavs = []
 201.         for (offset, length) in readv:
 202.             (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
 203.                                                        offset, length)
 204.             if data is not None:
 205.                 datavs.append(data)
 206.         if len(datavs) == len(readv):
 207.             self.log("got data from cache")
 208.             got_from_cache = True
 209.             d = fireEventually({shnum: datavs})
 210.             # datavs is a dict mapping shnum to a pair of strings
 211.         else:
 212.             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
 213.         self.remaining_sharemap.discard(shnum, peerid)
 214. 
 215.         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
 216.         d.addErrback(self._query_failed, m, peerid)
 217.         # errors that aren't handled by _query_failed (and errors caused by
 218.         # _query_failed) get logged, but we still want to check for doneness.
 219.         def _oops(f):
 220.             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
 221.                      shnum=shnum,
 222.                      peerid=idlib.shortnodeid_b2a(peerid),
 223.                      failure=f,
 224.                      level=log.WEIRD, umid="W0xnQA")
 225.         d.addErrback(_oops)
 226.         d.addBoth(self._check_for_done)
 227.         # any error during _check_for_done means the download fails. If the
 228.         # download is successful, _check_for_done will fire _done by itself.
 229.         d.addErrback(self._done)
 230.         d.addErrback(log.err)
 231.         return d # purely for testing convenience
 232. 
 233.     def _do_read(self, ss, peerid, storage_index, shnums, readv):
 234.         # isolate the callRemote to a separate method, so tests can subclass
 235.         # Publish and override it
 236.         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
 237.         return d
 238. 
 239.     def remove_peer(self, peerid):
 240.         for shnum in list(self.remaining_sharemap.keys()):
 241.             self.remaining_sharemap.discard(shnum, peerid)
 242. 
 243.     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
 244.         now = time.time()
 245.         elapsed = now - started
 246.         if not got_from_cache:
 247.             self._status.add_fetch_timing(peerid, elapsed)
 248.         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
 249.                  shares=len(datavs),
 250.                  peerid=idlib.shortnodeid_b2a(peerid),
 251.                  level=log.NOISY)
 252.         self._outstanding_queries.pop(marker, None)
 253.         if not self._running:
 254.             return
 255. 
 256.         # note that we only ask for a single share per query, so we only
 257.         # expect a single share back. On the other hand, we use the extra
 258.         # shares if we get them.. seems better than an assert().
 259. 
 260.         for shnum,datav in datavs.items():
 261.             (prefix, hash_and_data) = datav[:2]
 262.             try:
 263.                 self._got_results_one_share(shnum, peerid,
 264.                                             prefix, hash_and_data)
 265.             except CorruptShareError, e:
 266.                 # log it and give the other shares a chance to be processed
 267.                 f = failure.Failure()
 268.                 self.log(format="bad share: %(f_value)s",
 269.                          f_value=str(f.value), failure=f,
 270.                          level=log.WEIRD, umid="7fzWZw")
 271.                 self.remove_peer(peerid)
 272.                 self.servermap.mark_bad_share(peerid, shnum, prefix)
 273.                 self._bad_shares.add( (peerid, shnum) )
 274.                 self._status.problems[peerid] = f
 275.                 self._last_failure = f
 276.                 pass
 277.             if self._need_privkey and len(datav) > 2:
 278.                 lp = None
 279.                 self._try_to_validate_privkey(datav[2], peerid, shnum, lp)
 280.         # all done!
 281. 
 282.     def _got_results_one_share(self, shnum, peerid,
 283.                                got_prefix, got_hash_and_data):
 284.         self.log("_got_results: got shnum #%d from peerid %s"
 285.                  % (shnum, idlib.shortnodeid_b2a(peerid)))
 286.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 287.          offsets_tuple) = self.verinfo
 288.         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
 289.         if got_prefix != prefix:
 290.             msg = "someone wrote to the data since we read the servermap: prefix changed"
 291.             raise UncoordinatedWriteError(msg)
 292.         (share_hash_chain, block_hash_tree,
 293.          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
 294. 
 295.         assert isinstance(share_data, str)
 296.         # build the block hash tree. SDMF has only one leaf.
 297.         leaves = [hashutil.block_hash(share_data)]
 298.         t = hashtree.HashTree(leaves)
 299.         if list(t) != block_hash_tree:
 300.             raise CorruptShareError(peerid, shnum, "block hash tree failure")
 301.         share_hash_leaf = t[0]
 302.         t2 = hashtree.IncompleteHashTree(N)
 303.         # root_hash was checked by the signature
 304.         t2.set_hashes({0: root_hash})
 305.         try:
 306.             t2.set_hashes(hashes=share_hash_chain,
 307.                           leaves={shnum: share_hash_leaf})
 308.         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
 309.                 IndexError), e:
 310.             msg = "corrupt hashes: %s" % (e,)
 311.             raise CorruptShareError(peerid, shnum, msg)
 312.         self.log(" data valid! len=%d" % len(share_data))
 313.         # each query comes down to this: placing validated share data into
 314.         # self.shares
 315.         self.shares[shnum] = share_data
 316. 
 317.     def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
 318. 
 319.         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
 320.         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
 321.         if alleged_writekey != self._node.get_writekey():
 322.             self.log("invalid privkey from %s shnum %d" %
 323.                      (idlib.nodeid_b2a(peerid)[:8], shnum),
 324.                      parent=lp, level=log.WEIRD, umid="YIw4tA")
 325.             return
 326. 
 327.         # it's good
 328.         self.log("got valid privkey from shnum %d on peerid %s" %
 329.                  (shnum, idlib.shortnodeid_b2a(peerid)),
 330.                  parent=lp)
 331.         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
 332.         self._node._populate_encprivkey(enc_privkey)
 333.         self._node._populate_privkey(privkey)
 334.         self._need_privkey = False
 335. 
 336.     def _query_failed(self, f, marker, peerid):
 337.         self.log(format="query to [%(peerid)s] failed",
 338.                  peerid=idlib.shortnodeid_b2a(peerid),
 339.                  level=log.NOISY)
 340.         self._status.problems[peerid] = f
 341.         self._outstanding_queries.pop(marker, None)
 342.         if not self._running:
 343.             return
 344.         self._last_failure = f
 345.         self.remove_peer(peerid)
 346.         level = log.WEIRD
 347.         if f.check(DeadReferenceError):
 348.             level = log.UNUSUAL
 349.         self.log(format="error during query: %(f_value)s",
 350.                  f_value=str(f.value), failure=f, level=level, umid="gOJB5g")
 351. 
 352.     def _check_for_done(self, res):
 353.         # exit paths:
 354.         #  return : keep waiting, no new queries
 355.         #  return self._send_more_queries(outstanding) : send some more queries
 356.         #  fire self._done(plaintext) : download successful
 357.         #  raise exception : download fails
 358. 
 359.         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
 360.                  running=self._running, decoding=self._decoding,
 361.                  level=log.NOISY)
 362.         if not self._running:
 363.             return
 364.         if self._decoding:
 365.             return
 366.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 367.          offsets_tuple) = self.verinfo
 368. 
 369.         if len(self.shares) < k:
 370.             # we don't have enough shares yet
 371.             return self._maybe_send_more_queries(k)
 372.         if self._need_privkey:
 373.             # we got k shares, but none of them had a valid privkey. TODO:
 374.             # look further. Adding code to do this is a bit complicated, and
 375.             # I want to avoid that complication, and this should be pretty
 376.             # rare (k shares with bitflips in the enc_privkey but not in the
 377.             # data blocks). If we actually do get here, the subsequent repair
 378.             # will fail for lack of a privkey.
 379.             self.log("got k shares but still need_privkey, bummer",
 380.                      level=log.WEIRD, umid="MdRHPA")
 381. 
 382.         # we have enough to finish. All the shares have had their hashes
 383.         # checked, so if something fails at this point, we don't know how
 384.         # to fix it, so the download will fail.
 385. 
 386.         self._decoding = True # avoid reentrancy
 387.         self._status.set_status("decoding")
 388.         now = time.time()
 389.         elapsed = now - self._started
 390.         self._status.timings["fetch"] = elapsed
 391. 
 392.         d = defer.maybeDeferred(self._decode)
 393.         d.addCallback(self._decrypt, IV, self._node._readkey)
 394.         d.addBoth(self._done)
 395.         return d # purely for test convenience
 396. 
 397.     def _maybe_send_more_queries(self, k):
 398.         # we don't have enough shares yet. Should we send out more queries?
 399.         # There are some number of queries outstanding, each for a single
 400.         # share. If we can generate 'needed_shares' additional queries, we do
 401.         # so. If we can't, then we know this file is a goner, and we raise
 402.         # NotEnoughSharesError.
 403.         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
 404.                          "outstanding=%(outstanding)d"),
 405.                  have=len(self.shares), k=k,
 406.                  outstanding=len(self._outstanding_queries),
 407.                  level=log.NOISY)
 408. 
 409.         remaining_shares = k - len(self.shares)
 410.         needed = remaining_shares - len(self._outstanding_queries)
 411.         if not needed:
 412.             # we have enough queries in flight already
 413. 
 414.             # TODO: but if they've been in flight for a long time, and we
 415.             # have reason to believe that new queries might respond faster
 416.             # (i.e. we've seen other queries come back faster, then consider
 417.             # sending out new queries. This could help with peers which have
 418.             # silently gone away since the servermap was updated, for which
 419.             # we're still waiting for the 15-minute TCP disconnect to happen.
 420.             self.log("enough queries are in flight, no more are needed",
 421.                      level=log.NOISY)
 422.             return
 423. 
 424.         outstanding_shnums = set([shnum
 425.                                   for (peerid, shnum, started)
 426.                                   in self._outstanding_queries.values()])
 427.         # prefer low-numbered shares, they are more likely to be primary
 428.         available_shnums = sorted(self.remaining_sharemap.keys())
 429.         for shnum in available_shnums:
 430.             if shnum in outstanding_shnums:
 431.                 # skip ones that are already in transit
 432.                 continue
 433.             if shnum not in self.remaining_sharemap:
 434.                 # no servers for that shnum. note that DictOfSets removes
 435.                 # empty sets from the dict for us.
 436.                 continue
 437.             peerid = list(self.remaining_sharemap[shnum])[0]
 438.             # get_data will remove that peerid from the sharemap, and add the
 439.             # query to self._outstanding_queries
 440.             self._status.set_status("Retrieving More Shares")
 441.             self.get_data(shnum, peerid)
 442.             needed -= 1
 443.             if not needed:
 444.                 break
 445. 
 446.         # at this point, we have as many outstanding queries as we can. If
 447.         # needed!=0 then we might not have enough to recover the file.
 448.         if needed:
 449.             format = ("ran out of peers: "
 450.                       "have %(have)d shares (k=%(k)d), "
 451.                       "%(outstanding)d queries in flight, "
 452.                       "need %(need)d more, "
 453.                       "found %(bad)d bad shares")
 454.             args = {"have": len(self.shares),
 455.                     "k": k,
 456.                     "outstanding": len(self._outstanding_queries),
 457.                     "need": needed,
 458.                     "bad": len(self._bad_shares),
 459.                     }
 460.             self.log(format=format,
 461.                      level=log.WEIRD, umid="ezTfjw", **args)
 462.             err = NotEnoughSharesError("%s, last failure: %s" %
 463.                                       (format % args, self._last_failure))
 464.             if self._bad_shares:
 465.                 self.log("We found some bad shares this pass. You should "
 466.                          "update the servermap and try again to check "
 467.                          "more peers",
 468.                          level=log.WEIRD, umid="EFkOlA")
 469.                 err.servermap = self.servermap
 470.             raise err
 471. 
 472.         return
 473. 
 474.     def _decode(self):
 475.         started = time.time()
 476.         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 477.          offsets_tuple) = self.verinfo
 478. 
 479.         # shares_dict is a dict mapping shnum to share data, but the codec
 480.         # wants two lists.
 481.         shareids = []; shares = []
 482.         for shareid, share in self.shares.items():
 483.             shareids.append(shareid)
 484.             shares.append(share)
 485. 
 486.         assert len(shareids) >= k, len(shareids)
 487.         # zfec really doesn't want extra shares
 488.         shareids = shareids[:k]
 489.         shares = shares[:k]
 490. 
 491.         fec = codec.CRSDecoder()
 492.         params = "%d-%d-%d" % (segsize, k, N)
 493.         fec.set_serialized_params(params)
 494. 
 495.         self.log("params %s, we have %d shares" % (params, len(shares)))
 496.         self.log("about to decode, shareids=%s" % (shareids,))
 497.         d = defer.maybeDeferred(fec.decode, shares, shareids)
 498.         def _done(buffers):
 499.             self._status.timings["decode"] = time.time() - started
 500.             self.log(" decode done, %d buffers" % len(buffers))
 501.             segment = "".join(buffers)
 502.             self.log(" joined length %d, datalength %d" %
 503.                      (len(segment), datalength))
 504.             segment = segment[:datalength]
 505.             self.log(" segment len=%d" % len(segment))
 506.             return segment
 507.         def _err(f):
 508.             self.log(" decode failed: %s" % f)
 509.             return f
 510.         d.addCallback(_done)
 511.         d.addErrback(_err)
 512.         return d
 513. 
 514.     def _decrypt(self, crypttext, IV, readkey):
 515.         self._status.set_status("decrypting")
 516.         started = time.time()
 517.         key = hashutil.ssk_readkey_data_hash(IV, readkey)
 518.         decryptor = AES(key)
 519.         plaintext = decryptor.process(crypttext)
 520.         self._status.timings["decrypt"] = time.time() - started
 521.         return plaintext
 522. 
 523.     def _done(self, res):
 524.         if not self._running:
 525.             return
 526.         self._running = False
 527.         self._status.set_active(False)
 528.         self._status.timings["total"] = time.time() - self._started
 529.         # res is either the new contents, or a Failure
 530.         if isinstance(res, failure.Failure):
 531.             self.log("Retrieve done, with failure", failure=res,
 532.                      level=log.UNUSUAL)
 533.             self._status.set_status("Failed")
 534.         else:
 535.             self.log("Retrieve done, success!")
 536.             self._status.set_status("Done")
 537.             self._status.set_progress(1.0)
 538.             # remember the encoding parameters, use them again next time
 539.             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
 540.              offsets_tuple) = self.verinfo
 541.             self._node._populate_required_shares(k)
 542.             self._node._populate_total_shares(N)
 543.         eventually(self._done_deferred.callback, res)
 544.