Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def generate():
assert part['response'] == next(it)


def test_client_generate_images(httpserver: HTTPServer):
def test_client_generate_images(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
Expand All @@ -460,10 +460,12 @@ def test_client_generate_images(httpserver: HTTPServer):

client = Client(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as temp:
# Create a Temporary file path outide local scope .
temp_img = tmp_path / "test.png"
with open(temp_img, "wb") as temp:
temp.write(PNG_BYTES)
temp.flush()
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
response = client.generate('dummy', 'Why is the sky blue?', images=[str(temp_img)])
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'

Expand Down Expand Up @@ -774,7 +776,7 @@ def userhomedir():
os.environ['HOME'] = home


def test_client_create_with_blob(httpserver: HTTPServer):
def test_client_create_with_blob(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
Expand All @@ -787,12 +789,15 @@ def test_client_create_with_blob(httpserver: HTTPServer):

client = Client(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile():
response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
# Create a Temporary file path outide local scope .
temp_file = tmp_path / "test.gguf"
with open(temp_file, "wb") as f:
f.flush()
response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'


def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
Expand All @@ -812,7 +817,7 @@ def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):

client = Client(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile():
with tempfile.NamedTemporaryFile(delete=False):
response = client.create(
'dummy',
quantize='q4_k_m',
Expand Down Expand Up @@ -845,23 +850,29 @@ def test_client_create_from_library(httpserver: HTTPServer):
assert response['status'] == 'success'


def test_client_create_blob(httpserver: HTTPServer):
def test_client_create_blob(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(re.compile('^/api/blobs/sha256[:-][0-9a-fA-F]{64}$'), method='POST').respond_with_response(Response(status=201))

client = Client(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as blob:
response = client.create_blob(blob.name)
# Create a Temporary file path outide local scope .
temp_blob = tmp_path / "blob.bin"
with open(temp_blob, "wb") as blob:
blob.flush()
response = client.create_blob(str(temp_blob))
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'


def test_client_create_blob_exists(httpserver: HTTPServer):
def test_client_create_blob_exists(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))

client = Client(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as blob:
response = client.create_blob(blob.name)
# Create a Temporary file path outide local scope .
temp_blob = tmp_path / "blob.bin"
with open(temp_blob, "wb") as blob:
blob.flush()
response = client.create_blob(str(temp_blob))
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'


Expand Down Expand Up @@ -1036,7 +1047,7 @@ def generate():
assert part['response'] == next(it)


async def test_async_client_generate_images(httpserver: HTTPServer):
async def test_async_client_generate_images(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
Expand All @@ -1055,10 +1066,12 @@ async def test_async_client_generate_images(httpserver: HTTPServer):

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as temp:
# Create a Temporary file path outide local scope .
temp_img = tmp_path / "test.png"
with open(temp_img, "wb") as temp:
temp.write(PNG_BYTES)
temp.flush()
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
response = await client.generate('dummy', 'Why is the sky blue?', images=[str(temp_img)])
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'

Expand Down Expand Up @@ -1151,7 +1164,7 @@ def generate():
assert part['status'] == next(it)


async def test_async_client_create_with_blob(httpserver: HTTPServer):
async def test_async_client_create_with_blob(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
Expand All @@ -1164,12 +1177,15 @@ async def test_async_client_create_with_blob(httpserver: HTTPServer):

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile():
response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
# Create a Temporary file path outide local scope .
temp_file = tmp_path / "test.gguf"
with open(temp_file, "wb") as f:
f.flush()
response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'


async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
Expand All @@ -1189,7 +1205,7 @@ async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPSer

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile():
with tempfile.NamedTemporaryFile(delete=False):
response = await client.create(
'dummy',
quantize='q4_k_m',
Expand Down Expand Up @@ -1222,23 +1238,29 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
assert response['status'] == 'success'


async def test_async_client_create_blob(httpserver: HTTPServer):
async def test_async_client_create_blob(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(re.compile('^/api/blobs/sha256[:-][0-9a-fA-F]{64}$'), method='POST').respond_with_response(Response(status=201))

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as blob:
response = await client.create_blob(blob.name)
# Create a Temporary file path outide local scope .
temp_blob = tmp_path / "blob.bin"
with open(temp_blob, "wb") as blob:
blob.flush()
response = await client.create_blob(str(temp_blob))
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'


async def test_async_client_create_blob_exists(httpserver: HTTPServer):
async def test_async_client_create_blob_exists(httpserver: HTTPServer, tmp_path):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as blob:
response = await client.create_blob(blob.name)
# Create a Temporary file path outide local scope .
temp_blob = tmp_path / "blob.bin"
with open(temp_blob, "wb") as blob:
blob.flush()
response = await client.create_blob(str(temp_blob))
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'


Expand Down
16 changes: 10 additions & 6 deletions tests/test_type_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,23 @@ def test_image_serialization_plain_string():
assert img.model_dump() == 'not a path or base64' # Should return as-is


def test_image_serialization_path():
with tempfile.NamedTemporaryFile() as temp_file:
def test_image_serialization_path(tmp_path):
# Create a Temporary file path outide local scope .
temp_file_path = tmp_path / "temp.txt"
with open(temp_file_path, "wb") as temp_file:
temp_file.write(b'test file content')
temp_file.flush()
img = Image(value=Path(temp_file.name))
img = Image(value=temp_file_path)
assert img.model_dump() == b64encode(b'test file content').decode()


def test_image_serialization_string_path():
with tempfile.NamedTemporaryFile() as temp_file:
def test_image_serialization_string_path(tmp_path):
# Create a Temporary file path outide local scope .
temp_file_path = tmp_path / "temp.txt"
with open(temp_file_path, "wb") as temp_file:
temp_file.write(b'test file content')
temp_file.flush()
img = Image(value=temp_file.name)
img = Image(value=str(temp_file_path))
assert img.model_dump() == b64encode(b'test file content').decode()

with pytest.raises(ValueError):
Expand Down