diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/_emerge/EbuildFetcher.py | 68 | ||||
-rw-r--r-- | lib/portage/package/ebuild/fetch.py | 102 | ||||
-rw-r--r-- | lib/portage/tests/ebuild/test_fetch.py | 100 | ||||
-rw-r--r-- | lib/portage/tests/util/test_socks5.py | 16 |
4 files changed, 254 insertions, 32 deletions
diff --git a/lib/_emerge/EbuildFetcher.py b/lib/_emerge/EbuildFetcher.py index 81d4b1054..994271236 100644 --- a/lib/_emerge/EbuildFetcher.py +++ b/lib/_emerge/EbuildFetcher.py @@ -4,6 +4,8 @@ import copy import functools import io +import multiprocessing +import signal import sys import portage @@ -17,11 +19,12 @@ from portage.package.ebuild.fetch import ( _check_distfile, _drop_privs_userfetch, _want_userfetch, - fetch, + async_fetch, ) from portage.util._async.AsyncTaskFuture import AsyncTaskFuture from portage.util._async.ForkProcess import ForkProcess from portage.util._pty import _create_pty_or_pipe +from portage.util.futures import asyncio from _emerge.CompositeTask import CompositeTask @@ -34,6 +37,7 @@ class EbuildFetcher(CompositeTask): "logfile", "pkg", "prefetch", + "pre_exec", "_fetcher_proc", ) @@ -253,6 +257,7 @@ class _EbuildFetcherProcess(ForkProcess): self._get_manifest(), self._uri_map, self.fetchonly, + self.pre_exec, ) ForkProcess._start(self) @@ -263,7 +268,10 @@ class _EbuildFetcherProcess(ForkProcess): self._settings = None @staticmethod - def _target(settings, manifest, uri_map, fetchonly): + def _target(settings, manifest, uri_map, fetchonly, pre_exec): + if pre_exec is not None: + pre_exec() + # Force consistent color output, in case we are capturing fetch # output through a normal pipe due to unavailability of ptys. portage.output.havecolor = settings.get("NOCOLOR") not in ("yes", "true") @@ -273,17 +281,53 @@ class _EbuildFetcherProcess(ForkProcess): if _want_userfetch(settings): _drop_privs_userfetch(settings) - rval = 1 allow_missing = manifest.allow_missing or "digest" in settings.features - if fetch( - uri_map, - settings, - fetchonly=fetchonly, - digests=copy.deepcopy(manifest.getTypeDigests("DIST")), - allow_missing_digests=allow_missing, - ): - rval = os.EX_OK - return rval + + async def main(): + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + async_fetch( + uri_map, + settings, + fetchonly=fetchonly, + digests=copy.deepcopy(manifest.getTypeDigests("DIST")), + allow_missing_digests=allow_missing, + ) + ) + + def sigterm_handler(signum, _frame): + loop.call_soon_threadsafe(task.cancel) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + signal.signal(signal.SIGTERM, sigterm_handler) + try: + await task + except asyncio.CancelledError: + # If asyncio.CancelledError arrives too soon after fork/spawn + # then handers will not have an opportunity to terminate + # the corresponding process, so clean up after this race. + for proc in multiprocessing.active_children(): + proc.terminate() + + # Use a non-zero timeout only for the first join because + # later joins are delayed by the first join. + timeout = 0.25 + for proc in multiprocessing.active_children(): + proc.join(timeout) + timeout = 0 + + for proc in multiprocessing.active_children(): + proc.kill() + # Wait upon the process in order to ensure that its + # pid will trigger ProcessLookupError for tests. + proc.join() + + signal.signal(signal.SIGTERM, signal.SIG_DFL) + os.kill(os.getpid(), signal.SIGTERM) + + return os.EX_OK if task.result() else 1 + + return asyncio.run(main()) def _get_ebuild_path(self): if self.ebuild_path is not None: diff --git a/lib/portage/package/ebuild/fetch.py b/lib/portage/package/ebuild/fetch.py index ed40cf6ed..0947f45aa 100644 --- a/lib/portage/package/ebuild/fetch.py +++ b/lib/portage/package/ebuild/fetch.py @@ -1,4 +1,4 @@ -# Copyright 2010-2021 Gentoo Authors +# Copyright 2010-2024 Gentoo Authors # Distributed under the terms of the GNU General Public License v2 __all__ = ["fetch"] @@ -73,6 +73,7 @@ from portage.util import ( writemsg_level, writemsg_stdout, ) +from portage.util.futures import asyncio from portage.process import spawn _download_suffix = ".__download__" @@ -111,7 +112,7 @@ def _drop_privs_userfetch(settings): """ spawn_kwargs = dict(_userpriv_spawn_kwargs) try: - _ensure_distdir(settings, settings["DISTDIR"]) + asyncio.run(_ensure_distdir(settings, settings["DISTDIR"])) except PortageException: if not os.path.isdir(settings["DISTDIR"]): raise @@ -179,13 +180,37 @@ def _spawn_fetch(settings, args, **kwargs): return rval +# Instrumentation hooks for use by unit tests. +_async_spawn_fetch_pre_wait = None +_async_spawn_fetch_post_terminate = None + + +async def _async_spawn_fetch(settings, args, **kwargs): + kwargs["returnproc"] = True + proc = _spawn_fetch(settings, args, **kwargs) + try: + if _async_spawn_fetch_pre_wait is not None: + _async_spawn_fetch_pre_wait(proc) + return await proc.wait() + except asyncio.CancelledError: + proc.terminate() + if _async_spawn_fetch_post_terminate is not None: + _async_spawn_fetch_post_terminate(proc) + raise + + +_async_spawn_fetch.__doc__ = _spawn_fetch.__doc__ +_async_spawn_fetch.__doc__ += """ + This function is a coroutine. +""" + _userpriv_test_write_file_cache = {} _userpriv_test_write_cmd_script = ( ">> %(file_path)s 2>/dev/null ; rval=$? ; " + "rm -f %(file_path)s ; exit $rval" ) -def _userpriv_test_write_file(settings, file_path): +async def _userpriv_test_write_file(settings, file_path): """ Drop privileges and try to open a file for writing. The file may or may not exist, and the parent directory is assumed to exist. The file @@ -201,20 +226,26 @@ def _userpriv_test_write_file(settings, file_path): if rval is not None: return rval + # Optimize away the spawn when privileges do not need to be dropped. + if not _want_userfetch(settings): + rval = os.access(os.path.dirname(file_path), os.W_OK) + _userpriv_test_write_file_cache[file_path] = rval + return rval + args = [ BASH_BINARY, "-c", _userpriv_test_write_cmd_script % {"file_path": _shell_quote(file_path)}, ] - returncode = _spawn_fetch(settings, args) + returncode = await _async_spawn_fetch(settings, args) rval = returncode == os.EX_OK _userpriv_test_write_file_cache[file_path] = rval return rval -def _ensure_distdir(settings, distdir): +async def _ensure_distdir(settings, distdir): """ Ensure that DISTDIR exists with appropriate permissions. @@ -240,7 +271,7 @@ def _ensure_distdir(settings, distdir): userpriv = portage.data.secpass >= 2 and "userpriv" in settings.features write_test_file = os.path.join(distdir, ".__portage_test_write__") - if _userpriv_test_write_file(settings, write_test_file): + if await _userpriv_test_write_file(settings, write_test_file): return _userpriv_test_write_file_cache.pop(write_test_file, None) @@ -687,7 +718,12 @@ def get_mirror_url(mirror_url, filename, mysettings, cache_path=None): @param cache_path: Path for mirror metadata cache @return: Full URL to fetch """ + return asyncio.run( + async_mirror_url(mirror_url, filename, mysettings, cache_path=cache_path) + ) + +async def async_mirror_url(mirror_url, filename, mysettings, cache_path=None): mirror_conf = MirrorLayoutConfig() cache = {} @@ -708,7 +744,7 @@ def get_mirror_url(mirror_url, filename, mysettings, cache_path=None): if mirror_url[:1] == "/": tmpfile = os.path.join(mirror_url, "layout.conf") mirror_conf.read_from_file(tmpfile) - elif fetch( + elif await async_fetch( {tmpfile: (mirror_url + "/distfiles/layout.conf",)}, mysettings, force=1, @@ -738,6 +774,12 @@ def get_mirror_url(mirror_url, filename, mysettings, cache_path=None): return mirror_url + "/distfiles/" + path +async_mirror_url.__doc__ = get_mirror_url.__doc__ +async_mirror_url.__doc__ += """ + This function is a coroutine. +""" + + def fetch( myuris, mysettings, @@ -783,6 +825,34 @@ def fetch( @rtype: int @return: 1 if successful, 0 otherwise. """ + return asyncio.run( + async_fetch( + myuris, + mysettings, + listonly=listonly, + fetchonly=fetchonly, + locks_in_subdir=locks_in_subdir, + use_locks=use_locks, + try_mirrors=try_mirrors, + digests=digests, + allow_missing_digests=allow_missing_digests, + force=force, + ) + ) + + +async def async_fetch( + myuris, + mysettings, + listonly=0, + fetchonly=0, + locks_in_subdir=".locks", + use_locks=1, + try_mirrors=1, + digests=None, + allow_missing_digests=True, + force=False, +): if force and digests: # Since the force parameter can trigger unnecessary fetch when the @@ -1050,7 +1120,7 @@ def fetch( for l in itertools.chain(*location_lists): filedict[myfile].append( functools.partial( - get_mirror_url, l, myfile, mysettings, mirror_cache + async_mirror_url, l, myfile, mysettings, mirror_cache ) ) if myuri is None: @@ -1119,7 +1189,7 @@ def fetch( if can_fetch and not fetch_to_ro: try: - _ensure_distdir(mysettings, mysettings["DISTDIR"]) + await _ensure_distdir(mysettings, mysettings["DISTDIR"]) except PortageException as e: if not os.path.isdir(mysettings["DISTDIR"]): writemsg(f"!!! {str(e)}\n", noiselevel=-1) @@ -1381,7 +1451,7 @@ def fetch( if distdir_writable and ro_distdirs: readonly_file = None for x in ro_distdirs: - filename = get_mirror_url(x, myfile, mysettings) + filename = await async_mirror_url(x, myfile, mysettings) match, mystat = _check_distfile( filename, pruned_digests, eout, hash_filter=hash_filter ) @@ -1427,7 +1497,7 @@ def fetch( if fsmirrors and not os.path.exists(myfile_path) and has_space: for mydir in fsmirrors: - mirror_file = get_mirror_url(mydir, myfile, mysettings) + mirror_file = await async_mirror_url(mydir, myfile, mysettings) try: shutil.copyfile(mirror_file, download_path) writemsg(_("Local mirror has file: %s\n") % myfile) @@ -1554,7 +1624,7 @@ def fetch( while uri_list: loc = uri_list.pop() if isinstance(loc, functools.partial): - loc = loc() + loc = await loc() # Eliminate duplicates here in case we've switched to # "primaryuri" mode on the fly due to a checksum failure. if loc in tried_locations: @@ -1740,7 +1810,7 @@ def fetch( myret = -1 try: - myret = _spawn_fetch(mysettings, myfetch) + myret = await _async_spawn_fetch(mysettings, myfetch) finally: try: @@ -1992,3 +2062,9 @@ def fetch( if failed_files: return 0 return 1 + + +async_fetch.__doc__ = fetch.__doc__ +async_fetch.__doc__ += """ + This function is a coroutine. +""" diff --git a/lib/portage/tests/ebuild/test_fetch.py b/lib/portage/tests/ebuild/test_fetch.py index 1856bb52b..1ad958036 100644 --- a/lib/portage/tests/ebuild/test_fetch.py +++ b/lib/portage/tests/ebuild/test_fetch.py @@ -1,9 +1,11 @@ -# Copyright 2019-2023 Gentoo Authors +# Copyright 2019-2024 Gentoo Authors # Distributed under the terms of the GNU General Public License v2 import functools import io +import multiprocessing import shlex +import signal import tempfile import types @@ -36,6 +38,85 @@ from _emerge.Package import Package class EbuildFetchTestCase(TestCase): + + async def _test_interrupt(self, loop, server, async_fetch, pkg, ebuild_path): + """Test interrupt, with server responses temporarily paused.""" + server.pause() + pr, pw = multiprocessing.Pipe(duplex=False) + timeout = loop.create_future() + loop.add_reader(pr.fileno(), lambda: timeout.done() or timeout.set_result(None)) + self.assertEqual( + await async_fetch( + pkg, + ebuild_path, + timeout=timeout, + pre_exec=functools.partial(self._pre_exec_interrupt_patch, pw), + ), + -signal.SIGTERM, + ) + loop.remove_reader(pr.fileno()) + pw.close() + + # Read pid written by _async_spawn_fetch_pre_wait hook (the + # corresponding write served to trigger the timeout above). + pid = pr.recv() + + # Read pid written by _async_spawn_fetch_post_terminate hook, + # in order to know when the ProcessLookupError test should + # succeed. + pid = pr.recv() + pr.close() + + # Poll the process table until the pid has disappeared, + # and fail if a short timeout expires. + tries = 10 + while tries: + tries -= 1 + + msg = None + if tries <= 0: + try: + with open(f"/proc/{pid}/status") as f: + for line in f: + if line.startswith("State:"): + msg = line + break + except OSError: + pass + + try: + with self.assertRaises(ProcessLookupError, msg=msg): + os.kill(pid, 0) + except Exception: + if tries <= 0: + raise + await asyncio.sleep(0.1) + else: + break + + server.resume() + + @staticmethod + def _pre_exec_interrupt_patch(pw): + portage.package.ebuild.fetch._async_spawn_fetch_pre_wait = functools.partial( + EbuildFetchTestCase._fetch_pre_wait, + pw, + ) + portage.package.ebuild.fetch._async_spawn_fetch_post_terminate = ( + functools.partial( + EbuildFetchTestCase._fetch_post_terminate, + pw, + ) + ) + + @staticmethod + def _fetch_pre_wait(pw, proc): + pw.send(proc.pid) + + @staticmethod + def _fetch_post_terminate(pw, proc): + pw.send(proc.pid) + def testEbuildFetch(self): user_config = { "make.conf": ('GENTOO_MIRRORS="{scheme}://{host}:{port}"',), @@ -338,7 +419,7 @@ class EbuildFetchTestCase(TestCase): config_pool = config_pool_cls(settings) - def async_fetch(pkg, ebuild_path): + def async_fetch(pkg, ebuild_path, pre_exec=None, timeout=None): fetcher = EbuildFetcher( config_pool=config_pool, ebuild_path=ebuild_path, @@ -346,9 +427,15 @@ class EbuildFetchTestCase(TestCase): fetchall=True, pkg=pkg, scheduler=loop, + pre_exec=pre_exec, ) fetcher.start() - return fetcher.async_wait() + waiter = fetcher.async_wait() + if timeout is not None: + timeout.add_done_callback( + lambda timeout: waiter.done() or fetcher.cancel() + ) + return waiter for cpv in ebuilds: metadata = dict( @@ -414,6 +501,13 @@ class EbuildFetchTestCase(TestCase): with open(os.path.join(settings["DISTDIR"], k), "rb") as f: self.assertEqual(f.read(), distfiles[k]) + # Test interrupt, with server responses temporarily paused. + for k in settings["AA"].split(): + os.unlink(os.path.join(settings["DISTDIR"], k)) + loop.run_until_complete( + self._test_interrupt(loop, server, async_fetch, pkg, ebuild_path) + ) + # Test empty files in DISTDIR for k in settings["AA"].split(): file_path = os.path.join(settings["DISTDIR"], k) diff --git a/lib/portage/tests/util/test_socks5.py b/lib/portage/tests/util/test_socks5.py index a8cd0c46c..e7bc2d699 100644 --- a/lib/portage/tests/util/test_socks5.py +++ b/lib/portage/tests/util/test_socks5.py @@ -58,19 +58,27 @@ class AsyncHTTPServer: self.server_port = None self._httpd = None + def pause(self): + """Pause responses (useful for testing timeouts).""" + self._loop.remove_reader(self._httpd.socket.fileno()) + + def resume(self): + """Resume responses following a previous call to pause.""" + self._loop.add_reader( + self._httpd.socket.fileno(), self._httpd._handle_request_noblock + ) + def __enter__(self): httpd = self._httpd = HTTPServer( (self._host, 0), functools.partial(_Handler, self._content) ) self.server_port = httpd.server_port - self._loop.add_reader( - httpd.socket.fileno(), self._httpd._handle_request_noblock - ) + self.resume() return self def __exit__(self, exc_type, exc_value, exc_traceback): if self._httpd is not None: - self._loop.remove_reader(self._httpd.socket.fileno()) + self.pause() self._httpd.socket.close() self._httpd = None |