From 1cfa3b3d8d09241eafca92654d443dffc38dbd3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Garc=C3=ADa=20Crespo?= Date: Tue, 16 Jun 2026 07:15:13 +0200 Subject: [PATCH] Clean up multiprocessing pools on failure Ensure bagit-python reliably cleans up multiprocessing pools when parallel manifest generation or validation fails with processes > 1. Successful make_bag() calls already cleaned up correctly through the normal graceful pool path. The gap was failures inside Pool.map(): worker exceptions skipped close() and join(), allowing BagIt child processes to remain alive after the caller received the error. In long-running services, those leftover children can later become orphaned or defunct when the owning worker exits. This follows up on c451b24 ("Wait for validation Pool to finish"), the validation-pool cleanup work Douglas and I did, by extending the same reliability expectation to failure paths. --- src/bagit/__init__.py | 36 +++++++++++++++++++++++++++--------- test.py | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/src/bagit/__init__.py b/src/bagit/__init__.py index 890dbd8..2b4adce 100755 --- a/src/bagit/__init__.py +++ b/src/bagit/__init__.py @@ -896,12 +896,12 @@ def _validate_entries(self, processes): if processes == 1: hash_results = [_calc_hashes(i) for i in args] else: - pool = multiprocessing.Pool( - processes if processes else None, initializer=worker_init + hash_results = _multiprocessing_pool_map( + _calc_hashes, + args, + processes if processes else None, + initializer=worker_init, ) - hash_results = pool.map(_calc_hashes, args) - pool.close() - pool.join() # Any unhandled exceptions are probably fatal except: @@ -1037,6 +1037,25 @@ def posix_multiprocessing_worker_initializer(): signal.signal(signal.SIGINT, signal.SIG_IGN) +def _multiprocessing_pool_map(func, iterable, processes, initializer=None): + """Run ``Pool.map()`` and always clean up the pool. + + This ensures worker processes are closed or terminated, then joined, under + all conditions. + """ + pool = multiprocessing.Pool(processes=processes, initializer=initializer) + try: + results = pool.map(func, iterable) + except BaseException: + pool.terminate() + raise + else: + pool.close() + return results + finally: + pool.join() + + # The Unicode normalization form used here doesn't matter – all we care about # is consistency since the input value will be preserved: @@ -1245,10 +1264,9 @@ def make_manifests(data_dir, processes, algorithms=DEFAULT_CHECKSUMS, encoding=" manifest_line_generator = partial(generate_manifest_lines, algorithms=algorithms) if processes > 1: - pool = multiprocessing.Pool(processes=processes) - checksums = pool.map(manifest_line_generator, _walk(data_dir)) - pool.close() - pool.join() + checksums = _multiprocessing_pool_map( + manifest_line_generator, _walk(data_dir), processes=processes + ) else: checksums = [manifest_line_generator(i) for i in _walk(data_dir)] diff --git a/test.py b/test.py index f79bca0..a0e7f31 100644 --- a/test.py +++ b/test.py @@ -13,10 +13,9 @@ import tempfile import unicodedata import unittest +from io import StringIO from os.path import join as j - from unittest import mock -from io import StringIO import bagit @@ -458,6 +457,23 @@ def validate(self, bag, *args, **kwargs): bag, *args, processes=2, **kwargs ) + @mock.patch("bagit.multiprocessing.Pool") + def test_validate_multiprocessing_terminates_and_joins_pool_on_failure(self, pool): + pool.return_value.map.side_effect = RuntimeError("boom") + bag = bagit.make_bag(self.tmpdir) + + with self.assertRaises(RuntimeError): + self.validate(bag) + + self.assertEqual( + pool.return_value.mock_calls, + [ + mock.call.map(mock.ANY, mock.ANY), + mock.call.terminate(), + mock.call.join(), + ], + ) + @mock.patch("bagit.multiprocessing.Pool") def test_validate_pool_error(self, pool): # Simulate the Pool constructor raising a RuntimeError. @@ -745,6 +761,21 @@ def test_make_bag_multiprocessing(self): bagit.make_bag(self.tmpdir, processes=2) self.assertTrue(os.path.isdir(j(self.tmpdir, "data"))) + @mock.patch("bagit.multiprocessing.Pool") + def test_make_bag_multiprocessing_terminates_and_joins_pool_on_failure(self, pool): + pool.return_value.map.side_effect = RuntimeError("boom") + with self.assertRaises(RuntimeError): + bagit.make_bag(self.tmpdir, processes=2) + + self.assertEqual( + pool.return_value.mock_calls, + [ + mock.call.map(mock.ANY, mock.ANY), + mock.call.terminate(), + mock.call.join(), + ], + ) + def test_multiple_meta_values(self): baginfo = {"Multival-Meta": [7, 4, 8, 6, 8]} bag = bagit.make_bag(self.tmpdir, baginfo)