Skip to content
Merged

BI #44

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/website_profiling/llm/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,98 @@
from __future__ import annotations

import json
import os
import sys
from typing import Any

from ..base import ChatResult, TokenCallback, ToolCall, parse_json_response

# Ephemeral (5-minute) prompt-cache marker. Placed on the static request prefix
# (tools -> system -> conversation) so Anthropic bills repeated prefix tokens at
# ~10% of base input price across the multi-round tool loop. Mirrors how Claude
# Code caches its tool/system prefix.
_CACHE_CONTROL = {"type": "ephemeral"}


def _truthy(value: str | None, *, default: bool) -> bool:
raw = (value or "").strip().lower()
if not raw:
return default
return raw in ("1", "true", "yes", "on")


def _prompt_cache_enabled() -> bool:
"""Prompt caching is on by default; set WP_LLM_PROMPT_CACHE=0 to disable."""
return _truthy(os.environ.get("WP_LLM_PROMPT_CACHE"), default=True)


def _cache_debug_enabled() -> bool:
return _truthy(os.environ.get("WP_LLM_DEBUG_CACHE"), default=False)


def _log_cache_usage(usage: Any) -> None:
"""When WP_LLM_DEBUG_CACHE is set, print cache token counts to stderr."""
if usage is None or not _cache_debug_enabled():
return
created = getattr(usage, "cache_creation_input_tokens", None)
read = getattr(usage, "cache_read_input_tokens", None)
inp = getattr(usage, "input_tokens", None)
print(
f"[wp-cache] input={inp} cache_creation={created} cache_read={read}",
file=sys.stderr,
flush=True,
)


def _apply_prompt_caching(
system: str,
tools: list[dict[str, Any]],
messages: list[dict[str, Any]],
) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]:
"""Add cache_control breakpoints to the static request prefix.

Returns ``(system, tools, messages)`` unchanged when caching is disabled, so
behavior is byte-identical to the no-cache path. Otherwise places three
breakpoints (the limit is four) in Anthropic's prefix order:

1. the last tool definition (caches the whole tools array),
2. the system prompt (caches tools+system),
3. the last content block of the last message (rolls forward each round,
reading the prior conversation prefix from cache and writing the suffix).

Builds new copies — never mutates the caller's lists/dicts — so the pure
converter outputs stay clean.
"""
if not _prompt_cache_enabled():
return system, tools, messages

# 1. System prompt -> single text block carrying the cache marker.
system_blocks: Any = [
{"type": "text", "text": system, "cache_control": _CACHE_CONTROL},
]

# 2. Last tool definition.
tools_out = list(tools)
if tools_out:
tools_out[-1] = {**tools_out[-1], "cache_control": _CACHE_CONTROL}

# 3. Last content block of the last message.
messages_out = list(messages)
if messages_out:
last = dict(messages_out[-1])
content = last.get("content")
if isinstance(content, list) and content:
blocks = list(content)
blocks[-1] = {**blocks[-1], "cache_control": _CACHE_CONTROL}
last["content"] = blocks
elif isinstance(content, str):
last["content"] = [
{"type": "text", "text": content, "cache_control": _CACHE_CONTROL},
]
messages_out[-1] = last

return system_blocks, tools_out, messages_out


def _to_anthropic_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
"""Convert OpenAI-shaped chat messages to ``(system, anthropic_messages)``.
Expand Down Expand Up @@ -122,6 +210,9 @@ def chat_with_tools(

system, anthropic_messages = _to_anthropic_messages(messages)
anthropic_tools = _to_anthropic_tools(tools)
system, anthropic_tools, anthropic_messages = _apply_prompt_caching(
system, anthropic_tools, anthropic_messages,
)

kwargs: dict[str, Any] = {
"model": self._model,
Expand Down Expand Up @@ -154,6 +245,7 @@ def chat_with_tools(
prev = tool_calls[-1].arguments.get("_partial", "")
tool_calls[-1].arguments["_partial"] = prev + partial
final = stream.get_final_message()
_log_cache_usage(getattr(final, "usage", None))
for tc in tool_calls:
partial = tc.arguments.pop("_partial", "")
if partial:
Expand All @@ -168,6 +260,7 @@ def chat_with_tools(
return ChatResult(content="".join(content_parts) or "".join(text_parts), tool_calls=tool_calls)

msg = client.messages.create(**kwargs)
_log_cache_usage(getattr(msg, "usage", None))
content_parts: list[str] = []
tool_calls = []
for block in msg.content:
Expand Down
99 changes: 99 additions & 0 deletions tests/test_llm_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
"""
from __future__ import annotations

import pytest

