fix(acp): use modes for edit auto-approval

This commit is contained in:
HenkDz
2026-05-16 19:15:08 +01:00
committed by Teknium
parent f70e0b85dd
commit 0292398604
5 changed files with 165 additions and 87 deletions

View File

@@ -117,6 +117,7 @@ def make_tool_progress_cb(
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
tool_call_ids: Dict[str, Deque[str]], tool_call_ids: Dict[str, Deque[str]],
tool_call_meta: Dict[str, Dict[str, Any]], tool_call_meta: Dict[str, Dict[str, Any]],
edit_approval_policy_getter: Callable[[], tuple[str, str | None]] | None = None,
) -> Callable: ) -> Callable:
"""Create a ``tool_progress_callback`` for AIAgent. """Create a ``tool_progress_callback`` for AIAgent.
@@ -162,7 +163,20 @@ def make_tool_progress_cb(
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True) logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot} tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
update = build_tool_start(tc_id, name, args) edit_diff = None
if name in {"write_file", "patch"} and edit_approval_policy_getter is not None:
try:
from acp_adapter.edit_approval import build_edit_proposal, should_auto_approve_edit
proposal = build_edit_proposal(name, args)
if proposal is not None:
policy, cwd = edit_approval_policy_getter()
if should_auto_approve_edit(proposal, policy, cwd):
edit_diff = proposal
except Exception:
logger.debug("Failed to prepare auto-approved ACP edit diff for %s", name, exc_info=True)
update = build_tool_start(tc_id, name, args, edit_diff=edit_diff)
_send_update(conn, session_id, loop, update) _send_update(conn, session_id, loop, update)
return _tool_progress return _tool_progress

View File

