source file: /home/buildslave/tahoe/edgy/build/src/allmydata/immutable/download.py
file stats: 795 lines, 753 executed: 94.7% covered
1.
2. import os, random, weakref, itertools, time
3. from zope.interface import implements
4. from twisted.internet import defer
5. from twisted.internet.interfaces import IPushProducer, IConsumer
6. from twisted.application import service
7. from foolscap import DeadReferenceError
8. from foolscap.eventual import eventually
9.
10. from allmydata.util import base32, mathutil, hashutil, log
11. from allmydata.util.assertutil import _assert
12. from allmydata import codec, hashtree, storage, uri
13. from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
14. IDownloadStatus, IDownloadResults
15. from allmydata.immutable.encode import NotEnoughSharesError
16. from pycryptopp.cipher.aes import AES
17.
18. class HaveAllPeersError(Exception):
19. # we use this to jump out of the loop
20. pass
21.
22. class IntegrityCheckError(Exception):
23. pass
24.
25. class BadURIExtensionHashValue(IntegrityCheckError):
26. pass
27. class BadURIExtension(IntegrityCheckError):
28. pass
29. class BadPlaintextHashValue(IntegrityCheckError):
30. pass
31. class BadCrypttextHashValue(IntegrityCheckError):
32. pass
33.
34. class DownloadStopped(Exception):
35. pass
36.
37. class DownloadResults:
38. implements(IDownloadResults)
39.
40. def __init__(self):
41. self.servers_used = set()
42. self.server_problems = {}
43. self.servermap = {}
44. self.timings = {}
45. self.file_size = None
46.
47. class Output:
48. def __init__(self, downloadable, key, total_length, log_parent,
49. download_status):
50. self.downloadable = downloadable
51. self._decryptor = AES(key)
52. self._crypttext_hasher = hashutil.crypttext_hasher()
53. self._plaintext_hasher = hashutil.plaintext_hasher()
54. self.length = 0
55. self.total_length = total_length
56. self._segment_number = 0
57. self._plaintext_hash_tree = None
58. self._crypttext_hash_tree = None
59. self._opened = False
60. self._log_parent = log_parent
61. self._status = download_status
62. self._status.set_progress(0.0)
63.
64. def log(self, *args, **kwargs):
65. if "parent" not in kwargs:
66. kwargs["parent"] = self._log_parent
67. if "facility" not in kwargs:
68. kwargs["facility"] = "download.output"
69. return log.msg(*args, **kwargs)
70.
71. def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
72. self._plaintext_hash_tree = plaintext_hashtree
73. self._crypttext_hash_tree = crypttext_hashtree
74.
75. def write_segment(self, crypttext):
76. self.length += len(crypttext)
77. self._status.set_progress( float(self.length) / self.total_length )
78.
79. # memory footprint: 'crypttext' is the only segment_size usage
80. # outstanding. While we decrypt it into 'plaintext', we hit
81. # 2*segment_size.
82. self._crypttext_hasher.update(crypttext)
83. if self._crypttext_hash_tree:
84. ch = hashutil.crypttext_segment_hasher()
85. ch.update(crypttext)
86. crypttext_leaves = {self._segment_number: ch.digest()}
87. self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
88. bytes=len(crypttext),
89. segnum=self._segment_number, hash=base32.b2a(ch.digest()),
90. level=log.NOISY)
91. self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
92.
93. plaintext = self._decryptor.process(crypttext)
94. del crypttext
95.
96. # now we're back down to 1*segment_size.
97.
98. self._plaintext_hasher.update(plaintext)
99. if self._plaintext_hash_tree:
100. ph = hashutil.plaintext_segment_hasher()
101. ph.update(plaintext)
102. plaintext_leaves = {self._segment_number: ph.digest()}
103. self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
104. bytes=len(plaintext),
105. segnum=self._segment_number, hash=base32.b2a(ph.digest()),
106. level=log.NOISY)
107. self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
108.
109. self._segment_number += 1
110. # We're still at 1*segment_size. The Downloadable is responsible for
111. # any memory usage beyond this.
112. if not self._opened:
113. self._opened = True
114. self.downloadable.open(self.total_length)
115. self.downloadable.write(plaintext)
116.
117. def fail(self, why):
118. # this is really unusual, and deserves maximum forensics
119. if why.check(DownloadStopped):
120. # except DownloadStopped just means the consumer aborted the
121. # download, not so scary
122. self.log("download stopped", level=log.UNUSUAL)
123. else:
124. self.log("download failed!", failure=why,
125. level=log.SCARY, umid="lp1vaQ")
126. self.downloadable.fail(why)
127.
128. def close(self):
129. self.crypttext_hash = self._crypttext_hasher.digest()
130. self.plaintext_hash = self._plaintext_hasher.digest()
131. self.log("download finished, closing IDownloadable", level=log.NOISY)
132. self.downloadable.close()
133.
134. def finish(self):
135. return self.downloadable.finish()
136.
137. class ValidatedBucket:
138. """I am a front-end for a remote storage bucket, responsible for
139. retrieving and validating data from that bucket.
140.
141. My get_block() method is used by BlockDownloaders.
142. """
143.
144. def __init__(self, sharenum, bucket,
145. share_hash_tree, roothash,
146. num_blocks):
147. self.sharenum = sharenum
148. self.bucket = bucket
149. self._share_hash = None # None means not validated yet
150. self.share_hash_tree = share_hash_tree
151. self._roothash = roothash
152. self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
153. self.started = False
154.
155. def get_block(self, blocknum):
156. if not self.started:
157. d = self.bucket.start()
158. def _started(res):
159. self.started = True
160. return self.get_block(blocknum)
161. d.addCallback(_started)
162. return d
163.
164. # the first time we use this bucket, we need to fetch enough elements
165. # of the share hash tree to validate it from our share hash up to the
166. # hashroot.
167. if not self._share_hash:
168. d1 = self.bucket.get_share_hashes()
169. else:
170. d1 = defer.succeed([])
171.
172. # we might need to grab some elements of our block hash tree, to
173. # validate the requested block up to the share hash
174. needed = self.block_hash_tree.needed_hashes(blocknum)
175. if needed:
176. # TODO: get fewer hashes, use get_block_hashes(needed)
177. d2 = self.bucket.get_block_hashes()
178. else:
179. d2 = defer.succeed([])
180.
181. d3 = self.bucket.get_block(blocknum)
182.
183. d = defer.gatherResults([d1, d2, d3])
184. d.addCallback(self._got_data, blocknum)
185. return d
186.
187. def _got_data(self, res, blocknum):
188. sharehashes, blockhashes, blockdata = res
189. blockhash = None # to make logging it safe
190.
191. try:
192. if not self._share_hash:
193. sh = dict(sharehashes)
194. sh[0] = self._roothash # always use our own root, from the URI
195. sht = self.share_hash_tree
196. if sht.get_leaf_index(self.sharenum) not in sh:
197. raise hashtree.NotEnoughHashesError
198. sht.set_hashes(sh)
199. self._share_hash = sht.get_leaf(self.sharenum)
200.
201. blockhash = hashutil.block_hash(blockdata)
202. #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
203. # "%r .. %r: %s" %
204. # (self.sharenum, blocknum, len(blockdata),
205. # blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
206.
207. # we always validate the blockhash
208. bh = dict(enumerate(blockhashes))
209. # replace blockhash root with validated value
210. bh[0] = self._share_hash
211. self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
212.
213. except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
214. # log.WEIRD: indicates undetected disk/network error, or more
215. # likely a programming error
216. log.msg("hash failure in block=%d, shnum=%d on %s" %
217. (blocknum, self.sharenum, self.bucket))
218. if self._share_hash:
219. log.msg(""" failure occurred when checking the block_hash_tree.
220. This suggests that either the block data was bad, or that the
221. block hashes we received along with it were bad.""")
222. else:
223. log.msg(""" the failure probably occurred when checking the
224. share_hash_tree, which suggests that the share hashes we
225. received from the remote peer were bad.""")
226. log.msg(" have self._share_hash: %s" % bool(self._share_hash))
227. log.msg(" block length: %d" % len(blockdata))
228. log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
229. if len(blockdata) < 100:
230. log.msg(" block data: %r" % (blockdata,))
231. else:
232. log.msg(" block data start/end: %r .. %r" %
233. (blockdata[:50], blockdata[-50:]))
234. log.msg(" root hash: %s" % base32.b2a(self._roothash))
235. log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
236. log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
237. lines = []
238. for i,h in sorted(sharehashes):
239. lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
240. log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
241. lines = []
242. for i,h in enumerate(blockhashes):
243. lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
244. log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
245. raise
246.
247. # If we made it here, the block is good. If the hash trees didn't
248. # like what they saw, they would have raised a BadHashError, causing
249. # our caller to see a Failure and thus ignore this block (as well as
250. # dropping this bucket).
251. return blockdata
252.
253.
254.
255. class BlockDownloader:
256. """I am responsible for downloading a single block (from a single bucket)
257. for a single segment.
258.
259. I am a child of the SegmentDownloader.
260. """
261.
262. def __init__(self, vbucket, blocknum, parent, results):
263. self.vbucket = vbucket
264. self.blocknum = blocknum
265. self.parent = parent
266. self.results = results
267. self._log_number = self.parent.log("starting block %d" % blocknum)
268.
269. def log(self, *args, **kwargs):
270. if "parent" not in kwargs:
271. kwargs["parent"] = self._log_number
272. return self.parent.log(*args, **kwargs)
273.
274. def start(self, segnum):
275. lognum = self.log("get_block(segnum=%d)" % segnum)
276. started = time.time()
277. d = self.vbucket.get_block(segnum)
278. d.addCallbacks(self._hold_block, self._got_block_error,
279. callbackArgs=(started, lognum,), errbackArgs=(lognum,))
280. return d
281.
282. def _hold_block(self, data, started, lognum):
283. if self.results:
284. elapsed = time.time() - started
285. peerid = self.vbucket.bucket.get_peerid()
286. if peerid not in self.results.timings["fetch_per_server"]:
287. self.results.timings["fetch_per_server"][peerid] = []
288. self.results.timings["fetch_per_server"][peerid].append(elapsed)
289. self.log("got block", parent=lognum)
290. self.parent.hold_block(self.blocknum, data)
291.
292. def _got_block_error(self, f, lognum):
293. level = log.WEIRD
294. if f.check(DeadReferenceError):
295. level = log.UNUSUAL
296. self.log("BlockDownloader[%d] got error" % self.blocknum,
297. failure=f, level=level, parent=lognum, umid="5Z4uHQ")
298. if self.results:
299. peerid = self.vbucket.bucket.get_peerid()
300. self.results.server_problems[peerid] = str(f)
301. self.parent.bucket_failed(self.vbucket)
302.
303. class SegmentDownloader:
304. """I am responsible for downloading all the blocks for a single segment
305. of data.
306.
307. I am a child of the FileDownloader.
308. """
309.
310. def __init__(self, parent, segmentnumber, needed_shares, results):
311. self.parent = parent
312. self.segmentnumber = segmentnumber
313. self.needed_blocks = needed_shares
314. self.blocks = {} # k: blocknum, v: data
315. self.results = results
316. self._log_number = self.parent.log("starting segment %d" %
317. segmentnumber)
318.
319. def log(self, *args, **kwargs):
320. if "parent" not in kwargs:
321. kwargs["parent"] = self._log_number
322. return self.parent.log(*args, **kwargs)
323.
324. def start(self):
325. return self._download()
326.
327. def _download(self):
328. d = self._try()
329. def _done(res):
330. if len(self.blocks) >= self.needed_blocks:
331. # we only need self.needed_blocks blocks
332. # we want to get the smallest blockids, because they are
333. # more likely to be fast "primary blocks"
334. blockids = sorted(self.blocks.keys())[:self.needed_blocks]
335. blocks = []
336. for blocknum in blockids:
337. blocks.append(self.blocks[blocknum])
338. return (blocks, blockids)
339. else:
340. return self._download()
341. d.addCallback(_done)
342. return d
343.
344. def _try(self):
345. # fill our set of active buckets, maybe raising NotEnoughSharesError
346. active_buckets = self.parent._activate_enough_buckets()
347. # Now we have enough buckets, in self.parent.active_buckets.
348.
349. # in test cases, bd.start might mutate active_buckets right away, so
350. # we need to put off calling start() until we've iterated all the way
351. # through it.
352. downloaders = []
353. for blocknum, vbucket in active_buckets.iteritems():
354. bd = BlockDownloader(vbucket, blocknum, self, self.results)
355. downloaders.append(bd)
356. if self.results:
357. self.results.servers_used.add(vbucket.bucket.get_peerid())
358. l = [bd.start(self.segmentnumber) for bd in downloaders]
359. return defer.DeferredList(l, fireOnOneErrback=True)
360.
361. def hold_block(self, blocknum, data):
362. self.blocks[blocknum] = data
363.
364. def bucket_failed(self, vbucket):
365. self.parent.bucket_failed(vbucket)
366.
367. class DownloadStatus:
368. implements(IDownloadStatus)
369. statusid_counter = itertools.count(0)
370.
371. def __init__(self):
372. self.storage_index = None
373. self.size = None
374. self.helper = False
375. self.status = "Not started"
376. self.progress = 0.0
377. self.paused = False
378. self.stopped = False
379. self.active = True
380. self.results = None
381. self.counter = self.statusid_counter.next()
382. self.started = time.time()
383.
384. def get_started(self):
385. return self.started
386. def get_storage_index(self):
387. return self.storage_index
388. def get_size(self):
389. return self.size
390. def using_helper(self):
391. return self.helper
392. def get_status(self):
393. status = self.status
394. if self.paused:
395. status += " (output paused)"
396. if self.stopped:
397. status += " (output stopped)"
398. return status
399. def get_progress(self):
400. return self.progress
401. def get_active(self):
402. return self.active
403. def get_results(self):
404. return self.results
405. def get_counter(self):
406. return self.counter
407.
408. def set_storage_index(self, si):
409. self.storage_index = si
410. def set_size(self, size):
411. self.size = size
412. def set_helper(self, helper):
413. self.helper = helper
414. def set_status(self, status):
415. self.status = status
416. def set_paused(self, paused):
417. self.paused = paused
418. def set_stopped(self, stopped):
419. self.stopped = stopped
420. def set_progress(self, value):
421. self.progress = value
422. def set_active(self, value):
423. self.active = value
424. def set_results(self, value):
425. self.results = value
426.
427. class FileDownloader:
428. implements(IPushProducer)
429. check_crypttext_hash = True
430. check_plaintext_hash = True
431. _status = None
432.
433. def __init__(self, client, u, downloadable):
434. self._client = client
435.
436. u = IFileURI(u)
437. self._storage_index = u.storage_index
438. self._uri_extension_hash = u.uri_extension_hash
439. self._total_shares = u.total_shares
440. self._size = u.size
441. self._num_needed_shares = u.needed_shares
442.
443. self._si_s = storage.si_b2a(self._storage_index)
444. self.init_logging()
445.
446. self._started = time.time()
447. self._status = s = DownloadStatus()
448. s.set_status("Starting")
449. s.set_storage_index(self._storage_index)
450. s.set_size(self._size)
451. s.set_helper(False)
452. s.set_active(True)
453.
454. self._results = DownloadResults()
455. s.set_results(self._results)
456. self._results.file_size = self._size
457. self._results.timings["servers_peer_selection"] = {}
458. self._results.timings["fetch_per_server"] = {}
459. self._results.timings["cumulative_fetch"] = 0.0
460. self._results.timings["cumulative_decode"] = 0.0
461. self._results.timings["cumulative_decrypt"] = 0.0
462. self._results.timings["paused"] = 0.0
463.
464. self._paused = False
465. self._stopped = False
466. if IConsumer.providedBy(downloadable):
467. downloadable.registerProducer(self, True)
468. self._downloadable = downloadable
469. self._output = Output(downloadable, u.key, self._size, self._log_number,
470. self._status)
471.
472. self.active_buckets = {} # k: shnum, v: bucket
473. self._share_buckets = [] # list of (sharenum, bucket) tuples
474. self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
475. self._uri_extension_sources = []
476.
477. self._uri_extension_data = None
478.
479. self._fetch_failures = {"uri_extension": 0,
480. "plaintext_hashroot": 0,
481. "plaintext_hashtree": 0,
482. "crypttext_hashroot": 0,
483. "crypttext_hashtree": 0,
484. }
485.
486. def init_logging(self):
487. self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
488. num = self._client.log(format="FileDownloader(%(si)s): starting",
489. si=storage.si_b2a(self._storage_index))
490. self._log_number = num
491.
492. def log(self, *args, **kwargs):
493. if "parent" not in kwargs:
494. kwargs["parent"] = self._log_number
495. if "facility" not in kwargs:
496. kwargs["facility"] = "tahoe.download"
497. return log.msg(*args, **kwargs)
498.
499. def pauseProducing(self):
500. if self._paused:
501. return
502. self._paused = defer.Deferred()
503. self._paused_at = time.time()
504. if self._status:
505. self._status.set_paused(True)
506.
507. def resumeProducing(self):
508. if self._paused:
509. paused_for = time.time() - self._paused_at
510. self._results.timings['paused'] += paused_for
511. p = self._paused
512. self._paused = None
513. eventually(p.callback, None)
514. if self._status:
515. self._status.set_paused(False)
516.
517. def stopProducing(self):
518. self.log("Download.stopProducing")
519. self._stopped = True
520. self.resumeProducing()
521. if self._status:
522. self._status.set_stopped(True)
523. self._status.set_active(False)
524.
525. def start(self):
526. self.log("starting download")
527.
528. # first step: who should we download from?
529. d = defer.maybeDeferred(self._get_all_shareholders)
530. d.addCallback(self._got_all_shareholders)
531. # now get the uri_extension block from somebody and validate it
532. d.addCallback(self._obtain_uri_extension)
533. d.addCallback(self._got_uri_extension)
534. d.addCallback(self._get_hashtrees)
535. d.addCallback(self._create_validated_buckets)
536. # once we know that, we can download blocks from everybody
537. d.addCallback(self._download_all_segments)
538. def _finished(res):
539. if self._status:
540. self._status.set_status("Finished")
541. self._status.set_active(False)
542. self._status.set_paused(False)
543. if IConsumer.providedBy(self._downloadable):
544. self._downloadable.unregisterProducer()
545. return res
546. d.addBoth(_finished)
547. def _failed(why):
548. if self._status:
549. self._status.set_status("Failed")
550. self._status.set_active(False)
551. self._output.fail(why)
552. return why
553. d.addErrback(_failed)
554. d.addCallback(self._done)
555. return d
556.
557. def _get_all_shareholders(self):
558. dl = []
559. for (peerid,ss) in self._client.get_permuted_peers("storage",
560. self._storage_index):
561. d = ss.callRemote("get_buckets", self._storage_index)
562. d.addCallbacks(self._got_response, self._got_error,
563. callbackArgs=(peerid,))
564. dl.append(d)
565. self._responses_received = 0
566. self._queries_sent = len(dl)
567. if self._status:
568. self._status.set_status("Locating Shares (%d/%d)" %
569. (self._responses_received,
570. self._queries_sent))
571. return defer.DeferredList(dl)
572.
573. def _got_response(self, buckets, peerid):
574. self._responses_received += 1
575. if self._results:
576. elapsed = time.time() - self._started
577. self._results.timings["servers_peer_selection"][peerid] = elapsed
578. if self._status:
579. self._status.set_status("Locating Shares (%d/%d)" %
580. (self._responses_received,
581. self._queries_sent))
582. for sharenum, bucket in buckets.iteritems():
583. b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
584. self.add_share_bucket(sharenum, b)
585. self._uri_extension_sources.append(b)
586. if self._results:
587. if peerid not in self._results.servermap:
588. self._results.servermap[peerid] = set()
589. self._results.servermap[peerid].add(sharenum)
590.
591. def add_share_bucket(self, sharenum, bucket):
592. # this is split out for the benefit of test_encode.py
593. self._share_buckets.append( (sharenum, bucket) )
594.
595. def _got_error(self, f):
596. level = log.WEIRD
597. if f.check(DeadReferenceError):
598. level = log.UNUSUAL
599. self._client.log("Error during get_buckets", failure=f, level=level,
600. umid="3uuBUQ")
601.
602. def bucket_failed(self, vbucket):
603. shnum = vbucket.sharenum
604. del self.active_buckets[shnum]
605. s = self._share_vbuckets[shnum]
606. # s is a set of ValidatedBucket instances
607. s.remove(vbucket)
608. # ... which might now be empty
609. if not s:
610. # there are no more buckets which can provide this share, so
611. # remove the key. This may prompt us to use a different share.
612. del self._share_vbuckets[shnum]
613.
614. def _got_all_shareholders(self, res):
615. if self._results:
616. now = time.time()
617. self._results.timings["peer_selection"] = now - self._started
618.
619. if len(self._share_buckets) < self._num_needed_shares:
620. raise NotEnoughSharesError
621.
622. #for s in self._share_vbuckets.values():
623. # for vb in s:
624. # assert isinstance(vb, ValidatedBucket), \
625. # "vb is %s but should be a ValidatedBucket" % (vb,)
626.
627. def _unpack_uri_extension_data(self, data):
628. return uri.unpack_extension(data)
629.
630. def _obtain_uri_extension(self, ignored):
631. # all shareholders are supposed to have a copy of uri_extension, and
632. # all are supposed to be identical. We compute the hash of the data
633. # that comes back, and compare it against the version in our URI. If
634. # they don't match, ignore their data and try someone else.
635. if self._status:
636. self._status.set_status("Obtaining URI Extension")
637.
638. self._uri_extension_fetch_started = time.time()
639. def _validate(proposal, bucket):
640. h = hashutil.uri_extension_hash(proposal)
641. if h != self._uri_extension_hash:
642. self._fetch_failures["uri_extension"] += 1
643. msg = ("The copy of uri_extension we received from "
644. "%s was bad: wanted %s, got %s" %
645. (bucket,
646. base32.b2a(self._uri_extension_hash),
647. base32.b2a(h)))
648. self.log(msg, level=log.SCARY, umid="jnkTtQ")
649. raise BadURIExtensionHashValue(msg)
650. return self._unpack_uri_extension_data(proposal)
651. return self._obtain_validated_thing(None,
652. self._uri_extension_sources,
653. "uri_extension",
654. "get_uri_extension", (), _validate)
655.
656. def _obtain_validated_thing(self, ignored, sources, name, methname, args,
657. validatorfunc):
658. if not sources:
659. raise NotEnoughSharesError("started with zero peers while fetching "
660. "%s" % name)
661. bucket = sources[0]
662. sources = sources[1:]
663. #d = bucket.callRemote(methname, *args)
664. d = bucket.startIfNecessary()
665. d.addCallback(lambda res: getattr(bucket, methname)(*args))
666. d.addCallback(validatorfunc, bucket)
667. def _bad(f):
668. level = log.WEIRD
669. if f.check(DeadReferenceError):
670. level = log.UNUSUAL
671. self.log(format="operation %(op)s from vbucket %(vbucket)s failed",
672. op=name, vbucket=str(bucket),
673. failure=f, level=level, umid="JGXxBA")
674. if not sources:
675. raise NotEnoughSharesError("ran out of peers, last error was %s"
676. % (f,))
677. # try again with a different one
678. return self._obtain_validated_thing(None, sources, name,
679. methname, args, validatorfunc)
680. d.addErrback(_bad)
681. return d
682.
683. def _got_uri_extension(self, uri_extension_data):
684. if self._results:
685. elapsed = time.time() - self._uri_extension_fetch_started
686. self._results.timings["uri_extension"] = elapsed
687.
688. d = self._uri_extension_data = uri_extension_data
689.
690. self._codec = codec.get_decoder_by_name(d['codec_name'])
691. self._codec.set_serialized_params(d['codec_params'])
692. self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
693. self._tail_codec.set_serialized_params(d['tail_codec_params'])
694.
695. crypttext_hash = d.get('crypttext_hash', None) # optional
696. if crypttext_hash:
697. assert isinstance(crypttext_hash, str)
698. assert len(crypttext_hash) == 32
699. self._crypttext_hash = crypttext_hash
700. self._plaintext_hash = d.get('plaintext_hash', None) # optional
701.
702. self._roothash = d['share_root_hash']
703.
704. self._segment_size = segment_size = d['segment_size']
705. self._total_segments = mathutil.div_ceil(self._size, segment_size)
706. self._current_segnum = 0
707.
708. self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
709. self._share_hashtree.set_hashes({0: self._roothash})
710.
711. def _get_hashtrees(self, res):
712. self._get_hashtrees_started = time.time()
713. if self._status:
714. self._status.set_status("Retrieving Hash Trees")
715. d = defer.maybeDeferred(self._get_plaintext_hashtrees)
716. d.addCallback(self._get_crypttext_hashtrees)
717. d.addCallback(self._setup_hashtrees)
718. return d
719.
720. def _get_plaintext_hashtrees(self):
721. # plaintext hashes are optional. If the root isn't in the UEB, then
722. # the share will be holding an empty list. We don't even bother
723. # fetching it.
724. if "plaintext_root_hash" not in self._uri_extension_data:
725. self._plaintext_hashtree = None
726. return
727. def _validate_plaintext_hashtree(proposal, bucket):
728. if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
729. self._fetch_failures["plaintext_hashroot"] += 1
730. msg = ("The copy of the plaintext_root_hash we received from"
731. " %s was bad" % bucket)
732. raise BadPlaintextHashValue(msg)
733. pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
734. pt_hashes = dict(list(enumerate(proposal)))
735. try:
736. pt_hashtree.set_hashes(pt_hashes)
737. except hashtree.BadHashError:
738. # the hashes they gave us were not self-consistent, even
739. # though the root matched what we saw in the uri_extension
740. # block
741. self._fetch_failures["plaintext_hashtree"] += 1
742. raise
743. self._plaintext_hashtree = pt_hashtree
744. d = self._obtain_validated_thing(None,
745. self._uri_extension_sources,
746. "plaintext_hashes",
747. "get_plaintext_hashes", (),
748. _validate_plaintext_hashtree)
749. return d
750.
751. def _get_crypttext_hashtrees(self, res):
752. # Ciphertext hash tree root is mandatory, so that there is at
753. # most one ciphertext that matches this read-cap or
754. # verify-cap. The integrity check on the shares is not
755. # sufficient to prevent the original encoder from creating
756. # some shares of file A and other shares of file B.
757. if "crypttext_root_hash" not in self._uri_extension_data:
758. raise BadURIExtension("URI Extension block did not have the ciphertext hash tree root")
759. def _validate_crypttext_hashtree(proposal, bucket):
760. if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
761. self._fetch_failures["crypttext_hashroot"] += 1
762. msg = ("The copy of the crypttext_root_hash we received from"
763. " %s was bad" % bucket)
764. raise BadCrypttextHashValue(msg)
765. ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
766. ct_hashes = dict(list(enumerate(proposal)))
767. try:
768. ct_hashtree.set_hashes(ct_hashes)
769. except hashtree.BadHashError:
770. self._fetch_failures["crypttext_hashtree"] += 1
771. raise
772. ct_hashtree.set_hashes(ct_hashes)
773. self._crypttext_hashtree = ct_hashtree
774. d = self._obtain_validated_thing(None,
775. self._uri_extension_sources,
776. "crypttext_hashes",
777. "get_crypttext_hashes", (),
778. _validate_crypttext_hashtree)
779. return d
780.
781. def _setup_hashtrees(self, res):
782. self._output.setup_hashtrees(self._plaintext_hashtree,
783. self._crypttext_hashtree)
784. if self._results:
785. elapsed = time.time() - self._get_hashtrees_started
786. self._results.timings["hashtrees"] = elapsed
787.
788. def _create_validated_buckets(self, ignored=None):
789. self._share_vbuckets = {}
790. for sharenum, bucket in self._share_buckets:
791. vbucket = ValidatedBucket(sharenum, bucket,
792. self._share_hashtree,
793. self._roothash,
794. self._total_segments)
795. s = self._share_vbuckets.setdefault(sharenum, set())
796. s.add(vbucket)
797.
798. def _activate_enough_buckets(self):
799. """either return a mapping from shnum to a ValidatedBucket that can
800. provide data for that share, or raise NotEnoughSharesError"""
801.
802. while len(self.active_buckets) < self._num_needed_shares:
803. # need some more
804. handled_shnums = set(self.active_buckets.keys())
805. available_shnums = set(self._share_vbuckets.keys())
806. potential_shnums = list(available_shnums - handled_shnums)
807. if not potential_shnums:
808. raise NotEnoughSharesError
809. # choose a random share
810. shnum = random.choice(potential_shnums)
811. # and a random bucket that will provide it
812. validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
813. self.active_buckets[shnum] = validated_bucket
814. return self.active_buckets
815.
816.
817. def _download_all_segments(self, res):
818. # the promise: upon entry to this function, self._share_vbuckets
819. # contains enough buckets to complete the download, and some extra
820. # ones to tolerate some buckets dropping out or having errors.
821. # self._share_vbuckets is a dictionary that maps from shnum to a set
822. # of ValidatedBuckets, which themselves are wrappers around
823. # RIBucketReader references.
824. self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
825.
826. self._started_fetching = time.time()
827.
828. d = defer.succeed(None)
829. for segnum in range(self._total_segments-1):
830. d.addCallback(self._download_segment, segnum)
831. # this pause, at the end of write, prevents pre-fetch from
832. # happening until the consumer is ready for more data.
833. d.addCallback(self._check_for_pause)
834. d.addCallback(self._download_tail_segment, self._total_segments-1)
835. return d
836.
837. def _check_for_pause(self, res):
838. if self._paused:
839. d = defer.Deferred()
840. self._paused.addCallback(lambda ignored: d.callback(res))
841. return d
842. if self._stopped:
843. raise DownloadStopped("our Consumer called stopProducing()")
844. return res
845.
846. def _download_segment(self, res, segnum):
847. if self._status:
848. self._status.set_status("Downloading segment %d of %d" %
849. (segnum+1, self._total_segments))
850. self.log("downloading seg#%d of %d (%d%%)"
851. % (segnum, self._total_segments,
852. 100.0 * segnum / self._total_segments))
853. # memory footprint: when the SegmentDownloader finishes pulling down
854. # all shares, we have 1*segment_size of usage.
855. segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
856. self._results)
857. started = time.time()
858. d = segmentdler.start()
859. def _finished_fetching(res):
860. elapsed = time.time() - started
861. self._results.timings["cumulative_fetch"] += elapsed
862. return res
863. if self._results:
864. d.addCallback(_finished_fetching)
865. # pause before using more memory
866. d.addCallback(self._check_for_pause)
867. # while the codec does its job, we hit 2*segment_size
868. def _started_decode(res):
869. self._started_decode = time.time()
870. return res
871. if self._results:
872. d.addCallback(_started_decode)
873. d.addCallback(lambda (shares, shareids):
874. self._codec.decode(shares, shareids))
875. # once the codec is done, we drop back to 1*segment_size, because
876. # 'shares' goes out of scope. The memory usage is all in the
877. # plaintext now, spread out into a bunch of tiny buffers.
878. def _finished_decode(res):
879. elapsed = time.time() - self._started_decode
880. self._results.timings["cumulative_decode"] += elapsed
881. return res
882. if self._results:
883. d.addCallback(_finished_decode)
884.
885. # pause/check-for-stop just before writing, to honor stopProducing
886. d.addCallback(self._check_for_pause)
887. def _done(buffers):
888. # we start by joining all these buffers together into a single
889. # string. This makes Output.write easier, since it wants to hash
890. # data one segment at a time anyways, and doesn't impact our
891. # memory footprint since we're already peaking at 2*segment_size
892. # inside the codec a moment ago.
893. segment = "".join(buffers)
894. del buffers
895. # we're down to 1*segment_size right now, but write_segment()
896. # will decrypt a copy of the segment internally, which will push
897. # us up to 2*segment_size while it runs.
898. started_decrypt = time.time()
899. self._output.write_segment(segment)
900. if self._results:
901. elapsed = time.time() - started_decrypt
902. self._results.timings["cumulative_decrypt"] += elapsed
903. d.addCallback(_done)
904. return d
905.
906. def _download_tail_segment(self, res, segnum):
907. self.log("downloading seg#%d of %d (%d%%)"
908. % (segnum, self._total_segments,
909. 100.0 * segnum / self._total_segments))
910. segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
911. self._results)
912. started = time.time()
913. d = segmentdler.start()
914. def _finished_fetching(res):
915. elapsed = time.time() - started
916. self._results.timings["cumulative_fetch"] += elapsed
917. return res
918. if self._results:
919. d.addCallback(_finished_fetching)
920. # pause before using more memory
921. d.addCallback(self._check_for_pause)
922. def _started_decode(res):
923. self._started_decode = time.time()
924. return res
925. if self._results:
926. d.addCallback(_started_decode)
927. d.addCallback(lambda (shares, shareids):
928. self._tail_codec.decode(shares, shareids))
929. def _finished_decode(res):
930. elapsed = time.time() - self._started_decode
931. self._results.timings["cumulative_decode"] += elapsed
932. return res
933. if self._results:
934. d.addCallback(_finished_decode)
935. # pause/check-for-stop just before writing, to honor stopProducing
936. d.addCallback(self._check_for_pause)
937. def _done(buffers):
938. # trim off any padding added by the upload side
939. segment = "".join(buffers)
940. del buffers
941. # we never send empty segments. If the data was an exact multiple
942. # of the segment size, the last segment will be full.
943. pad_size = mathutil.pad_size(self._size, self._segment_size)
944. tail_size = self._segment_size - pad_size
945. segment = segment[:tail_size]
946. started_decrypt = time.time()
947. self._output.write_segment(segment)
948. if self._results:
949. elapsed = time.time() - started_decrypt
950. self._results.timings["cumulative_decrypt"] += elapsed
951. d.addCallback(_done)
952. return d
953.
954. def _done(self, res):
955. self.log("download done")
956. if self._results:
957. now = time.time()
958. self._results.timings["total"] = now - self._started
959. self._results.timings["segments"] = now - self._started_fetching
960. self._output.close()
961. if self.check_crypttext_hash and self._crypttext_hash:
962. _assert(self._crypttext_hash == self._output.crypttext_hash,
963. "bad crypttext_hash: computed=%s, expected=%s" %
964. (base32.b2a(self._output.crypttext_hash),
965. base32.b2a(self._crypttext_hash)))
966. if self.check_plaintext_hash and self._plaintext_hash:
967. _assert(self._plaintext_hash == self._output.plaintext_hash,
968. "bad plaintext_hash: computed=%s, expected=%s" %
969. (base32.b2a(self._output.plaintext_hash),
970. base32.b2a(self._plaintext_hash)))
971. _assert(self._output.length == self._size,
972. got=self._output.length, expected=self._size)
973. return self._output.finish()
974.
975. def get_download_status(self):
976. return self._status
977.
978.
979. class FileName:
980. implements(IDownloadTarget)
981. def __init__(self, filename):
982. self._filename = filename
983. self.f = None
984. def open(self, size):
985. self.f = open(self._filename, "wb")
986. return self.f
987. def write(self, data):
988. self.f.write(data)
989. def close(self):
990. if self.f:
991. self.f.close()
992. def fail(self, why):
993. if self.f:
994. self.f.close()
995. os.unlink(self._filename)
996. def register_canceller(self, cb):
997. pass # we won't use it
998. def finish(self):
999. pass
1000.
1001. class Data:
1002. implements(IDownloadTarget)
1003. def __init__(self):
1004. self._data = []
1005. def open(self, size):
1006. pass
1007. def write(self, data):
1008. self._data.append(data)
1009. def close(self):
1010. self.data = "".join(self._data)
1011. del self._data
1012. def fail(self, why):
1013. del self._data
1014. def register_canceller(self, cb):
1015. pass # we won't use it
1016. def finish(self):
1017. return self.data
1018.
1019. class FileHandle:
1020. """Use me to download data to a pre-defined filehandle-like object. I
1021. will use the target's write() method. I will *not* close the filehandle:
1022. I leave that up to the originator of the filehandle. The download process
1023. will return the filehandle when it completes.
1024. """
1025. implements(IDownloadTarget)
1026. def __init__(self, filehandle):
1027. self._filehandle = filehandle
1028. def open(self, size):
1029. pass
1030. def write(self, data):
1031. self._filehandle.write(data)
1032. def close(self):
1033. # the originator of the filehandle reserves the right to close it
1034. pass
1035. def fail(self, why):
1036. pass
1037. def register_canceller(self, cb):
1038. pass
1039. def finish(self):
1040. return self._filehandle
1041.
1042. class Downloader(service.MultiService):
1043. """I am a service that allows file downloading.
1044. """
1045. # TODO: in fact, this service only downloads immutable files (URI:CHK:).
1046. # It is scheduled to go away, to be replaced by filenode.download()
1047. implements(IDownloader)
1048. name = "downloader"
1049. MAX_DOWNLOAD_STATUSES = 10
1050.
1051. def __init__(self, stats_provider=None):
1052. service.MultiService.__init__(self)
1053. self.stats_provider = stats_provider
1054. self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1055. self._all_download_statuses = weakref.WeakKeyDictionary()
1056. self._recent_download_statuses = []
1057.
1058. def download(self, u, t):
1059. assert self.parent
1060. assert self.running
1061. u = IFileURI(u)
1062. t = IDownloadTarget(t)
1063. assert t.write
1064. assert t.close
1065.
1066. assert isinstance(u, uri.CHKFileURI)
1067. if self.stats_provider:
1068. # these counters are meant for network traffic, and don't
1069. # include LIT files
1070. self.stats_provider.count('downloader.files_downloaded', 1)
1071. self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1072. dl = FileDownloader(self.parent, u, t)
1073. self._add_download(dl)
1074. d = dl.start()
1075. return d
1076.
1077. # utility functions
1078. def download_to_data(self, uri):
1079. return self.download(uri, Data())
1080. def download_to_filename(self, uri, filename):
1081. return self.download(uri, FileName(filename))
1082. def download_to_filehandle(self, uri, filehandle):
1083. return self.download(uri, FileHandle(filehandle))
1084.
1085. def _add_download(self, downloader):
1086. self._all_downloads[downloader] = None
1087. s = downloader.get_download_status()
1088. self._all_download_statuses[s] = None
1089. self._recent_download_statuses.append(s)
1090. while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1091. self._recent_download_statuses.pop(0)
1092.
1093. def list_all_download_statuses(self):
1094. for ds in self._all_download_statuses:
1095. yield ds