Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions src/bagit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,12 +896,12 @@
if processes == 1:
hash_results = [_calc_hashes(i) for i in args]
else:
pool = multiprocessing.Pool(
processes if processes else None, initializer=worker_init
hash_results = _multiprocessing_pool_map(
_calc_hashes,
args,
processes if processes else None,
initializer=worker_init,
)
hash_results = pool.map(_calc_hashes, args)
pool.close()
pool.join()

# Any unhandled exceptions are probably fatal
except:
Expand Down Expand Up @@ -1037,6 +1037,25 @@
signal.signal(signal.SIGINT, signal.SIG_IGN)


def _multiprocessing_pool_map(func, iterable, processes, initializer=None):
"""Run ``Pool.map()`` and always clean up the pool.

This ensures worker processes are closed or terminated, then joined, under
all conditions.
"""
pool = multiprocessing.Pool(processes=processes, initializer=initializer)
try:
results = pool.map(func, iterable)
except BaseException:
pool.terminate()
raise
else:
pool.close()
return results
finally:
pool.join()


# The Unicode normalization form used here doesn't matter – all we care about
# is consistency since the input value will be preserved:

Expand Down Expand Up @@ -1245,10 +1264,9 @@
manifest_line_generator = partial(generate_manifest_lines, algorithms=algorithms)

if processes > 1:
pool = multiprocessing.Pool(processes=processes)
checksums = pool.map(manifest_line_generator, _walk(data_dir))
pool.close()
pool.join()
checksums = _multiprocessing_pool_map(
manifest_line_generator, _walk(data_dir), processes=processes
)
else:
checksums = [manifest_line_generator(i) for i in _walk(data_dir)]

Expand Down Expand Up @@ -1583,7 +1601,7 @@
else:
LOGGER.info(_("%s is valid"), bag_dir)
except BagError as e:
LOGGER.error(

Check failure on line 1604 in src/bagit/__init__.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "logging.exception()" instead.

See more on https://sonarcloud.io/project/issues?id=LibraryOfCongress_bagit-python&issues=AZ7WfiStvaLYFV3NH828&open=AZ7WfiStvaLYFV3NH828&pullRequest=208
_("%(bag)s is invalid: %(error)s"), {"bag": bag_dir, "error": e}
)
rc = 1
Expand All @@ -1598,7 +1616,7 @@
checksums=args.checksums,
)
except Exception as exc:
LOGGER.error(

Check failure on line 1619 in src/bagit/__init__.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "logging.exception()" instead.

See more on https://sonarcloud.io/project/issues?id=LibraryOfCongress_bagit-python&issues=AZ7WfiStvaLYFV3NH829&open=AZ7WfiStvaLYFV3NH829&pullRequest=208
_("Failed to create bag in %(bag_directory)s: %(error)s"),
{"bag_directory": bag_dir, "error": exc},
exc_info=True,
Expand Down
35 changes: 33 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import tempfile
import unicodedata
import unittest
from io import StringIO
from os.path import join as j

from unittest import mock
from io import StringIO

import bagit

Expand Down Expand Up @@ -458,6 +457,23 @@ def validate(self, bag, *args, **kwargs):
bag, *args, processes=2, **kwargs
)

@mock.patch("bagit.multiprocessing.Pool")
def test_validate_multiprocessing_terminates_and_joins_pool_on_failure(self, pool):
pool.return_value.map.side_effect = RuntimeError("boom")
bag = bagit.make_bag(self.tmpdir)

with self.assertRaises(RuntimeError):
self.validate(bag)

self.assertEqual(
pool.return_value.mock_calls,
[
mock.call.map(mock.ANY, mock.ANY),
mock.call.terminate(),
mock.call.join(),
],
)

@mock.patch("bagit.multiprocessing.Pool")
def test_validate_pool_error(self, pool):
# Simulate the Pool constructor raising a RuntimeError.
Expand Down Expand Up @@ -745,6 +761,21 @@ def test_make_bag_multiprocessing(self):
bagit.make_bag(self.tmpdir, processes=2)
self.assertTrue(os.path.isdir(j(self.tmpdir, "data")))

@mock.patch("bagit.multiprocessing.Pool")
def test_make_bag_multiprocessing_terminates_and_joins_pool_on_failure(self, pool):
pool.return_value.map.side_effect = RuntimeError("boom")
with self.assertRaises(RuntimeError):
bagit.make_bag(self.tmpdir, processes=2)

self.assertEqual(
pool.return_value.mock_calls,
[
mock.call.map(mock.ANY, mock.ANY),
mock.call.terminate(),
mock.call.join(),
],
)

def test_multiple_meta_values(self):
baginfo = {"Multival-Meta": [7, 4, 8, 6, 8]}
bag = bagit.make_bag(self.tmpdir, baginfo)
Expand Down
Loading