diff --git a/src/bagit/__init__.py b/src/bagit/__init__.py index 890dbd8..fe17cbf 100755 --- a/src/bagit/__init__.py +++ b/src/bagit/__init__.py @@ -10,6 +10,7 @@ import os import re import signal +import shutil import sys import tempfile import time @@ -537,6 +538,77 @@ def save(self, processes=1, manifests=False): os.chdir(old_dir) + + def update_payload(self, processes=1): + """ + Rebuild payload manifests and tag manifests after payload changes. + """ + self.save(processes=processes, manifests=True) + + + def add_payload(self, src, dest=None, processes=1): + """ + Copy a file into the payload directory and rebuild manifests. + """ + if not os.path.isfile(src) and not os.path.isdir(src): + raise ValueError("Payload source must be a file or directory: %s" % src) + + if dest is None: + dest = os.path.basename(src) + + dest = os.path.normpath(dest) + + if os.path.isabs(dest) or dest.startswith("..") or os.path.expanduser(dest) != dest: + raise ValueError("Payload destination is unsafe: %s" % dest) + + payload_dest = os.path.join("data", dest) + + if self._path_is_dangerous(payload_dest): + raise ValueError("Payload destination is unsafe: %s" % dest) + + dst = os.path.join(self.path, payload_dest) + + if os.path.isfile(src): + dst_dir = os.path.dirname(dst) + if not os.path.isdir(dst_dir): + os.makedirs(dst_dir) + + shutil.copy2(src, dst) + + elif os.path.isdir(src): + if os.path.exists(dst): + raise ValueError("Payload destination already exists: %s" % dest) + + shutil.copytree(src, dst) + + self.update_payload(processes=processes) + + + def remove_payload(self, path, processes=1, recursive=False): + """ + Remove a payload file and rebuild manifests. + """ + payload_path = os.path.normpath(path) + + if self._path_is_dangerous(payload_path): + raise ValueError("Payload path is unsafe: %s" % path) + + if not payload_path.startswith("data" + os.sep): + raise ValueError("Payload path must start with data/: %s" % path) + + full_path = os.path.join(self.path, payload_path) + + if os.path.isfile(full_path): + os.remove(full_path) + elif os.path.isdir(full_path): + if not recursive: + raise ValueError("Payload path is a directory: %s" % path) + shutil.rmtree(full_path) + else: + raise ValueError("Payload path does not exist: %s" % path) + + self.update_payload(processes=processes) + def tagfile_entries(self): return dict( (key, value) @@ -808,6 +880,17 @@ def _validate_contents(self, processes=1, fast=False, completeness_only=False): self._validate_entries(processes) + def payload_oxum(self): + total_bytes = 0 + total_files = 0 + + for payload_file in self.payload_files(): + payload_file = os.path.join(self.path, payload_file) + total_bytes += os.stat(payload_file).st_size + total_files += 1 + + return total_bytes, total_files + def _validate_oxum(self): oxum = self.info.get("Payload-Oxum") @@ -827,13 +910,8 @@ def _validate_oxum(self): oxum_byte_count = int(oxum_byte_count) oxum_file_count = int(oxum_file_count) - total_bytes = 0 - total_files = 0 - for payload_file in self.payload_files(): - payload_file = os.path.join(self.path, payload_file) - total_bytes += os.stat(payload_file).st_size - total_files += 1 + total_bytes, total_files = self.payload_oxum() if oxum_file_count != total_files or oxum_byte_count != total_bytes: raise BagValidationError( @@ -1583,7 +1661,7 @@ def main(): else: LOGGER.info(_("%s is valid"), bag_dir) except BagError as e: - LOGGER.error( + LOGGER.exception( _("%(bag)s is invalid: %(error)s"), {"bag": bag_dir, "error": e} ) rc = 1 @@ -1598,7 +1676,7 @@ def main(): checksums=args.checksums, ) except Exception as exc: - LOGGER.error( + LOGGER.exception( _("Failed to create bag in %(bag_directory)s: %(error)s"), {"bag_directory": bag_dir, "error": exc}, exc_info=True, diff --git a/test.py b/test.py index f79bca0..703f2c7 100644 --- a/test.py +++ b/test.py @@ -1011,6 +1011,248 @@ def test_open_bag_with_unknown_encoding(self): self.assertEqual("Unsupported encoding: WTF-8", str(error_catcher.exception)) + def test_update_payload_added_file(self): + bag = bagit.make_bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + + with open(j(self.tmpdir, "data", "newfile"), "w") as nf: + nf.write("newfile") + + bag = bagit.Bag(self.tmpdir) + self.assertFalse(bag.is_valid()) + + bag.update_payload() + + bag = bagit.Bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + + def test_update_payload_deleted_file(self): + bag = bagit.make_bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + + os.remove(j(self.tmpdir, "data", "loc", "2478433644_2839c5e8b8_o_d.jpg")) + + bag = bagit.Bag(self.tmpdir) + self.assertFalse(bag.is_valid()) + + bag.update_payload() + + bag = bagit.Bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + + def test_payload_oxum(self): + bag = bagit.make_bag(self.tmpdir, checksums=["md5"]) + self.assertEqual(bag.payload_oxum(), (991765, 5)) + + def test_payload_oxum_after_payload_change(self): + bagit.make_bag(self.tmpdir, checksums=["md5"]) + + with open(j(self.tmpdir, "data", "newfile"), "w") as nf: + nf.write("newfile") + + bag = bagit.Bag(self.tmpdir) + self.assertEqual(bag.payload_oxum(), (991772, 6)) + + def test_add_payload(self): + bag = bagit.make_bag(self.tmpdir) + + extra = os.path.join(self.tmpdir, "extra.txt") + + with open(extra, "w") as f: + f.write("hello world") + + bag.add_payload(extra) + + bag = bagit.Bag(self.tmpdir) + + self.assertTrue( + os.path.isfile( + os.path.join(self.tmpdir, "data", "extra.txt") + ) + ) + self.assertTrue(bag.is_valid()) + + def test_add_payload_updates_oxum(self): + bag = bagit.make_bag(self.tmpdir) + + old_bytes, old_files = bag.payload_oxum() + + extra = os.path.join(self.tmpdir, "extra.txt") + + with open(extra, "w") as f: + f.write("hello") + + bag.add_payload(extra) + + bag = bagit.Bag(self.tmpdir) + + new_bytes, new_files = bag.payload_oxum() + + self.assertEqual(new_files, old_files + 1) + self.assertEqual(new_bytes, old_bytes + 5) + + self.assertTrue(bag.is_valid()) + + def test_add_payload_missing_file(self): + bag = bagit.make_bag(self.tmpdir) + + with self.assertRaises(ValueError): + bag.add_payload("/does/not/exist") + + def test_remove_payload(self): + bag = bagit.make_bag(self.tmpdir) + + payload_file = "data/README" + full_path = j(self.tmpdir, payload_file) + + self.assertTrue(os.path.isfile(full_path)) + + bag.remove_payload(payload_file) + + bag = bagit.Bag(self.tmpdir) + + self.assertFalse(os.path.exists(full_path)) + self.assertTrue(bag.is_valid()) + + def test_remove_payload_updates_oxum(self): + bag = bagit.make_bag(self.tmpdir) + + old_bytes, old_files = bag.payload_oxum() + removed_size = os.stat(j(self.tmpdir, "data", "README")).st_size + + bag.remove_payload("data/README") + + bag = bagit.Bag(self.tmpdir) + + new_bytes, new_files = bag.payload_oxum() + + self.assertEqual(new_files, old_files - 1) + self.assertEqual(new_bytes, old_bytes - removed_size) + + def test_remove_payload_rejects_non_payload_path(self): + bag = bagit.make_bag(self.tmpdir) + + with self.assertRaises(ValueError): + bag.remove_payload("bag-info.txt") + + def test_add_payload_with_destination(self): + bag = bagit.make_bag(self.tmpdir) + + extra = os.path.join(self.tmpdir, "extra.txt") + with open(extra, "w") as f: + f.write("hello") + + bag.add_payload(extra, "masters/extra.txt") + + self.assertTrue(os.path.isfile(j(self.tmpdir, "data", "masters", "extra.txt"))) + + bag = bagit.Bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + + def test_add_payload_rejects_unsafe_destination(self): + bag = bagit.make_bag(self.tmpdir) + + extra = os.path.join(self.tmpdir, "extra.txt") + with open(extra, "w") as f: + f.write("hello") + + unsafe_destinations = [ + "../extra.txt", + "subdir/../../extra.txt", + os.path.abspath(os.path.join(self.tmpdir, "..", "extra.txt")), + "~/.ssh/id_rsa", + ] + + if os.name == "nt": + unsafe_destinations.extend( + [ + r"..\extra.txt", + r"subdir\..\..\extra.txt", + r"C:\Windows\system32\cmd.exe", + r"\\server\share\file.txt", + ] + ) + + for dest in unsafe_destinations: + with self.subTest(dest=dest): + with self.assertRaises(ValueError): + bag.add_payload(extra, "../extra.txt") + + def test_remove_payload_directory_recursive(self): + bag = bagit.make_bag(self.tmpdir) + + payload_dir = j(self.tmpdir, "data", "extra") + os.makedirs(payload_dir) + + with open(j(payload_dir, "one.txt"), "w") as f: + f.write("one") + + with open(j(payload_dir, "two.txt"), "w") as f: + f.write("two") + + bag.update_payload() + self.assertTrue(bag.is_valid()) + + bag.remove_payload("data/extra", recursive=True) + + bag = bagit.Bag(self.tmpdir) + + self.assertFalse(os.path.exists(payload_dir)) + self.assertTrue(bag.is_valid()) + + def test_remove_payload_directory_requires_recursive(self): + bag = bagit.make_bag(self.tmpdir) + + payload_dir = j(self.tmpdir, "data", "extra") + os.makedirs(payload_dir) + + with self.assertRaises(ValueError): + bag.remove_payload("data/extra") + + def test_add_payload_directory(self): + bag = bagit.make_bag(self.tmpdir) + + src_dir = tempfile.mkdtemp() + try: + with open(j(src_dir, "one.txt"), "w") as f: + f.write("one") + + os.mkdir(j(src_dir, "nested")) + + with open(j(src_dir, "nested", "two.txt"), "w") as f: + f.write("two") + + bag.add_payload(src_dir) + + dest_dir = j(self.tmpdir, "data", os.path.basename(src_dir)) + + self.assertTrue(os.path.isdir(dest_dir)) + self.assertTrue(os.path.isfile(j(dest_dir, "one.txt"))) + self.assertTrue(os.path.isfile(j(dest_dir, "nested", "two.txt"))) + + bag = bagit.Bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + finally: + shutil.rmtree(src_dir) + + def test_add_payload_directory_with_destination(self): + bag = bagit.make_bag(self.tmpdir) + + src_dir = tempfile.mkdtemp() + try: + with open(j(src_dir, "one.txt"), "w") as f: + f.write("one") + + bag.add_payload(src_dir, "imports/sip") + + self.assertTrue(os.path.isdir(j(self.tmpdir, "data", "imports", "sip"))) + self.assertTrue(os.path.isfile(j(self.tmpdir, "data", "imports", "sip", "one.txt"))) + + bag = bagit.Bag(self.tmpdir) + self.assertTrue(bag.is_valid()) + finally: + shutil.rmtree(src_dir) + class TestFetch(SelfCleaningTestCase): def setUp(self):