diff --git a/acp_adapter/server.py b/acp_adapter/server.py index e69ff5955..ab37c5c0b 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -569,6 +569,9 @@ class HermesACPAgent(acp.Agent): async def cancel(self, session_id: str, **kwargs: Any) -> None: state = self.session_manager.get_session(session_id) if state and state.cancel_event: + with state.runtime_lock: + if state.is_running and state.current_prompt_text: + state.interrupted_prompt_text = state.current_prompt_text state.cancel_event.set() try: if getattr(state, "agent", None) and hasattr(state.agent, "interrupt"): @@ -666,6 +669,25 @@ class HermesACPAgent(acp.Agent): if not has_content: return PromptResponse(stop_reason="end_turn") + # Zed currently interrupts an active ACP request before delivering a + # follow-up slash command. If that follow-up is /steer, there may be no + # live AIAgent left to steer by the time this method runs. Salvage that + # UX by replaying the interrupted prompt with the steer text attached as + # explicit correction/guidance. + if isinstance(user_content, str) and user_text.startswith("/steer"): + steer_text = user_text.split(maxsplit=1)[1].strip() if len(user_text.split(maxsplit=1)) > 1 else "" + interrupted_prompt = "" + with state.runtime_lock: + if not state.is_running and steer_text and state.interrupted_prompt_text: + interrupted_prompt = state.interrupted_prompt_text + state.interrupted_prompt_text = "" + if interrupted_prompt: + user_text = ( + f"{interrupted_prompt}\n\n" + f"User correction/guidance after interrupt: {steer_text}" + ) + user_content = user_text + # Intercept slash commands — handle locally without calling the LLM. # Slash commands are text-only; if the client included images/resources, # send the whole multimodal prompt to the agent instead of treating it as @@ -694,6 +716,7 @@ class HermesACPAgent(acp.Agent): await self._conn.session_update(session_id, update) return PromptResponse(stop_reason="end_turn") state.is_running = True + state.current_prompt_text = user_text or "[Image attachment]" logger.info("Prompt on session %s: %s", session_id, user_text[:100]) @@ -808,6 +831,7 @@ class HermesACPAgent(acp.Agent): logger.exception("Executor error for session %s", session_id) with state.runtime_lock: state.is_running = False + state.current_prompt_text = "" return PromptResponse(stop_reason="end_turn") if result.get("messages"): @@ -838,6 +862,7 @@ class HermesACPAgent(acp.Agent): # normal follow-up user prompts, preserving role alternation and history. with state.runtime_lock: state.is_running = False + state.current_prompt_text = "" while True: with state.runtime_lock: diff --git a/acp_adapter/session.py b/acp_adapter/session.py index 0b627aabe..d1fb1a874 100644 --- a/acp_adapter/session.py +++ b/acp_adapter/session.py @@ -148,6 +148,8 @@ class SessionState: is_running: bool = False queued_prompts: List[str] = field(default_factory=list) runtime_lock: Any = field(default_factory=Lock) + current_prompt_text: str = "" + interrupted_prompt_text: str = "" class SessionManager: diff --git a/tests/acp_adapter/test_acp_commands.py b/tests/acp_adapter/test_acp_commands.py index f8a0ad45e..20082fe28 100644 --- a/tests/acp_adapter/test_acp_commands.py +++ b/tests/acp_adapter/test_acp_commands.py @@ -81,6 +81,24 @@ async def test_acp_steer_slash_command_injects_into_running_agent(): assert fake.runs == [] +@pytest.mark.asyncio +async def test_acp_steer_after_zed_interrupt_replays_interrupted_prompt_with_guidance(): + acp_agent, state, fake, _conn = make_agent_and_state() + state.interrupted_prompt_text = "write hi to a text file" + + response = await acp_agent.prompt( + session_id=state.session_id, + prompt=[TextContentBlock(type="text", text="/steer write HELLO instead")], + ) + + assert response.stop_reason == "end_turn" + assert fake.steers == [] + assert fake.runs == [ + "write hi to a text file\n\nUser correction/guidance after interrupt: write HELLO instead" + ] + assert state.interrupted_prompt_text == "" + + @pytest.mark.asyncio async def test_acp_queue_slash_command_adds_next_turn_without_running_now(): acp_agent, state, fake, _conn = make_agent_and_state()