Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions src/bagit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import re
import signal
import shutil
import sys
import tempfile
import time
Expand Down Expand Up @@ -537,6 +538,77 @@ def save(self, processes=1, manifests=False):

os.chdir(old_dir)


def update_payload(self, processes=1):
"""
Rebuild payload manifests and tag manifests after payload changes.
"""
self.save(processes=processes, manifests=True)


def add_payload(self, src, dest=None, processes=1):
"""
Copy a file into the payload directory and rebuild manifests.
"""
if not os.path.isfile(src) and not os.path.isdir(src):
raise ValueError("Payload source must be a file or directory: %s" % src)

if dest is None:
dest = os.path.basename(src)

dest = os.path.normpath(dest)

if os.path.isabs(dest) or dest.startswith("..") or os.path.expanduser(dest) != dest:
raise ValueError("Payload destination is unsafe: %s" % dest)

payload_dest = os.path.join("data", dest)

if self._path_is_dangerous(payload_dest):
raise ValueError("Payload destination is unsafe: %s" % dest)

dst = os.path.join(self.path, payload_dest)

if os.path.isfile(src):
dst_dir = os.path.dirname(dst)
if not os.path.isdir(dst_dir):
os.makedirs(dst_dir)

shutil.copy2(src, dst)

elif os.path.isdir(src):
if os.path.exists(dst):
raise ValueError("Payload destination already exists: %s" % dest)

shutil.copytree(src, dst)

self.update_payload(processes=processes)


def remove_payload(self, path, processes=1, recursive=False):
"""
Remove a payload file and rebuild manifests.
"""
payload_path = os.path.normpath(path)

if self._path_is_dangerous(payload_path):
raise ValueError("Payload path is unsafe: %s" % path)

if not payload_path.startswith("data" + os.sep):
raise ValueError("Payload path must start with data/: %s" % path)

full_path = os.path.join(self.path, payload_path)

if os.path.isfile(full_path):
os.remove(full_path)
elif os.path.isdir(full_path):
if not recursive:
raise ValueError("Payload path is a directory: %s" % path)
shutil.rmtree(full_path)
else:
raise ValueError("Payload path does not exist: %s" % path)

self.update_payload(processes=processes)

