diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 2eae88ccd..fe97eb69c 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -2832,10 +2832,23 @@ class BasePlatformAdapter(ABC): # reset-like command that already swapped in its own # command_guard (and cancelled us) can't be accidentally # cleared by our unwind. The command owns the session now. + # + # The owner-check also covers the in-band drain handoff + # above: when we spawned a drain_task and transferred + # ownership via ``_session_tasks[session_key] = drain_task``, + # ``_session_tasks.get(session_key) is current_task`` is + # False, so we leave _active_sessions populated. Without + # this guard, the drain task picks up the same + # interrupt_event in its own _process_message_background + # entry, _release_session_guard's guard-match succeeds, + # and we'd delete the entry while the drain task is still + # running — letting a concurrent inbound message pass + # the Level-1 guard and spawn a second handler for the + # same session. current_task = asyncio.current_task() if current_task is not None and self._session_tasks.get(session_key) is current_task: del self._session_tasks[session_key] - self._release_session_guard(session_key, guard=interrupt_event) + self._release_session_guard(session_key, guard=interrupt_event) async def cancel_background_tasks(self) -> None: """Cancel any in-flight background message-processing tasks. diff --git a/tests/gateway/test_pending_drain_no_recursion.py b/tests/gateway/test_pending_drain_no_recursion.py index a005061eb..620dfe9f3 100644 --- a/tests/gateway/test_pending_drain_no_recursion.py +++ b/tests/gateway/test_pending_drain_no_recursion.py @@ -21,9 +21,9 @@ task spawning keeps it constant (1 every time). import asyncio import sys +from unittest.mock import AsyncMock import pytest -from unittest.mock import AsyncMock from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( @@ -127,3 +127,62 @@ async def test_in_band_drain_does_not_grow_stack(): f"in-band drain is recursing instead of spawning a fresh task — " f"stack depth grew with chain length: {depths!r}" ) + + +@pytest.mark.asyncio +async def test_in_band_drain_preserves_active_session_guard(): + """The original task must NOT release ``_active_sessions[session_key]`` + after handing off to the drain task. + + When the in-band drain spawns ``drain_task`` and transfers ownership + via ``_session_tasks[session_key] = drain_task``, the original task + still unwinds through the ``finally`` block. The drain task picks + up the same ``interrupt_event`` in its own + ``_process_message_background`` entry, so a naive + ``_release_session_guard(session_key, guard=interrupt_event)`` in + the unwind matches and deletes ``_active_sessions[session_key]``. + That briefly reopens the Level-1 guard between the original task's + finally and the drain task's first await — a concurrent inbound + arriving in that window passes the guard and spawns a second + handler for the same session. + + Invariant: ``_active_sessions[sk]`` must hold the SAME interrupt + Event identity at every handler entry across an in-band drain + chain. Pre-fix, the original task's finally deletes the entry, so + the drain task falls through to the ``or asyncio.Event()`` branch + in ``_process_message_background`` and installs a *new* Event — + the identity diverges. Post-fix, the entry is preserved across + handoff and the drain task reuses the original Event. + """ + adapter = _make_adapter() + sk = _sk() + + seen_guards: list = [] + + async def handler(event): + seen_guards.append(adapter._active_sessions.get(sk)) + if len(seen_guards) == 1: + adapter._pending_messages[sk] = _make_event(text="M1") + return "ok" + + adapter._message_handler = handler + + await adapter.handle_message(_make_event(text="M0")) + + for _ in range(400): + if len(seen_guards) >= 2 and sk not in adapter._active_sessions: + break + await asyncio.sleep(0.01) + + await adapter.cancel_background_tasks() + + assert len(seen_guards) == 2, f"expected 2 handler runs, got {len(seen_guards)}" + assert seen_guards[0] is not None, "M0 saw no active-session guard" + assert seen_guards[1] is not None, "M1 saw no active-session guard" + assert seen_guards[0] is seen_guards[1], ( + "in-band drain handoff replaced the active-session guard — the " + "original task's finally deleted _active_sessions[sk] and the " + "drain task installed a new Event. Concurrent inbounds during " + "the handoff window would bypass the Level-1 guard and spawn a " + "second handler for the same session." + )