source: trunk/src/allmydata/test/no_network.py

Last change on this file was cb83b089, checked in by Itamar Turner-Trauring <itamar@…>, at 2023-10-16T15:16:10Z

Decouple from reactor

  • Property mode set to 100644
File size: 24.0 KB
Line 
1"""
2This contains a test harness that creates a full Tahoe grid in a single
3process (actually in a single MultiService) which does not use the network.
4It does not use an Introducer, and there are no foolscap Tubs. Each storage
5server puts real shares on disk, but is accessed through loopback
6RemoteReferences instead of over serialized SSL. It is not as complete as
7the common.SystemTestMixin framework (which does use the network), but
8should be considerably faster: on my laptop, it takes 50-80ms to start up,
9whereas SystemTestMixin takes close to 2s.
10
11This should be useful for tests which want to examine and/or manipulate the
12uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13Tubs, so it is not useful for tests that involve a Helper.
14"""
15
16from __future__ import annotations
17
18from six import ensure_text
19
20from typing import Callable
21
22import os
23from base64 import b32encode
24from functools import (
25    partial,
26)
27from zope.interface import implementer
28from twisted.application import service
29from twisted.internet import defer
30from twisted.python.failure import Failure
31from twisted.web.error import Error
32from foolscap.api import Referenceable, fireEventually, RemoteException
33from foolscap.ipb import (
34    IRemoteReference,
35)
36import treq
37
38from allmydata.util.assertutil import _assert
39
40from allmydata import uri as tahoe_uri
41from allmydata.client import _Client
42from allmydata.storage.server import (
43    StorageServer, storage_index_to_dir, FoolscapStorageServer,
44)
45from allmydata.util import fileutil, idlib, hashutil
46from allmydata.util.hashutil import permute_server_hash
47from allmydata.util.fileutil import abspath_expanduser_unicode
48from allmydata.interfaces import IStorageBroker, IServer
49from allmydata.storage_client import (
50    _StorageServer,
51)
52from .common import (
53    SameProcessStreamEndpointAssigner,
54)
55
56
57class IntentionalError(Exception):
58    pass
59
60class Marker(object):
61    pass
62
63fireNow = partial(defer.succeed, None)
64
65@implementer(IRemoteReference)  # type: ignore  # warner/foolscap#79
66class LocalWrapper(object):
67    """
68    A ``LocalWrapper`` presents the remote reference interface to a local
69    object which implements a ``RemoteInterface``.
70    """
71    def __init__(self, original, fireEventually=fireEventually):
72        """
73        :param Callable[[], Deferred[None]] fireEventually: Get a Deferred
74            that will fire at some point.  This is used to control when
75            ``callRemote`` calls the remote method.  The default value allows
76            the reactor to iterate before the call happens.  Use ``fireNow``
77            to call the remote method synchronously.
78        """
79        self.original = original
80        self.broken = False
81        self.hung_until = None
82        self.post_call_notifier = None
83        self.disconnectors = {}
84        self.counter_by_methname = {}
85        self._fireEventually = fireEventually
86
87    def _clear_counters(self):
88        self.counter_by_methname = {}
89
90    def callRemoteOnly(self, methname, *args, **kwargs):
91        d = self.callRemote(methname, *args, **kwargs)
92        del d # explicitly ignored
93        return None
94
95    def callRemote(self, methname, *args, **kwargs):
96        # this is ideally a Membrane, but that's too hard. We do a shallow
97        # wrapping of inbound arguments, and per-methodname wrapping of
98        # selected return values.
99        def wrap(a):
100            if isinstance(a, Referenceable):
101                return self._wrap(a)
102            else:
103                return a
104        args = tuple([wrap(a) for a in args])
105        kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
106
107        def _really_call():
108            def incr(d, k): d[k] = d.setdefault(k, 0) + 1
109            incr(self.counter_by_methname, methname)
110            meth = getattr(self.original, "remote_" + methname)
111            return meth(*args, **kwargs)
112
113        def _call():
114            if self.broken:
115                if self.broken is not True: # a counter, not boolean
116                    self.broken -= 1
117                raise IntentionalError("I was asked to break")
118            if self.hung_until:
119                d2 = defer.Deferred()
120                self.hung_until.addCallback(lambda ign: _really_call())
121                self.hung_until.addCallback(lambda res: d2.callback(res))
122                def _err(res):
123                    d2.errback(res)
124                    return res
125                self.hung_until.addErrback(_err)
126                return d2
127            return _really_call()
128
129        d = self._fireEventually()
130        d.addCallback(lambda res: _call())
131        def _wrap_exception(f):
132            return Failure(RemoteException(f))
133        d.addErrback(_wrap_exception)
134        def _return_membrane(res):
135            # rather than complete the difficult task of building a
136            # fully-general Membrane (which would locate all Referenceable
137            # objects that cross the simulated wire and replace them with
138            # wrappers), we special-case certain methods that we happen to
139            # know will return Referenceables.
140            if methname == "allocate_buckets":
141                (alreadygot, allocated) = res
142                for shnum in allocated:
143                    allocated[shnum] = self._wrap(allocated[shnum])
144            if methname == "get_buckets":
145                for shnum in res:
146                    res[shnum] = self._wrap(res[shnum])
147            return res
148        d.addCallback(_return_membrane)
149        if self.post_call_notifier:
150            d.addCallback(self.post_call_notifier, self, methname)
151        return d
152
153    def notifyOnDisconnect(self, f, *args, **kwargs):
154        m = Marker()
155        self.disconnectors[m] = (f, args, kwargs)
156        return m
157    def dontNotifyOnDisconnect(self, marker):
158        del self.disconnectors[marker]
159
160    def _wrap(self, value):
161        return LocalWrapper(value, self._fireEventually)
162
163
164def wrap_storage_server(original):
165    # Much of the upload/download code uses rref.version (which normally
166    # comes from rrefutil.add_version_to_remote_reference). To avoid using a
167    # network, we want a LocalWrapper here. Try to satisfy all these
168    # constraints at the same time.
169    wrapper = LocalWrapper(original)
170    wrapper.version = original.remote_get_version()
171    return wrapper
172
173@implementer(IServer)
174class NoNetworkServer(object):
175    def __init__(self, serverid, rref):
176        self.serverid = serverid
177        self.rref = rref
178    def __repr__(self):
179        return "<NoNetworkServer for %s>" % self.get_name()
180    # Special method used by copy.copy() and copy.deepcopy(). When those are
181    # used in allmydata.immutable.filenode to copy CheckResults during
182    # repair, we want it to treat the IServer instances as singletons.
183    def __copy__(self):
184        return self
185    def __deepcopy__(self, memodict):
186        return self
187
188    def upload_permitted(self):
189        return True
190
191    def get_serverid(self):
192        return self.serverid
193    def get_permutation_seed(self):
194        return self.serverid
195    def get_lease_seed(self):
196        return self.serverid
197    def get_foolscap_write_enabler_seed(self):
198        return self.serverid
199
200    def get_name(self):
201        # Other implementations return bytes.
202        return idlib.shortnodeid_b2a(self.serverid).encode("utf-8")
203    def get_longname(self):
204        return idlib.nodeid_b2a(self.serverid)
205    def get_nickname(self):
206        return "nickname"
207    def get_rref(self):
208        return self.rref
209    def get_storage_server(self):
210        if self.rref is None:
211            return None
212        return _StorageServer(lambda: self.rref)
213    def get_version(self):
214        return self.rref.version
215    def start_connecting(self, trigger_cb):
216        raise NotImplementedError
217
218
219@implementer(IStorageBroker)
220class NoNetworkStorageBroker(object):  # type: ignore # missing many methods
221    def get_servers_for_psi(self, peer_selection_index, for_upload=True):
222        def _permuted(server):
223            seed = server.get_permutation_seed()
224            return permute_server_hash(peer_selection_index, seed)
225        return sorted(self.get_connected_servers(), key=_permuted)
226    def get_connected_servers(self):
227        return self.client._servers
228    def get_nickname_for_serverid(self, serverid):
229        return None
230    def when_connected_enough(self, threshold):
231        return defer.Deferred()
232    def get_all_serverids(self):
233        return []  # FIXME?
234    def get_known_servers(self):
235        return []  # FIXME?
236
237
238def create_no_network_client(basedir):
239    """
240    :return: a Deferred yielding an instance of _Client subclass which
241        does no actual networking but has the same API.
242    """
243    basedir = abspath_expanduser_unicode(str(basedir))
244    fileutil.make_dirs(os.path.join(basedir, "private"), 0o700)
245
246    from allmydata.client import read_config
247    config = read_config(basedir, u'client.port')
248    storage_broker = NoNetworkStorageBroker()
249    client = _NoNetworkClient(
250        config,
251        main_tub=None,
252        i2p_provider=None,
253        tor_provider=None,
254        introducer_clients=[],
255        storage_farm_broker=storage_broker
256    )
257    # this is a (pre-existing) reference-cycle and also a bad idea, see:
258    # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949
259    storage_broker.client = client
260    return defer.succeed(client)
261
262
263class _NoNetworkClient(_Client):  # type: ignore  # tahoe-lafs/ticket/3573
264    """
265    Overrides all _Client networking functionality to do nothing.
266    """
267
268    def init_connections(self):
269        pass
270    def create_main_tub(self):
271        pass
272    def init_introducer_client(self):
273        pass
274    def create_log_tub(self):
275        pass
276    def setup_logging(self):
277        pass
278    def startService(self):
279        service.MultiService.startService(self)
280    def stopService(self):
281        return service.MultiService.stopService(self)
282    def init_helper(self):
283        pass
284    def init_key_gen(self):
285        pass
286    def init_storage(self):
287        pass
288    def init_client_storage_broker(self):
289        self.storage_broker = NoNetworkStorageBroker()
290        self.storage_broker.client = self
291    def init_stub_client(self):
292        pass
293    #._servers will be set by the NoNetworkGrid which creates us
294
295
296class SimpleStats(object):
297    def __init__(self):
298        self.counters = {}
299        self.stats_producers = []
300
301    def count(self, name, delta=1):
302        val = self.counters.setdefault(name, 0)
303        self.counters[name] = val + delta
304
305    def register_producer(self, stats_producer):
306        self.stats_producers.append(stats_producer)
307
308    def get_stats(self):
309        stats = {}
310        for sp in self.stats_producers:
311            stats.update(sp.get_stats())
312        ret = { 'counters': self.counters, 'stats': stats }
313        return ret
314
315class NoNetworkGrid(service.MultiService):
316    def __init__(self, basedir, num_clients, num_servers,
317                 client_config_hooks, port_assigner):
318        service.MultiService.__init__(self)
319
320        # We really need to get rid of this pattern here (and
321        # everywhere) in Tahoe where "async work" is started in
322        # __init__ For now, we at least keep the errors so they can
323        # cause tests to fail less-improperly (see _check_clients)
324        self._setup_errors = []
325
326        self.port_assigner = port_assigner
327        self.basedir = basedir
328        fileutil.make_dirs(basedir)
329
330        self.servers_by_number = {} # maps to StorageServer instance
331        self.wrappers_by_id = {} # maps to wrapped StorageServer instance
332        self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
333                                # StorageServer
334        self.clients = []
335        self.client_config_hooks = client_config_hooks
336
337        for i in range(num_servers):
338            ss = self.make_server(i)
339            self.add_server(i, ss)
340        self.rebuild_serverlist()
341
342        for i in range(num_clients):
343            d = self.make_client(i)
344            d.addCallback(lambda c: self.clients.append(c))
345
346            def _bad(f):
347                self._setup_errors.append(f)
348            d.addErrback(_bad)
349
350    def _check_clients(self):
351        """
352        The anti-pattern of doing async work in __init__ means we need to
353        check if that work completed successfully. This method either
354        returns nothing or raises an exception in case __init__ failed
355        to complete properly
356        """
357        if self._setup_errors:
358            self._setup_errors[0].raiseException()
359
360    @defer.inlineCallbacks
361    def make_client(self, i, write_config=True):
362        clientid = hashutil.tagged_hash(b"clientid", b"%d" % i)[:20]
363        clientdir = os.path.join(self.basedir, "clients",
364                                 idlib.shortnodeid_b2a(clientid))
365        fileutil.make_dirs(clientdir)
366
367        tahoe_cfg_path = os.path.join(clientdir, "tahoe.cfg")
368        if write_config:
369            from twisted.internet import reactor
370            _, port_endpoint = self.port_assigner.assign(reactor)
371            with open(tahoe_cfg_path, "w") as f:
372                f.write("[node]\n")
373                f.write("nickname = client-%d\n" % i)
374                f.write("web.port = {}\n".format(port_endpoint))
375                f.write("[storage]\n")
376                f.write("enabled = false\n")
377        else:
378            _assert(os.path.exists(tahoe_cfg_path), tahoe_cfg_path=tahoe_cfg_path)
379
380        c = None
381        if i in self.client_config_hooks:
382            # this hook can either modify tahoe.cfg, or return an
383            # entirely new Client instance
384            c = self.client_config_hooks[i](clientdir)
385
386        if not c:
387            c = yield create_no_network_client(clientdir)
388
389        c.nodeid = clientid
390        c.short_nodeid = b32encode(clientid).lower()[:8]
391        c._servers = self.all_servers # can be updated later
392        c.setServiceParent(self)
393        defer.returnValue(c)
394
395    def make_server(self, i, readonly=False):
396        serverid = hashutil.tagged_hash(b"serverid", b"%d" % i)[:20]
397        serverdir = os.path.join(self.basedir, "servers",
398                                 idlib.shortnodeid_b2a(serverid), "storage")
399        fileutil.make_dirs(serverdir)
400        ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
401                           readonly_storage=readonly)
402        ss._no_network_server_number = i
403        return ss
404
405    def add_server(self, i, ss):
406        # to deal with the fact that all StorageServers are named 'storage',
407        # we interpose a middleman
408        middleman = service.MultiService()
409        middleman.setServiceParent(self)
410        ss.setServiceParent(middleman)
411        serverid = ss.my_nodeid
412        self.servers_by_number[i] = ss
413        wrapper = wrap_storage_server(FoolscapStorageServer(ss))
414        self.wrappers_by_id[serverid] = wrapper
415        self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
416        self.rebuild_serverlist()
417
418    def get_all_serverids(self):
419        return list(self.proxies_by_id.keys())
420
421    def rebuild_serverlist(self):
422        self._check_clients()
423        self.all_servers = frozenset(list(self.proxies_by_id.values()))
424        for c in self.clients:
425            c._servers = self.all_servers
426
427    def remove_server(self, serverid):
428        # it's enough to remove the server from c._servers (we don't actually
429        # have to detach and stopService it)
430        for i,ss in list(self.servers_by_number.items()):
431            if ss.my_nodeid == serverid:
432                del self.servers_by_number[i]
433                break
434        del self.wrappers_by_id[serverid]
435        del self.proxies_by_id[serverid]
436        self.rebuild_serverlist()
437        return ss
438
439    def break_server(self, serverid, count=True):
440        # mark the given server as broken, so it will throw exceptions when
441        # asked to hold a share or serve a share. If count= is a number,
442        # throw that many exceptions before starting to work again.
443        self.wrappers_by_id[serverid].broken = count
444
445    def hang_server(self, serverid):
446        # hang the given server
447        ss = self.wrappers_by_id[serverid]
448        assert ss.hung_until is None
449        ss.hung_until = defer.Deferred()
450
451    def unhang_server(self, serverid):
452        # unhang the given server
453        ss = self.wrappers_by_id[serverid]
454        assert ss.hung_until is not None
455        ss.hung_until.callback(None)
456        ss.hung_until = None
457
458    def nuke_from_orbit(self):
459        """ Empty all share directories in this grid. It's the only way to be sure ;-) """
460        for server in list(self.servers_by_number.values()):
461            for prefixdir in os.listdir(server.sharedir):
462                if prefixdir != 'incoming':
463                    fileutil.rm_dir(os.path.join(server.sharedir, prefixdir))
464
465
466class GridTestMixin(object):
467    def setUp(self):
468        self.s = service.MultiService()
469        self.s.startService()
470        return super(GridTestMixin, self).setUp()
471
472    def tearDown(self):
473        return defer.gatherResults([
474            self.s.stopService(),
475            defer.maybeDeferred(super(GridTestMixin, self).tearDown),
476        ])
477
478    def set_up_grid(self, num_clients=1, num_servers=10,
479                    client_config_hooks=None, oneshare=False):
480        """
481        Create a Tahoe-LAFS storage grid.
482
483        :param num_clients: See ``NoNetworkGrid``
484        :param num_servers: See `NoNetworkGrid``
485        :param client_config_hooks: See ``NoNetworkGrid``
486
487        :param bool oneshare: If ``True`` then the first client node is
488            configured with ``n == k == happy == 1``.
489
490        :return: ``None``
491        """
492        if client_config_hooks is None:
493            client_config_hooks = {}
494        # self.basedir must be set
495        port_assigner = SameProcessStreamEndpointAssigner()
496        port_assigner.setUp()
497        self.addCleanup(port_assigner.tearDown)
498        self.g = NoNetworkGrid(self.basedir,
499                               num_clients=num_clients,
500                               num_servers=num_servers,
501                               client_config_hooks=client_config_hooks,
502                               port_assigner=port_assigner,
503        )
504        self.g.setServiceParent(self.s)
505        if oneshare:
506            c = self.get_client(0)
507            c.encoding_params["k"] = 1
508            c.encoding_params["happy"] = 1
509            c.encoding_params["n"] = 1
510        self._record_webports_and_baseurls()
511
512    def _record_webports_and_baseurls(self):
513        self.g._check_clients()
514        self.client_webports = [c.getServiceNamed("webish").getPortnum()
515                                for c in self.g.clients]
516        self.client_baseurls = [c.getServiceNamed("webish").getURL()
517                                for c in self.g.clients]
518
519    def get_client_config(self, i=0):
520        self.g._check_clients()
521        return self.g.clients[i].config
522
523    def get_clientdir(self, i=0):
524        # ideally, use something get_client_config() only, we
525        # shouldn't need to manipulate raw paths..
526        return self.get_client_config(i).get_config_path()
527
528    def get_client(self, i=0):
529        self.g._check_clients()
530        return self.g.clients[i]
531
532    def restart_client(self, i=0):
533        self.g._check_clients()
534        client = self.g.clients[i]
535        d = defer.succeed(None)
536        d.addCallback(lambda ign: self.g.removeService(client))
537
538        @defer.inlineCallbacks
539        def _make_client(ign):
540            c = yield self.g.make_client(i, write_config=False)
541            self.g.clients[i] = c
542            self._record_webports_and_baseurls()
543        d.addCallback(_make_client)
544        return d
545
546    def get_serverdir(self, i):
547        return self.g.servers_by_number[i].storedir
548
549    def iterate_servers(self):
550        for i in sorted(self.g.servers_by_number.keys()):
551            ss = self.g.servers_by_number[i]
552            yield (i, ss, ss.storedir)
553
554    def find_uri_shares(self, uri):
555        si = tahoe_uri.from_string(uri).get_storage_index()
556        prefixdir = storage_index_to_dir(si)
557        shares = []
558        for i,ss in list(self.g.servers_by_number.items()):
559            serverid = ss.my_nodeid
560            basedir = os.path.join(ss.sharedir, prefixdir)
561            if not os.path.exists(basedir):
562                continue
563            for f in os.listdir(basedir):
564                try:
565                    shnum = int(f)
566                    shares.append((shnum, serverid, os.path.join(basedir, f)))
567                except ValueError:
568                    pass
569        return sorted(shares)
570
571    def copy_shares(self, uri: bytes) -> dict[bytes, bytes]:
572        """
573        Read all of the share files for the given capability from the storage area
574        of the storage servers created by ``set_up_grid``.
575
576        :param bytes uri: A Tahoe-LAFS data capability.
577
578        :return: A ``dict`` mapping share file names to share file contents.
579        """
580        shares = {}
581        for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
582            with open(sharefile, "rb") as f:
583                shares[sharefile] = f.read()
584        return shares
585
586    def restore_all_shares(self, shares):
587        for sharefile, data in list(shares.items()):
588            with open(sharefile, "wb") as f:
589                f.write(data)
590
591    def delete_share(self, sharenum_and_serverid_and_sharefile):
592        (shnum, serverid, sharefile) = sharenum_and_serverid_and_sharefile
593        os.unlink(sharefile)
594
595    def delete_shares_numbered(self, uri, shnums):
596        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
597            if i_shnum in shnums:
598                os.unlink(i_sharefile)
599
600    def delete_all_shares(self, serverdir):
601        sharedir = os.path.join(serverdir, "shares")
602        for prefixdir in os.listdir(sharedir):
603            if prefixdir != 'incoming':
604                fileutil.rm_dir(os.path.join(sharedir, prefixdir))
605
606    def corrupt_share(self, sharenum_and_serverid_and_sharefile, corruptor_function):
607        (shnum, serverid, sharefile) = sharenum_and_serverid_and_sharefile
608        with open(sharefile, "rb") as f:
609            sharedata = f.read()
610        corruptdata = corruptor_function(sharedata)
611        with open(sharefile, "wb") as f:
612            f.write(corruptdata)
613
614    def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
615        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
616            if i_shnum in shnums:
617                with open(i_sharefile, "rb") as f:
618                    sharedata = f.read()
619                corruptdata = corruptor(sharedata, debug=debug)
620                with open(i_sharefile, "wb") as f:
621                    f.write(corruptdata)
622
623    def corrupt_all_shares(self, uri: bytes, corruptor: Callable[[bytes, bool], bytes], debug: bool=False):
624        """
625        Apply ``corruptor`` to the contents of all share files associated with a
626        given capability and replace the share file contents with its result.
627        """
628        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
629            with open(i_sharefile, "rb") as f:
630                sharedata = f.read()
631            corruptdata = corruptor(sharedata, debug)
632            with open(i_sharefile, "wb") as f:
633                f.write(corruptdata)
634
635    @defer.inlineCallbacks
636    def GET(self, urlpath, followRedirect=False, return_response=False,
637            method="GET", clientnum=0, **kwargs):
638        # if return_response=True, this fires with (data, statuscode,
639        # respheaders) instead of just data.
640        url = self.client_baseurls[clientnum] + ensure_text(urlpath)
641
642        response = yield treq.request(method, url, persistent=False,
643                                      allow_redirects=followRedirect,
644                                      **kwargs)
645        data = yield response.content()
646        if return_response:
647            # we emulate the old HTTPClientGetFactory-based response, which
648            # wanted a tuple of (bytestring of data, bytestring of response
649            # code like "200" or "404", and a
650            # twisted.web.http_headers.Headers instance). Fortunately treq's
651            # response.headers has one.
652            defer.returnValue( (data, str(response.code), response.headers) )
653        if 400 <= response.code < 600:
654            raise Error(response.code, response=data)
655        defer.returnValue(data)
656
657    def PUT(self, urlpath, **kwargs):
658        return self.GET(urlpath, method="PUT", **kwargs)
Note: See TracBrowser for help on using the repository browser.