feat: pre-call sanitization and post-call tool guardrails (#1732)

Salvage of PR #1321 by @alireza78a (cherry-picked concept, reimplemented
against current main).

Phase 1 — Pre-call message sanitization:
  _sanitize_api_messages() now runs unconditionally before every LLM call.
  Previously gated on context_compressor being present, so sessions loaded
  from disk or running without compression could accumulate dangling
  tool_call/tool_result pairs causing API errors.

Phase 2a — Delegate task cap:
  _cap_delegate_task_calls() truncates excess delegate_task calls per turn
  to MAX_CONCURRENT_CHILDREN. The existing cap in delegate_tool.py only
  limits the task array within a single call; this catches multiple
  separate delegate_task tool_calls in one turn.

Phase 2b — Tool call deduplication:
  _deduplicate_tool_calls() drops duplicate (tool_name, arguments) pairs
  within a single turn when models stutter.

All three are static methods on AIAgent, independently testable.
29 tests covering happy paths and edge cases.
This commit is contained in:
Teknium
2026-03-17 04:24:27 -07:00
committed by GitHub
parent fb20a9e120
commit 85993fbb5a
2 changed files with 394 additions and 7 deletions

View File

@@ -1957,7 +1957,124 @@ class AIAgent:
prompt_parts.append(PLATFORM_HINTS[platform_key])
return "\n\n".join(prompt_parts)
# =========================================================================
# Pre/post-call guardrails (inspired by PR #1321 — @alireza78a)
# =========================================================================
@staticmethod
def _get_tool_call_id_static(tc) -> str:
"""Extract call ID from a tool_call entry (dict or object)."""
if isinstance(tc, dict):
return tc.get("id", "") or ""
return getattr(tc, "id", "") or ""
@staticmethod
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
Runs unconditionally — not gated on whether the context compressor
is present — so orphans from session loading or manual message
manipulation are always caught.
"""
surviving_call_ids: set = set()
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = AIAgent._get_tool_call_id_static(tc)
if cid:
surviving_call_ids.add(cid)
result_call_ids: set = set()
for msg in messages:
if msg.get("role") == "tool":
cid = msg.get("tool_call_id")
if cid:
result_call_ids.add(cid)
# 1. Drop tool results with no matching assistant call
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
logger.debug(
"Pre-call sanitizer: removed %d orphaned tool result(s)",
len(orphaned_results),
)
# 2. Inject stub results for calls whose result was dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = AIAgent._get_tool_call_id_static(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result unavailable — see context summary above]",
"tool_call_id": cid,
})
messages = patched
logger.debug(
"Pre-call sanitizer: added %d stub tool result(s)",
len(missing_results),
)
return messages
@staticmethod
def _cap_delegate_task_calls(tool_calls: list) -> list:
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
The delegate_tool caps the task list inside a single call, but the
model can emit multiple separate delegate_task tool_calls in one
turn. This truncates the excess, preserving all non-delegate calls.
Returns the original list if no truncation was needed.
"""
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
if delegate_count <= MAX_CONCURRENT_CHILDREN:
return tool_calls
kept_delegates = 0
truncated = []
for tc in tool_calls:
if tc.function.name == "delegate_task":
if kept_delegates < MAX_CONCURRENT_CHILDREN:
truncated.append(tc)
kept_delegates += 1
else:
truncated.append(tc)
logger.warning(
"Truncated %d excess delegate_task call(s) to enforce "
"MAX_CONCURRENT_CHILDREN=%d limit",
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
)
return truncated
@staticmethod
def _deduplicate_tool_calls(tool_calls: list) -> list:
"""Remove duplicate (tool_name, arguments) pairs within a single turn.
Only the first occurrence of each unique pair is kept.
Returns the original list if no duplicates were found.
"""
seen: set = set()
unique: list = []
for tc in tool_calls:
key = (tc.function.name, tc.function.arguments)
if key not in seen:
seen.add(key)
unique.append(tc)
else:
logger.warning("Removed duplicate tool call: %s", tc.function.name)
return unique if len(unique) < len(tool_calls) else tool_calls
def _repair_tool_call(self, tool_name: str) -> str | None:
"""Attempt to repair a mismatched tool name before aborting.
@@ -4992,11 +5109,10 @@ class AIAgent:
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
# Safety net: strip orphaned tool results / add stubs for missing
# results before sending to the API. The compressor handles this
# during compression, but orphans can also sneak in from session
# loading or manual message manipulation.
if hasattr(self, 'context_compressor') and self.context_compressor:
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
# results before sending to the API. Runs unconditionally — not
# gated on context_compressor — so orphans from session loading or
# manual message manipulation are always caught.
api_messages = self._sanitize_api_messages(api_messages)
# Calculate approximate request size for logging
total_chars = sum(len(str(msg)) for msg in api_messages)
@@ -6026,7 +6142,15 @@ class AIAgent:
# Reset retry counter on successful JSON validation
self._invalid_json_retries = 0
# ── Post-call guardrails ──────────────────────────
assistant_message.tool_calls = self._cap_delegate_task_calls(
assistant_message.tool_calls
)
assistant_message.tool_calls = self._deduplicate_tool_calls(
assistant_message.tool_calls
)
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
# If this turn has both content AND tool_calls, capture the content