@@ -45,10 +45,10 @@ from acp.schema import (
SetSessionModeResponse, SetSessionModeResponse,
ResourceContentBlock, ResourceContentBlock,
SessionCapabilities, SessionCapabilities,
SessionConfigOptionSelect,
SessionConfigSelectOption,
SessionForkCapabilities, SessionForkCapabilities,
SessionListCapabilities, SessionListCapabilities,
SessionMode,
SessionModeState,
SessionModelState, SessionModelState,
SessionResumeCapabilities, SessionResumeCapabilities,
SessionInfo, SessionInfo,
@@ -499,6 +499,17 @@ class HermesACPAgent(acp.Agent):
_EDIT_APPROVAL_POLICY_CONFIG_ID = "edit_approval_policy" _EDIT_APPROVAL_POLICY_CONFIG_ID = "edit_approval_policy"
_EDIT_APPROVAL_POLICY_DEFAULT = "ask" _EDIT_APPROVAL_POLICY_DEFAULT = "ask"
_MODE_DEFAULT = "default"
_MODE_ACCEPT_EDITS = "accept_edits"
_MODE_DONT_ASK = "dont_ask"
_MODE_TO_EDIT_APPROVAL_POLICY = {
_MODE_DEFAULT: "ask",
_MODE_ACCEPT_EDITS: "workspace_session",
_MODE_DONT_ASK: "session",
}
_EDIT_APPROVAL_POLICY_TO_MODE = {
value: key for key, value in _MODE_TO_EDIT_APPROVAL_POLICY.items()
}
def __init__(self, session_manager: SessionManager | None = None): def __init__(self, session_manager: SessionManager | None = None):
super().__init__() super().__init__()
@@ -513,47 +524,43 @@ class HermesACPAgent(acp.Agent):
logger.info("ACP client connected") logger.info("ACP client connected")
def _session_config_options(self, state: SessionState) -> list[Any]: def _session_modes(self, state: SessionState) -> SessionModeState:
values = getattr(state, "config_options", None) """Return ACP session modes while preserving Zed's separate model picker.
if not isinstance(values, dict):
values = {} Zed renders ``config_options`` in the prominent selector slot where the
current = str(values.get(self._EDIT_APPROVAL_POLICY_CONFIG_ID) or self._EDIT_APPROVAL_POLICY_DEFAULT) model picker was visible. Claude/Codex expose policy-like controls as ACP
allowed = {"ask", "workspace_session", "session"} modes, which coexist with the model picker, so Hermes maps edit approval
if current not in allowed: policy onto modes instead of advertising config options.
current = self._EDIT_APPROVAL_POLICY_DEFAULT """
return [
SessionConfigOptionSelect( current = str(getattr(state, "mode", "") or self._MODE_DEFAULT)
id=self._EDIT_APPROVAL_POLICY_CONFIG_ID, if current not in self._MODE_TO_EDIT_APPROVAL_POLICY:
name="Edit approvals", current = self._MODE_DEFAULT
description="Control ACP edit approvals for this session.", return SessionModeState(
category="permissions", current_mode_id=current,
type="select", available_modes=[
current_value=current, SessionMode(
options=[ id=self._MODE_DEFAULT,
SessionConfigSelectOption( name="Default",
value="ask", description="Ask before edits.",
name="Ask before edits", ),
description="Require approval for every file edit.", SessionMode(
), id=self._MODE_ACCEPT_EDITS,
SessionConfigSelectOption( name="Accept Edits",
value="workspace_session", description="Auto-allow workspace and /tmp edits; still asks for sensitive paths.",
name="Auto-allow workspace edits", ),
description="Allow workspace and /tmp edits for this session; still asks for sensitive paths.", SessionMode(
), id=self._MODE_DONT_ASK,
SessionConfigSelectOption( name="Don't Ask",
value="session", description="Auto-allow file edits for this session except sensitive paths.",
name="Auto-allow all edits this session", ),
description="Allow file edits for this session except sensitive paths.", ],
), )
],
)
]
def _edit_approval_policy_for_state(self, state: SessionState) -> tuple[str, str | None]: def _edit_approval_policy_for_state(self, state: SessionState) -> tuple[str, str | None]:
values = getattr(state, "config_options", None) mode = str(getattr(state, "mode", "") or self._MODE_DEFAULT)
if not isinstance(values, dict): policy = self._MODE_TO_EDIT_APPROVAL_POLICY.get(mode, self._EDIT_APPROVAL_POLICY_DEFAULT)
values = {} return policy, state.cwd
return str(values.get(self._EDIT_APPROVAL_POLICY_CONFIG_ID) or self._EDIT_APPROVAL_POLICY_DEFAULT), state.cwd
@staticmethod @staticmethod
def _encode_model_choice(provider: str | None, model: str | None) -> str: def _encode_model_choice(provider: str | None, model: str | None) -> str:
@@ -1040,7 +1047,7 @@ class HermesACPAgent(acp.Agent):
return NewSessionResponse( return NewSessionResponse(
session_id=state.session_id, session_id=state.session_id,
models=self._build_model_state(state), models=self._build_model_state(state),
config_options=self._session_config_options(state), modes=self._session_modes(state),
) )
async def load_session( async def load_session(
@@ -1084,7 +1091,7 @@ class HermesACPAgent(acp.Agent):
self._schedule_usage_update(state) self._schedule_usage_update(state)
return LoadSessionResponse( return LoadSessionResponse(
models=self._build_model_state(state), models=self._build_model_state(state),
config_options=self._session_config_options(state), modes=self._session_modes(state),
) )
async def resume_session( async def resume_session(
@@ -1116,7 +1123,7 @@ class HermesACPAgent(acp.Agent):
self._schedule_usage_update(state) self._schedule_usage_update(state)
return ResumeSessionResponse( return ResumeSessionResponse(
models=self._build_model_state(state), models=self._build_model_state(state),
config_options=self._session_config_options(state), modes=self._session_modes(state),
) )
async def cancel(self, session_id: str, **kwargs: Any) -> None: async def cancel(self, session_id: str, **kwargs: Any) -> None:
@@ -1150,7 +1157,7 @@ class HermesACPAgent(acp.Agent):
return ForkSessionResponse( return ForkSessionResponse(
session_id=new_id, session_id=new_id,
models=self._build_model_state(state) if state is not None else None, models=self._build_model_state(state) if state is not None else None,
config_options=self._session_config_options(state) if state is not None else None, modes=self._session_modes(state) if state is not None else None,
) )
async def list_sessions( async def list_sessions(
@@ -1307,7 +1314,14 @@ class HermesACPAgent(acp.Agent):
streamed_message = False streamed_message = False
if conn: if conn:
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta) tool_progress_cb = make_tool_progress_cb(
conn,
session_id,
loop,
tool_call_ids,
tool_call_meta,
edit_approval_policy_getter=lambda: self._edit_approval_policy_for_state(state),
)
reasoning_cb = make_thinking_cb(conn, session_id, loop) reasoning_cb = make_thinking_cb(conn, session_id, loop)
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta) step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
message_cb = make_message_cb(conn, session_id, loop) message_cb = make_message_cb(conn, session_id, loop)
@@ -1849,9 +1863,12 @@ class HermesACPAgent(acp.Agent):
if state is None: if state is None:
logger.warning("Session %s: mode switch requested for missing session", session_id) logger.warning("Session %s: mode switch requested for missing session", session_id)
return None return None
setattr(state, "mode", mode_id) normalized_mode = str(mode_id or "").strip()
if normalized_mode not in self._MODE_TO_EDIT_APPROVAL_POLICY:
normalized_mode = self._MODE_DEFAULT
setattr(state, "mode", normalized_mode)
self.session_manager.save_session(session_id) self.session_manager.save_session(session_id)
logger.info("Session %s: mode switched to %s", session_id, mode_id) logger.info("Session %s: mode switched to %s", session_id, normalized_mode)
return SetSessionModeResponse() return SetSessionModeResponse()
async def set_config_option( async def set_config_option(
@@ -1863,11 +1880,15 @@ class HermesACPAgent(acp.Agent):
logger.warning("Session %s: config update requested for missing session", session_id) logger.warning("Session %s: config update requested for missing session", session_id)
return None return None
options = getattr(state, "config_options", None) if str(config_id) == self._EDIT_APPROVAL_POLICY_CONFIG_ID:
if not isinstance(options, dict): mode = self._EDIT_APPROVAL_POLICY_TO_MODE.get(str(value), self._MODE_DEFAULT)
options = {} setattr(state, "mode", mode)
options[str(config_id)] = value else:
setattr(state, "config_options", options) options = getattr(state, "config_options", None)
if not isinstance(options, dict):
options = {}
options[str(config_id)] = value
setattr(state, "config_options", options)
self.session_manager.save_session(session_id) self.session_manager.save_session(session_id)
logger.info("Session %s: config option %s updated", session_id, config_id) logger.info("Session %s: config option %s updated", session_id, config_id)
return SetSessionConfigOptionResponse(config_options=self._session_config_options(state)) return SetSessionConfigOptionResponse(config_options=[])

View File

@@ -928,6 +928,8 @@ def build_tool_start(
tool_call_id: str, tool_call_id: str,
tool_name: str, tool_name: str,
arguments: Dict[str, Any], arguments: Dict[str, Any],
*,
edit_diff: Any = None,
) -> ToolCallStart: ) -> ToolCallStart:
"""Create a ToolCallStart event for the given hermes tool invocation.""" """Create a ToolCallStart event for the given hermes tool invocation."""
kind = get_tool_kind(tool_name) kind = get_tool_kind(tool_name)
@@ -935,16 +937,34 @@ def build_tool_start(
locations = extract_locations(arguments) locations = extract_locations(arguments)
if tool_name == "patch": if tool_name == "patch":
mode = arguments.get("mode", "replace") if edit_diff is not None:
path = arguments.get("path") or "patch input" content = [
content = [_text(f"Preparing {mode} edit for {path}. Approval prompt shows the diff.")] acp.tool_diff_content(
path=edit_diff.path,
old_text=edit_diff.old_text,
new_text=edit_diff.new_text,
)
]
else:
mode = arguments.get("mode", "replace")
path = arguments.get("path") or "patch input"
content = [_text(f"Preparing {mode} edit for {path}. Approval prompt shows the diff.")]
return acp.start_tool_call( return acp.start_tool_call(
tool_call_id, title, kind=kind, content=content, locations=locations, tool_call_id, title, kind=kind, content=content, locations=locations,
) )
if tool_name == "write_file": if tool_name == "write_file":
path = arguments.get("path", "") if edit_diff is not None:
content = [_text(f"Preparing write to {path}. Approval prompt shows the diff." if path else "Preparing file write. Approval prompt shows the diff.")] content = [
acp.tool_diff_content(
path=edit_diff.path,
old_text=edit_diff.old_text,
new_text=edit_diff.new_text,
)
]
else:
path = arguments.get("path", "")
content = [_text(f"Preparing write to {path}. Approval prompt shows the diff." if path else "Preparing file write. Approval prompt shows the diff.")]
return acp.start_tool_call( return acp.start_tool_call(
tool_call_id, title, kind=kind, content=content, locations=locations, tool_call_id, title, kind=kind, content=content, locations=locations,
) )

View File

@@ -24,6 +24,7 @@ from acp.schema import (
PromptResponse, PromptResponse,
ResumeSessionResponse, ResumeSessionResponse,
SessionModelState, SessionModelState,
SessionModeState,
SetSessionConfigOptionResponse, SetSessionConfigOptionResponse,
SetSessionModelResponse, SetSessionModelResponse,
SetSessionModeResponse, SetSessionModeResponse,
@@ -52,31 +53,34 @@ def agent(mock_manager):
"""HermesACPAgent backed by a mock session manager.""" """HermesACPAgent backed by a mock session manager."""
return HermesACPAgent(session_manager=mock_manager) return HermesACPAgent(session_manager=mock_manager)
@pytest.mark.asyncio
async def test_new_session_includes_edit_approval_config_option(self, agent):
resp = await agent.new_session(cwd="/tmp")
assert resp.config_options @pytest.mark.asyncio
option = resp.config_options[0] async def test_new_session_exposes_edit_approvals_as_modes_not_config_options(agent):
assert option.id == "edit_approval_policy" resp = await agent.new_session(cwd="/tmp")
assert option.current_value == "ask"
assert {choice.value for choice in option.options} == {
"ask",
"workspace_session",
"session",
}
@pytest.mark.asyncio assert resp.config_options is None
async def test_set_config_option_persists_edit_approval_policy(self, agent): assert isinstance(resp.modes, SessionModeState)
resp = await agent.new_session(cwd="/tmp") assert resp.modes.current_mode_id == "default"
update = await agent.set_config_option( assert [(mode.id, mode.name) for mode in resp.modes.available_modes] == [
"edit_approval_policy", ("default", "Default"),
resp.session_id, ("accept_edits", "Accept Edits"),
"workspace_session", ("dont_ask", "Don't Ask"),
) ]
assert isinstance(update, SetSessionConfigOptionResponse)
assert update.config_options[0].current_value == "workspace_session" @pytest.mark.asyncio
async def test_set_config_option_persists_edit_approval_policy_without_advertising_config(agent):
resp = await agent.new_session(cwd="/tmp")
update = await agent.set_config_option(
"edit_approval_policy",
resp.session_id,
"workspace_session",
)
state = agent.session_manager.get_session(resp.session_id)
assert isinstance(update, SetSessionConfigOptionResponse)
assert update.config_options == []
assert getattr(state, "mode", None) == "accept_edits"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -891,11 +895,11 @@ class TestSessionConfiguration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_session_mode_returns_response(self, agent): async def test_set_session_mode_returns_response(self, agent):
new_resp = await agent.new_session(cwd="/tmp") new_resp = await agent.new_session(cwd="/tmp")
resp = await agent.set_session_mode(mode_id="chat", session_id=new_resp.session_id) resp = await agent.set_session_mode(mode_id="accept_edits", session_id=new_resp.session_id)
state = agent.session_manager.get_session(new_resp.session_id) state = agent.session_manager.get_session(new_resp.session_id)
assert isinstance(resp, SetSessionModeResponse) assert isinstance(resp, SetSessionModeResponse)
assert getattr(state, "mode", None) == "chat" assert getattr(state, "mode", None) == "accept_edits"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_accepts_stable_session_config_methods(self, agent): async def test_router_accepts_stable_session_config_methods(self, agent):
@@ -904,7 +908,7 @@ class TestSessionConfiguration:
mode_result = await router( mode_result = await router(
"session/set_mode", "session/set_mode",
{"modeId": "chat", "sessionId": new_resp.session_id}, {"modeId": "accept_edits", "sessionId": new_resp.session_id},
False, False,
) )
config_result = await router( config_result = await router(
@@ -918,8 +922,7 @@ class TestSessionConfiguration:
) )
assert mode_result == {} assert mode_result == {}
assert config_result["configOptions"] assert config_result["configOptions"] == []
assert config_result["configOptions"][0]["id"] == "edit_approval_policy"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_accepts_unstable_model_switch_when_enabled(self, agent): async def test_router_accepts_unstable_model_switch_when_enabled(self, agent):

View File

@@ -2,6 +2,7 @@
import pytest import pytest
from acp_adapter.edit_approval import EditProposal
from acp_adapter.tools import ( from acp_adapter.tools import (
TOOL_KIND_MAP, TOOL_KIND_MAP,
build_tool_complete, build_tool_complete,
@@ -174,6 +175,25 @@ class TestBuildToolStart:
assert "Approval prompt shows the diff" in item.content.text assert "Approval prompt shows the diff" in item.content.text
assert "new_file.py" in item.content.text assert "new_file.py" in item.content.text
def test_auto_approved_edit_start_shows_diff_content(self):
"""Auto-approved edit starts need the diff because no approval card exists."""
args = {"path": "/tmp/acp.txt", "old_string": "old", "new_string": "new"}
result = build_tool_start(
"tc-auto-edit",
"patch",
args,
edit_diff=EditProposal("patch", "/tmp/acp.txt", "old\n", "new\n", args),
)
assert isinstance(result, ToolCallStart)
assert result.kind == "edit"
assert len(result.content) == 1
item = result.content[0]
assert isinstance(item, FileEditToolCallContent)
assert item.path == "/tmp/acp.txt"
assert item.old_text == "old\n"
assert item.new_text == "new\n"
def test_build_tool_start_for_terminal(self): def test_build_tool_start_for_terminal(self):
"""terminal should produce text content with the command.""" """terminal should produce text content with the command."""
args = {"command": "ls -la /tmp"} args = {"command": "ls -la /tmp"}