from website_profiling.llm.providers.anthropic import (
_apply_prompt_caching,
_to_anthropic_messages,
_to_anthropic_tools,
)

_EPHEMERAL = {"type": "ephemeral"}


def test_assistant_tool_calls_become_matching_tool_use_blocks() -> None:
messages = [
Expand Down Expand Up @@ -75,3 +80,97 @@ def test_to_anthropic_tools_maps_schema() -> None:
assert _to_anthropic_tools(tools) == [
{"name": "t", "description": "d", "input_schema": {"type": "object", "properties": {}}},
]


# --- prompt caching --------------------------------------------------------


@pytest.fixture(autouse=True)
def _cache_on(monkeypatch: pytest.MonkeyPatch) -> None:
"""Caching defaults to on; pin it for deterministic tests."""
monkeypatch.setenv("WP_LLM_PROMPT_CACHE", "1")


def test_caching_marks_last_tool_only() -> None:
tools = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
_, tools_out, _ = _apply_prompt_caching("sys", tools, [])
assert "cache_control" not in tools_out[0]
assert "cache_control" not in tools_out[1]
assert tools_out[-1]["cache_control"] == _EPHEMERAL
# original list/dicts are untouched
assert all("cache_control" not in t for t in tools)


def test_caching_empty_tools_is_safe() -> None:
system, tools_out, _ = _apply_prompt_caching("sys", [], [])
assert tools_out == []
assert system == [{"type": "text", "text": "sys", "cache_control": _EPHEMERAL}]


def test_caching_system_becomes_text_block() -> None:
system, _, _ = _apply_prompt_caching("the system prompt", [], [])
assert system == [
{"type": "text", "text": "the system prompt", "cache_control": _EPHEMERAL},
]


def test_caching_last_message_string_content_becomes_block() -> None:
messages = [
{"role": "user", "content": "first"},
{"role": "user", "content": "second"},
]
_, _, out = _apply_prompt_caching("sys", [], messages)
# earlier message untouched
assert out[0] == {"role": "user", "content": "first"}
assert out[-1]["content"] == [
{"type": "text", "text": "second", "cache_control": _EPHEMERAL},
]
# caller's list/dicts not mutated
assert messages[-1] == {"role": "user", "content": "second"}


def test_caching_last_message_list_content_marks_last_block() -> None:
messages = [{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "c1", "content": "{}"},
{"type": "tool_result", "tool_use_id": "c2", "content": "{}"},
],
}]
_, _, out = _apply_prompt_caching("sys", [], messages)
blocks = out[-1]["content"]
assert "cache_control" not in blocks[0]
assert blocks[-1]["cache_control"] == _EPHEMERAL
# original untouched
assert all("cache_control" not in b for b in messages[0]["content"])


