From f57ebf52e9bcb63477d0d282c9b5618ad814eb7c Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:33:19 -0700 Subject: [PATCH] fix(api-server): cancel orphaned agent + true interrupt on SSE disconnect (salvage #3399) (#3427) Salvage of #3399 by @binhnt92 with true agent interruption added on top. When a streaming /v1/chat/completions client disconnects mid-stream, the agent is now interrupted via agent.interrupt() so it stops making LLM API calls, and the asyncio task wrapper is cancelled. Closes #3399. --- gateway/platforms/api_server.py | 159 ++++++++------ tests/gateway/test_sse_agent_cancel.py | 280 +++++++++++++++++++++++++ 2 files changed, 377 insertions(+), 62 deletions(-) create mode 100644 tests/gateway/test_sse_agent_cancel.py diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index e5e81fe6d..0641aca28 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -495,17 +495,21 @@ class APIServerAdapter(BasePlatformAdapter): if delta is not None: _stream_q.put(delta) - # Start agent in background + # Start agent in background. agent_ref is a mutable container + # so the SSE writer can interrupt the agent on client disconnect. + agent_ref = [None] agent_task = asyncio.ensure_future(self._run_agent( user_message=user_message, conversation_history=history, ephemeral_system_prompt=system_prompt, session_id=session_id, stream_delta_callback=_on_delta, + agent_ref=agent_ref, )) return await self._write_sse_chat_completion( - request, completion_id, model_name, created, _stream_q, agent_task + request, completion_id, model_name, created, _stream_q, + agent_task, agent_ref, ) # Non-streaming: run the agent (with optional Idempotency-Key) @@ -568,9 +572,14 @@ class APIServerAdapter(BasePlatformAdapter): async def _write_sse_chat_completion( self, request: "web.Request", completion_id: str, model: str, - created: int, stream_q, agent_task, + created: int, stream_q, agent_task, agent_ref=None, ) -> "web.StreamResponse": - """Write real streaming SSE from agent's stream_delta_callback queue.""" + """Write real streaming SSE from agent's stream_delta_callback queue. + + If the client disconnects mid-stream (network drop, browser tab close), + the agent is interrupted via ``agent.interrupt()`` so it stops making + LLM API calls, and the asyncio task wrapper is cancelled. + """ import queue as _q response = web.StreamResponse( @@ -579,69 +588,87 @@ class APIServerAdapter(BasePlatformAdapter): ) await response.prepare(request) - # Role chunk - role_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], - } - await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) - - # Stream content chunks as they arrive from the agent - loop = asyncio.get_event_loop() - while True: - try: - delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5)) - except _q.Empty: - if agent_task.done(): - # Drain any remaining items - while True: - try: - delta = stream_q.get_nowait() - if delta is None: - break - content_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], - } - await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) - except _q.Empty: - break - break - continue - - if delta is None: # End of stream sentinel - break - - content_chunk = { + try: + # Role chunk + role_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], } - await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) - # Get usage from completed agent - usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} - try: - result, agent_usage = await agent_task - usage = agent_usage or usage - except Exception: - pass + # Stream content chunks as they arrive from the agent + loop = asyncio.get_event_loop() + while True: + try: + delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5)) + except _q.Empty: + if agent_task.done(): + # Drain any remaining items + while True: + try: + delta = stream_q.get_nowait() + if delta is None: + break + content_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + except _q.Empty: + break + break + continue - # Finish chunk - finish_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], - "usage": { - "prompt_tokens": usage.get("input_tokens", 0), - "completion_tokens": usage.get("output_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - }, - } - await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode()) - await response.write(b"data: [DONE]\n\n") + if delta is None: # End of stream sentinel + break + + content_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + + # Get usage from completed agent + usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + try: + result, agent_usage = await agent_task + usage = agent_usage or usage + except Exception: + pass + + # Finish chunk + finish_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + } + await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode()) + await response.write(b"data: [DONE]\n\n") + except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError): + # Client disconnected mid-stream. Interrupt the agent so it + # stops making LLM API calls at the next loop iteration, then + # cancel the asyncio task wrapper. + agent = agent_ref[0] if agent_ref else None + if agent is not None: + try: + agent.interrupt("SSE client disconnected") + except Exception: + pass + if not agent_task.done(): + agent_task.cancel() + try: + await agent_task + except (asyncio.CancelledError, Exception): + pass + logger.info("SSE client disconnected; interrupted agent task %s", completion_id) return response @@ -1144,12 +1171,18 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt: Optional[str] = None, session_id: Optional[str] = None, stream_delta_callback=None, + agent_ref: Optional[list] = None, ) -> tuple: """ Create an agent and run a conversation in a thread executor. Returns ``(result_dict, usage_dict)`` where *usage_dict* contains ``input_tokens``, ``output_tokens`` and ``total_tokens``. + + If *agent_ref* is a one-element list, the AIAgent instance is stored + at ``agent_ref[0]`` before ``run_conversation`` begins. This allows + callers (e.g. the SSE writer) to call ``agent.interrupt()`` from + another thread to stop in-progress LLM calls. """ loop = asyncio.get_event_loop() @@ -1159,6 +1192,8 @@ class APIServerAdapter(BasePlatformAdapter): session_id=session_id, stream_delta_callback=stream_delta_callback, ) + if agent_ref is not None: + agent_ref[0] = agent result = agent.run_conversation( user_message=user_message, conversation_history=conversation_history, diff --git a/tests/gateway/test_sse_agent_cancel.py b/tests/gateway/test_sse_agent_cancel.py new file mode 100644 index 000000000..6b5306fbe --- /dev/null +++ b/tests/gateway/test_sse_agent_cancel.py @@ -0,0 +1,280 @@ +"""Tests for SSE client disconnect → agent task cancellation. + +When a streaming /v1/chat/completions client disconnects mid-stream +(network drop, browser tab close), the agent is interrupted via +agent.interrupt() so it stops making LLM API calls, and the asyncio +task wrapper is cancelled. +""" + +import asyncio +import json +import queue +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_adapter(): + """Build a minimal APIServerAdapter with mocked internals.""" + from gateway.platforms.api_server import APIServerAdapter + from gateway.config import PlatformConfig + + config = PlatformConfig(enabled=True, token="test-key") + adapter = APIServerAdapter(config) + return adapter + + +def _make_request(): + """Build a mock aiohttp request.""" + req = MagicMock() + req.headers = {} + return req + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestSSEAgentCancelOnDisconnect: + """gateway/platforms/api_server.py — _write_sse_chat_completion()""" + + def test_agent_task_cancelled_on_client_disconnect(self): + """When response.write raises ConnectionResetError (client dropped), + the agent task must be cancelled.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("hello ") # Some data already queued + + # Agent task that runs forever (simulates a long LLM call) + agent_done = asyncio.Event() + + async def fake_agent(): + await agent_done.wait() + return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + + # Mock response that raises ConnectionResetError on second write + mock_response = AsyncMock(spec=web.StreamResponse) + call_count = 0 + + async def write_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise ConnectionResetError("client disconnected") + + mock_response.write = AsyncMock(side_effect=write_side_effect) + mock_response.prepare = AsyncMock() + + with patch.object(type(adapter), '_write_sse_chat_completion', + adapter._write_sse_chat_completion): + # Patch StreamResponse creation + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-123", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + # The critical assertion: agent_task must be cancelled + assert agent_task.cancelled() or agent_task.done() + # Clean up + agent_done.set() + + asyncio.run(run()) + + def test_agent_task_not_cancelled_on_normal_completion(self): + """On normal stream completion, agent task should NOT be cancelled.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("hello") + stream_q.put(None) # End-of-stream sentinel + + async def fake_agent(): + return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + await asyncio.sleep(0) # Let agent complete + + mock_response = AsyncMock(spec=web.StreamResponse) + mock_response.write = AsyncMock() + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-456", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + # Agent should have completed normally, not been cancelled + assert agent_task.done() + assert not agent_task.cancelled() + + asyncio.run(run()) + + def test_broken_pipe_also_cancels_agent(self): + """BrokenPipeError (another disconnect variant) also cancels the task.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + + async def fake_agent(): + await asyncio.sleep(999) # Never completes + return {}, {} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + + mock_response = AsyncMock(spec=web.StreamResponse) + mock_response.write = AsyncMock(side_effect=BrokenPipeError("pipe broken")) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-789", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + assert agent_task.cancelled() or agent_task.done() + + asyncio.run(run()) + + def test_already_done_task_not_cancelled_on_disconnect(self): + """If agent already finished before disconnect, don't try to cancel.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("data") + + async def fake_agent(): + return {"final_response": "done"}, {} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + await asyncio.sleep(0) # Let agent complete + + mock_response = AsyncMock(spec=web.StreamResponse) + call_count = 0 + + async def write_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise ConnectionResetError("late disconnect") + + mock_response.write = AsyncMock(side_effect=write_side_effect) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-done", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + # Task was already done — should not be cancelled + assert agent_task.done() + assert not agent_task.cancelled() + + asyncio.run(run()) + + def test_agent_interrupt_called_on_disconnect(self): + """When the client disconnects, agent.interrupt() must be called + so the agent thread stops making LLM API calls.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("hello ") + + agent_done = asyncio.Event() + + async def fake_agent(): + await agent_done.wait() + return {"final_response": "done"}, {} + + # Mock agent with an interrupt method + mock_agent = MagicMock() + mock_agent.interrupt = MagicMock() + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + agent_ref = [mock_agent] + + mock_response = AsyncMock(spec=web.StreamResponse) + call_count = 0 + + async def write_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise ConnectionResetError("client disconnected") + + mock_response.write = AsyncMock(side_effect=write_side_effect) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-int", "gpt-4", 1234567890, + stream_q, agent_task, agent_ref, + ) + + # agent.interrupt() must have been called + mock_agent.interrupt.assert_called_once_with("SSE client disconnected") + # Clean up + agent_done.set() + + asyncio.run(run()) + + def test_agent_ref_none_still_cancels_task(self): + """When agent_ref is not provided (None), the task is still cancelled + on disconnect — just without the interrupt() call.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + + async def fake_agent(): + await asyncio.sleep(999) + return {}, {} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + + mock_response = AsyncMock(spec=web.StreamResponse) + mock_response.write = AsyncMock(side_effect=BrokenPipeError("gone")) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + # No agent_ref passed — should still handle disconnect cleanly + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-noref", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + assert agent_task.cancelled() or agent_task.done() + + asyncio.run(run())