def tagfile_entries(self):
return dict(
(key, value)
Expand Down Expand Up @@ -808,6 +880,17 @@ def _validate_contents(self, processes=1, fast=False, completeness_only=False):

self._validate_entries(processes)

def payload_oxum(self):
total_bytes = 0
total_files = 0

for payload_file in self.payload_files():
payload_file = os.path.join(self.path, payload_file)
total_bytes += os.stat(payload_file).st_size
total_files += 1

return total_bytes, total_files

def _validate_oxum(self):
oxum = self.info.get("Payload-Oxum")

Expand All @@ -827,13 +910,8 @@ def _validate_oxum(self):

oxum_byte_count = int(oxum_byte_count)
oxum_file_count = int(oxum_file_count)
total_bytes = 0
total_files = 0

for payload_file in self.payload_files():
payload_file = os.path.join(self.path, payload_file)
total_bytes += os.stat(payload_file).st_size
total_files += 1
total_bytes, total_files = self.payload_oxum()

if oxum_file_count != total_files or oxum_byte_count != total_bytes:
raise BagValidationError(
Expand Down Expand Up @@ -1583,7 +1661,7 @@ def main():
else:
LOGGER.info(_("%s is valid"), bag_dir)
except BagError as e:
LOGGER.error(
LOGGER.exception(
_("%(bag)s is invalid: %(error)s"), {"bag": bag_dir, "error": e}
)
rc = 1
Expand All @@ -1598,7 +1676,7 @@ def main():
checksums=args.checksums,
)
except Exception as exc:
LOGGER.error(
LOGGER.exception(
_("Failed to create bag in %(bag_directory)s: %(error)s"),
{"bag_directory": bag_dir, "error": exc},
exc_info=True,
Expand Down
242 changes: 242 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,248 @@ def test_open_bag_with_unknown_encoding(self):

self.assertEqual("Unsupported encoding: WTF-8", str(error_catcher.exception))

def test_update_payload_added_file(self):
bag = bagit.make_bag(self.tmpdir)
self.assertTrue(bag.is_valid())

with open(j(self.tmpdir, "data", "newfile"), "w") as nf:
nf.write("newfile")

bag = bagit.Bag(self.tmpdir)
self.assertFalse(bag.is_valid())

bag.update_payload()

bag = bagit.Bag(self.tmpdir)
self.assertTrue(bag.is_valid())

def test_update_payload_deleted_file(self):
bag = bagit.make_bag(self.tmpdir)
self.assertTrue(bag.is_valid())

os.remove(j(self.tmpdir, "data", "loc", "2478433644_2839c5e8b8_o_d.jpg"))

bag = bagit.Bag(self.tmpdir)
self.assertFalse(bag.is_valid())

bag.update_payload()

bag = bagit.Bag(self.tmpdir)
self.assertTrue(bag.is_valid())

def test_payload_oxum(self):
bag = bagit.make_bag(self.tmpdir, checksums=["md5"])
self.assertEqual(bag.payload_oxum(), (991765, 5))

def test_payload_oxum_after_payload_change(self):
bagit.make_bag(self.tmpdir, checksums=["md5"])

with open(j(self.tmpdir, "data", "newfile"), "w") as nf:
nf.write("newfile")

bag = bagit.Bag(self.tmpdir)
self.assertEqual(bag.payload_oxum(), (991772, 6))

def test_add_payload(self):
bag = bagit.make_bag(self.tmpdir)

extra = os.path.join(self.tmpdir, "extra.txt")

with open(extra, "w") as f:
f.write("hello world")

bag.add_payload(extra)

bag = bagit.Bag(self.tmpdir)

self.assertTrue(
os.path.isfile(
os.path.join(self.tmpdir, "data", "extra.txt")
)
)
self.assertTrue(bag.is_valid())

def test_add_payload_updates_oxum(self):
bag = bagit.make_bag(self.tmpdir)

old_bytes, old_files = bag.payload_oxum()

extra = os.path.join(self.tmpdir, "extra.txt")

with open(extra, "w") as f:
f.write("hello")

bag.add_payload(extra)

bag = bagit.Bag(self.tmpdir)

new_bytes, new_files = bag.payload_oxum()

self.assertEqual(new_files, old_files + 1)
self.assertEqual(new_bytes, old_bytes + 5)

self.assertTrue(bag.is_valid())

def test_add_payload_missing_file(self):
bag = bagit.make_bag(self.tmpdir)

with self.assertRaises(ValueError):
bag.add_payload("/does/not/exist")

def test_remove_payload(self):
bag = bagit.make_bag(self.tmpdir)

payload_file = "data/README"
full_path = j(self.tmpdir, payload_file)

self.assertTrue(os.path.isfile(full_path))

bag.remove_payload(payload_file)

bag = bagit.Bag(self.tmpdir)

self.assertFalse(os.path.exists(full_path))
self.assertTrue(bag.is_valid())

def test_remove_payload_updates_oxum(self):
bag = bagit.make_bag(self.tmpdir)

old_bytes, old_files = bag.payload_oxum()
removed_size = os.stat(j(self.tmpdir, "data", "README")).st_size

bag.remove_payload("data/README")

bag = bagit.Bag(self.tmpdir)

new_bytes, new_files = bag.payload_oxum()

self.assertEqual(new_files, old_files - 1)
self.assertEqual(new_bytes, old_bytes - removed_size)

def test_remove_payload_rejects_non_payload_path(self):
bag = bagit.make_bag(self.tmpdir)

with self.assertRaises(ValueError):
bag.remove_payload("bag-info.txt")

def test_add_payload_with_destination(self):
bag = bagit.make_bag(self.tmpdir)

extra = os.path.join(self.tmpdir, "extra.txt")
with open(extra, "w") as f:
f.write("hello")

bag.add_payload(extra, "masters/extra.txt")

self.assertTrue(os.path.isfile(j(self.tmpdir, "data", "masters", "extra.txt")))

bag = bagit.Bag(self.tmpdir)
self.assertTrue(bag.is_valid())

def test_add_payload_rejects_unsafe_destination(self):
bag = bagit.make_bag(self.tmpdir)

extra = os.path.join(self.tmpdir, "extra.txt")
with open(extra, "w") as f:
f.write("hello")

unsafe_destinations = [
"../extra.txt",
"subdir/../../extra.txt",
os.path.abspath(os.path.join(self.tmpdir, "..", "extra.txt")),
"~/.ssh/id_rsa",
]

if os.name == "nt":
unsafe_destinations.extend(
[
r"..\extra.txt",
r"subdir\..\..\extra.txt",
r"C:\Windows\system32\cmd.exe",
r"\\server\share\file.txt",
]
)

for dest in unsafe_destinations:
with self.subTest(dest=dest):
with self.assertRaises(ValueError):
bag.add_payload(extra, "../extra.txt")

def test_remove_payload_directory_recursive(self):
bag = bagit.make_bag(self.tmpdir)

payload_dir = j(self.tmpdir, "data", "extra")
os.makedirs(payload_dir)

with open(j(payload_dir, "one.txt"), "w") as f:
f.write("one")

with open(j(payload_dir, "two.txt"), "w") as f:
f.write("two")

bag.update_payload()
self.assertTrue(bag.is_valid())

bag.remove_payload("data/extra", recursive=True)

bag = bagit.Bag(self.tmpdir)

self.assertFalse(os.path.exists(payload_dir))
self.assertTrue(bag.is_valid())

def test_remove_payload_directory_requires_recursive(self):
bag = bagit.make_bag(self.tmpdir)

payload_dir = j(self.tmpdir, "data", "extra")
os.makedirs(payload_dir)

with self.assertRaises(ValueError):
bag.remove_payload("data/extra")

def test_add_payload_directory(self):
bag = bagit.make_bag(self.tmpdir)

src_dir = tempfile.mkdtemp()
try:
with open(j(src_dir, "one.txt"), "w") as f:
f.write("one")

os.mkdir(j(src_dir, "nested"))

with open(j(src_dir, "nested", "two.txt"), "w") as f:
f.write("two")

bag.add_payload(src_dir)

dest_dir = j(self.tmpdir, "data", os.path.basename(src_dir))

self.assertTrue(os.path.isdir(dest_dir))
self.assertTrue(os.path.isfile(j(dest_dir, "one.txt")))
self.assertTrue(os.path.isfile(j(dest_dir, "nested", "two.txt")))

bag = bagit.Bag(self.tmpdir)
self.assertTrue(bag.is_valid())
finally:
shutil.rmtree(src_dir)

def test_add_payload_directory_with_destination(self):
bag = bagit.make_bag(self.tmpdir)

src_dir = tempfile.mkdtemp()
try:
with open(j(src_dir, "one.txt"), "w") as f:
f.write("one")

bag.add_payload(src_dir, "imports/sip")

self.assertTrue(os.path.isdir(j(self.tmpdir, "data", "imports", "sip")))
self.assertTrue(os.path.isfile(j(self.tmpdir, "data", "imports", "sip", "one.txt")))

bag = bagit.Bag(self.tmpdir)
self.assertTrue(bag.is_valid())
finally:
shutil.rmtree(src_dir)


class TestFetch(SelfCleaningTestCase):
def setUp(self):
Expand Down