211 lines
6.9 KiB
TypeScript
211 lines
6.9 KiB
TypeScript
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||
import type { TextContent } from "@mariozechner/pi-ai";
|
||
import type { SessionManager } from "@mariozechner/pi-coding-agent";
|
||
import { emitSessionTranscriptUpdate } from "../sessions/transcript-events.js";
|
||
import { HARD_MAX_TOOL_RESULT_CHARS } from "./pi-embedded-runner/tool-result-truncation.js";
|
||
import { makeMissingToolResult, sanitizeToolCallInputs } from "./session-transcript-repair.js";
|
||
import { extractToolCallsFromAssistant, extractToolResultId } from "./tool-call-id.js";
|
||
|
||
const GUARD_TRUNCATION_SUFFIX =
|
||
"\n\n⚠️ [Content truncated during persistence — original exceeded size limit. " +
|
||
"Use offset/limit parameters or request specific sections for large content.]";
|
||
|
||
/**
|
||
* Truncate oversized text content blocks in a tool result message.
|
||
* Returns the original message if under the limit, or a new message with
|
||
* truncated text blocks otherwise.
|
||
*/
|
||
function capToolResultSize(msg: AgentMessage): AgentMessage {
|
||
const role = (msg as { role?: string }).role;
|
||
if (role !== "toolResult") {
|
||
return msg;
|
||
}
|
||
const content = (msg as { content?: unknown }).content;
|
||
if (!Array.isArray(content)) {
|
||
return msg;
|
||
}
|
||
|
||
// Calculate total text size
|
||
let totalTextChars = 0;
|
||
for (const block of content) {
|
||
if (block && typeof block === "object" && (block as { type?: string }).type === "text") {
|
||
const text = (block as TextContent).text;
|
||
if (typeof text === "string") {
|
||
totalTextChars += text.length;
|
||
}
|
||
}
|
||
}
|
||
|
||
if (totalTextChars <= HARD_MAX_TOOL_RESULT_CHARS) {
|
||
return msg;
|
||
}
|
||
|
||
// Truncate proportionally
|
||
const newContent = content.map((block: unknown) => {
|
||
if (!block || typeof block !== "object" || (block as { type?: string }).type !== "text") {
|
||
return block;
|
||
}
|
||
const textBlock = block as TextContent;
|
||
if (typeof textBlock.text !== "string") {
|
||
return block;
|
||
}
|
||
const blockShare = textBlock.text.length / totalTextChars;
|
||
const blockBudget = Math.max(
|
||
2_000,
|
||
Math.floor(HARD_MAX_TOOL_RESULT_CHARS * blockShare) - GUARD_TRUNCATION_SUFFIX.length,
|
||
);
|
||
if (textBlock.text.length <= blockBudget) {
|
||
return block;
|
||
}
|
||
// Try to cut at a newline boundary
|
||
let cutPoint = blockBudget;
|
||
const lastNewline = textBlock.text.lastIndexOf("\n", blockBudget);
|
||
if (lastNewline > blockBudget * 0.8) {
|
||
cutPoint = lastNewline;
|
||
}
|
||
return {
|
||
...textBlock,
|
||
text: textBlock.text.slice(0, cutPoint) + GUARD_TRUNCATION_SUFFIX,
|
||
};
|
||
});
|
||
|
||
return { ...msg, content: newContent } as AgentMessage;
|
||
}
|
||
|
||
export function installSessionToolResultGuard(
|
||
sessionManager: SessionManager,
|
||
opts?: {
|
||
/**
|
||
* Optional transform applied to any message before persistence.
|
||
*/
|
||
transformMessageForPersistence?: (message: AgentMessage) => AgentMessage;
|
||
/**
|
||
* Optional, synchronous transform applied to toolResult messages *before* they are
|
||
* persisted to the session transcript.
|
||
*/
|
||
transformToolResultForPersistence?: (
|
||
message: AgentMessage,
|
||
meta: { toolCallId?: string; toolName?: string; isSynthetic?: boolean },
|
||
) => AgentMessage;
|
||
/**
|
||
* Whether to synthesize missing tool results to satisfy strict providers.
|
||
* Defaults to true.
|
||
*/
|
||
allowSyntheticToolResults?: boolean;
|
||
},
|
||
): {
|
||
flushPendingToolResults: () => void;
|
||
getPendingIds: () => string[];
|
||
} {
|
||
const originalAppend = sessionManager.appendMessage.bind(sessionManager);
|
||
const pending = new Map<string, string | undefined>();
|
||
const persistMessage = (message: AgentMessage) => {
|
||
const transformer = opts?.transformMessageForPersistence;
|
||
return transformer ? transformer(message) : message;
|
||
};
|
||
|
||
const persistToolResult = (
|
||
message: AgentMessage,
|
||
meta: { toolCallId?: string; toolName?: string; isSynthetic?: boolean },
|
||
) => {
|
||
const transformer = opts?.transformToolResultForPersistence;
|
||
return transformer ? transformer(message, meta) : message;
|
||
};
|
||
|
||
const allowSyntheticToolResults = opts?.allowSyntheticToolResults ?? true;
|
||
|
||
const flushPendingToolResults = () => {
|
||
if (pending.size === 0) {
|
||
return;
|
||
}
|
||
if (allowSyntheticToolResults) {
|
||
for (const [id, name] of pending.entries()) {
|
||
const synthetic = makeMissingToolResult({ toolCallId: id, toolName: name });
|
||
originalAppend(
|
||
persistToolResult(persistMessage(synthetic), {
|
||
toolCallId: id,
|
||
toolName: name,
|
||
isSynthetic: true,
|
||
}) as never,
|
||
);
|
||
}
|
||
}
|
||
pending.clear();
|
||
};
|
||
|
||
const guardedAppend = (message: AgentMessage) => {
|
||
let nextMessage = message;
|
||
const role = (message as { role?: unknown }).role;
|
||
if (role === "assistant") {
|
||
const sanitized = sanitizeToolCallInputs([message]);
|
||
if (sanitized.length === 0) {
|
||
if (allowSyntheticToolResults && pending.size > 0) {
|
||
flushPendingToolResults();
|
||
}
|
||
return undefined;
|
||
}
|
||
nextMessage = sanitized[0];
|
||
}
|
||
const nextRole = (nextMessage as { role?: unknown }).role;
|
||
|
||
if (nextRole === "toolResult") {
|
||
const id = extractToolResultId(nextMessage as Extract<AgentMessage, { role: "toolResult" }>);
|
||
const toolName = id ? pending.get(id) : undefined;
|
||
if (id) {
|
||
pending.delete(id);
|
||
}
|
||
// Apply hard size cap before persistence to prevent oversized tool results
|
||
// from consuming the entire context window on subsequent LLM calls.
|
||
const capped = capToolResultSize(persistMessage(nextMessage));
|
||
return originalAppend(
|
||
persistToolResult(capped, {
|
||
toolCallId: id ?? undefined,
|
||
toolName,
|
||
isSynthetic: false,
|
||
}) as never,
|
||
);
|
||
}
|
||
|
||
const toolCalls =
|
||
nextRole === "assistant"
|
||
? extractToolCallsFromAssistant(nextMessage as Extract<AgentMessage, { role: "assistant" }>)
|
||
: [];
|
||
|
||
if (allowSyntheticToolResults) {
|
||
// If previous tool calls are still pending, flush before non-tool results.
|
||
if (pending.size > 0 && (toolCalls.length === 0 || nextRole !== "assistant")) {
|
||
flushPendingToolResults();
|
||
}
|
||
// If new tool calls arrive while older ones are pending, flush the old ones first.
|
||
if (pending.size > 0 && toolCalls.length > 0) {
|
||
flushPendingToolResults();
|
||
}
|
||
}
|
||
|
||
const result = originalAppend(persistMessage(nextMessage) as never);
|
||
|
||
const sessionFile = (
|
||
sessionManager as { getSessionFile?: () => string | null }
|
||
).getSessionFile?.();
|
||
if (sessionFile) {
|
||
emitSessionTranscriptUpdate(sessionFile);
|
||
}
|
||
|
||
if (toolCalls.length > 0) {
|
||
for (const call of toolCalls) {
|
||
pending.set(call.id, call.name);
|
||
}
|
||
}
|
||
|
||
return result;
|
||
};
|
||
|
||
// Monkey-patch appendMessage with our guarded version.
|
||
sessionManager.appendMessage = guardedAppend as SessionManager["appendMessage"];
|
||
|
||
return {
|
||
flushPendingToolResults,
|
||
getPendingIds: () => Array.from(pending.keys()),
|
||
};
|
||
}
|