def test_caching_disabled_returns_inputs_unchanged(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("WP_LLM_PROMPT_CACHE", "0")
tools = [{"name": "a"}]
messages = [{"role": "user", "content": "hi"}]
system, tools_out, msgs_out = _apply_prompt_caching("sys", tools, messages)
assert system == "sys" # stays a plain string
assert tools_out is tools
assert msgs_out is messages


def test_caching_uses_at_most_four_breakpoints() -> None:
tools = [{"name": "a"}, {"name": "b"}]
messages = [
{"role": "user", "content": "u"},
{"role": "assistant", "content": [{"type": "text", "text": "a"}]},
]
system, tools_out, msgs_out = _apply_prompt_caching("sys", tools, messages)

def _count(obj: object) -> int:
if isinstance(obj, dict):
n = 1 if obj.get("cache_control") == _EPHEMERAL else 0
return n + sum(_count(v) for v in obj.values())
if isinstance(obj, list):
return sum(_count(v) for v in obj)
return 0

total = _count(system) + _count(tools_out) + _count(msgs_out)
assert total == 3
assert total <= 4
9 changes: 7 additions & 2 deletions web/app/api/chat/artifacts/[id]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,18 @@ export const GET: ApiRouteHandlerWithParams<{ id: string }> = async (
return;
}
const body = Buffer.from(parsed.data_base64, 'base64');
const filename = parsed.filename || 'export.bin';
const rawName = parsed.filename || 'export.bin';
// Sanitize the ASCII fallback (strip non-printable/quote/slash chars so
// a CR/LF or quote can't break or inject the header) and provide an
// RFC 5987 filename* for the full UTF-8 name.
const asciiName =
rawName.replace(/[^\x20-\x7e]/g, '_').replace(/["\\/]/g, '_') || 'export.bin';
const mime = parsed.mime_type || 'application/octet-stream';
resolve(
new NextResponse(body, {
headers: {
'Content-Type': mime,
'Content-Disposition': `attachment; filename="${filename}"`,
'Content-Disposition': `attachment; filename="${asciiName}"; filename*=UTF-8''${encodeURIComponent(rawName)}`,
},
}),
);
Expand Down
47 changes: 46 additions & 1 deletion web/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,31 @@ export const POST: ApiRouteHandler = async (request: NextRequest): Promise<Respo
const chatTimeoutMs = await resolveChatTimeoutMs();
const timeoutSec = Math.round(chatTimeoutMs / 1000);

// Track the spawned child so we can kill it if the client disconnects
// (ReadableStream.cancel) instead of leaking it until the timeout fires.
let activeProc: ReturnType<typeof spawn> | null = null;
let activeKillTimer: ReturnType<typeof setTimeout> | null = null;
let cancelled = false;

const cancelChild = () => {
cancelled = true;
const p = activeProc;
if (!p) return;
try {
p.kill('SIGTERM');
activeKillTimer = setTimeout(() => {
try {
p.kill('SIGKILL');
} catch {
/* already exited */
}
}, 2000);
(activeKillTimer as { unref?: () => void }).unref?.();
} catch {
/* already exited */
}
};

const stream = new ReadableStream({
start(controller) {
const encoder = new TextEncoder();
Expand Down Expand Up @@ -170,6 +195,7 @@ export const POST: ApiRouteHandler = async (request: NextRequest): Promise<Respo
shell: false,
},
);
activeProc = proc;

const timer = setTimeout(() => {
timedOut = true;
Expand All @@ -182,6 +208,14 @@ export const POST: ApiRouteHandler = async (request: NextRequest): Promise<Respo
closeStream();
}, chatTimeoutMs);

// Without an error listener, an EPIPE/ERR_STREAM_DESTROYED on the stdin
// pipe (child exits before reading) would surface as an unhandled stream
// error and crash the Node process instead of a clean chat error.
proc.stdin?.on('error', (err: Error) => {
clearTimeout(timer);
push('error', { message: `Failed to send request to assistant: ${err.message}` });
closeStream();
});
proc.stdin?.write(stdinPayload);
proc.stdin?.end();

Expand Down Expand Up @@ -265,7 +299,13 @@ export const POST: ApiRouteHandler = async (request: NextRequest): Promise<Respo

proc.on('close', async (code: number | null) => {
clearTimeout(timer);
if (timedOut) return;
if (activeKillTimer) {
clearTimeout(activeKillTimer);
activeKillTimer = null;
}
// On client cancel we drop the partial turn (the user navigated away);
// on timeout the error was already streamed.
if (timedOut || cancelled) return;
exitCode = code;

if (!sawError && !assistantText.trim() && !narrative) {
Expand Down Expand Up @@ -327,6 +367,11 @@ export const POST: ApiRouteHandler = async (request: NextRequest): Promise<Respo
closeStream();
});
},
cancel() {
// Client disconnected mid-stream (reload/navigate/abort): terminate the
// agent process so it does not keep holding the LLM connection/CPU.
cancelChild();
},
});

return new Response(stream, {
Expand Down
14 changes: 12 additions & 2 deletions web/app/api/chat/sessions/[id]/messages/route.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { NextResponse, type NextRequest } from 'next/server';
import { forbiddenIfNotLocal } from '@/server/localOnly';
import { requireApiAuthForChat } from '@/server/auth';
import { getChatMessages } from '@/server/chatDb';
import { getChatMessages, getChatSession } from '@/server/chatDb';
import type { ApiRouteHandler } from '@/types/api';

export const runtime = 'nodejs';
export const dynamic = 'force-dynamic';

/** GET /api/chat/sessions/[id]/messages */
/** GET /api/chat/sessions/[id]/messages?propertyId= */
export const GET: ApiRouteHandler = async (
request: NextRequest,
context?: { params?: Promise<{ id: string }> },
Expand All @@ -22,8 +22,18 @@ export const GET: ApiRouteHandler = async (
if (!sessionId) {
return NextResponse.json({ error: 'invalid session id' }, { status: 400 });
}
const propertyId = Number(request.nextUrl.searchParams.get('propertyId') || '0');
if (!propertyId) {
return NextResponse.json({ error: 'propertyId required' }, { status: 400 });
}

try {
// Scope conversation history to the caller's property to avoid leaking
// another property's messages by enumerating session ids.
const session = await getChatSession(sessionId);
if (!session || session.property_id !== propertyId) {
return NextResponse.json({ error: 'session not found' }, { status: 404 });
}
const messages = await getChatMessages(sessionId);
return NextResponse.json({ messages });
} catch (e) {
Expand Down
Loading
Loading