From cac6178104f9dc9c3327454fc0af022f19752dd4 Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 12:09:01 -0700 Subject: [PATCH 01/35] fix(gateway): propagate user identity through process watcher pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Background process watchers (notify_on_complete, check_interval) created synthetic SessionSource objects without user_id/user_name. While the internal=True bypass (1d8d4f28) prevented false pairing for agent- generated notifications, the missing identity caused: - Garbage entries in pairing rate limiters (discord:None, telegram:None) - 'User None' in approval messages and logs - No user identity available for future code paths that need it Additionally, platform messages arriving without from_user (Telegram service messages, channel forwards, anonymous admin actions) could still trigger false pairing because they are not internal events. Fix: 1. Propagate user_id/user_name through the full watcher chain: session_context.py → gateway/run.py → terminal_tool.py → process_registry.py (including checkpoint persistence/recovery) 2. Add None user_id guard in _handle_message() — silently drop non-internal messages with no user identity instead of triggering the pairing flow. Salvaged from PRs #7664 (kagura-agent, ContextVar approach), #6540 (MestreY0d4-Uninter, tests), and #7709 (guang384, None guard). Closes #6341, #6485, #7643 Relates to #6516, #7392 --- gateway/run.py | 13 +++ gateway/session_context.py | 10 ++ .../test_internal_event_bypass_pairing.py | 99 +++++++++++++++++++ tests/gateway/test_session_env.py | 13 +++ tests/tools/test_notify_on_complete.py | 4 + tests/tools/test_process_registry.py | 8 ++ tools/process_registry.py | 8 ++ tools/terminal_tool.py | 12 +++ 8 files changed, 167 insertions(+) diff --git a/gateway/run.py b/gateway/run.py index df69a498c..2bd493005 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2221,6 +2221,13 @@ class GatewayRunner: # are system-generated and must skip user authorization. if getattr(event, "internal", False): pass + elif source.user_id is None: + # Messages with no user identity (Telegram service messages, + # channel forwards, anonymous admin actions) cannot be + # authorized — drop silently instead of triggering the pairing + # flow with a None user_id. + logger.debug("Ignoring message with no user_id from %s", source.platform.value) + return None elif not self._is_user_authorized(source): logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value) # In DMs: offer pairing code. In groups: silently ignore. @@ -6597,6 +6604,8 @@ class GatewayRunner: chat_id=context.source.chat_id, chat_name=context.source.chat_name or "", thread_id=str(context.source.thread_id) if context.source.thread_id else "", + user_id=str(context.source.user_id) if context.source.user_id else "", + user_name=str(context.source.user_name) if context.source.user_name else "", ) def _clear_session_env(self, tokens: list) -> None: @@ -6809,6 +6818,8 @@ class GatewayRunner: platform_name = watcher.get("platform", "") chat_id = watcher.get("chat_id", "") thread_id = watcher.get("thread_id", "") + user_id = watcher.get("user_id", "") + user_name = watcher.get("user_name", "") agent_notify = watcher.get("notify_on_complete", False) notify_mode = self._load_background_notifications_mode() @@ -6864,6 +6875,8 @@ class GatewayRunner: platform=_platform_enum, chat_id=chat_id, thread_id=thread_id or None, + user_id=user_id or None, + user_name=user_name or None, ) synth_event = MessageEvent( text=synth_text, diff --git a/gateway/session_context.py b/gateway/session_context.py index 775cd8698..6d676dc1e 100644 --- a/gateway/session_context.py +++ b/gateway/session_context.py @@ -46,12 +46,16 @@ _SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", defau _SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="") _SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="") _SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="") +_SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="") +_SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="") _VAR_MAP = { "HERMES_SESSION_PLATFORM": _SESSION_PLATFORM, "HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID, "HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME, "HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID, + "HERMES_SESSION_USER_ID": _SESSION_USER_ID, + "HERMES_SESSION_USER_NAME": _SESSION_USER_NAME, } @@ -60,6 +64,8 @@ def set_session_vars( chat_id: str = "", chat_name: str = "", thread_id: str = "", + user_id: str = "", + user_name: str = "", ) -> list: """Set all session context variables and return reset tokens. @@ -74,6 +80,8 @@ def set_session_vars( _SESSION_CHAT_ID.set(chat_id), _SESSION_CHAT_NAME.set(chat_name), _SESSION_THREAD_ID.set(thread_id), + _SESSION_USER_ID.set(user_id), + _SESSION_USER_NAME.set(user_name), ] return tokens @@ -87,6 +95,8 @@ def clear_session_vars(tokens: list) -> None: _SESSION_CHAT_ID, _SESSION_CHAT_NAME, _SESSION_THREAD_ID, + _SESSION_USER_ID, + _SESSION_USER_NAME, ] for var, token in zip(vars_in_order, tokens): var.reset(token) diff --git a/tests/gateway/test_internal_event_bypass_pairing.py b/tests/gateway/test_internal_event_bypass_pairing.py index 05b093b04..46a96e5aa 100644 --- a/tests/gateway/test_internal_event_bypass_pairing.py +++ b/tests/gateway/test_internal_event_bypass_pairing.py @@ -195,6 +195,105 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path): ) +@pytest.mark.asyncio +async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path): + """Synthetic completion event should carry user_id and user_name from the watcher.""" + import tools.process_registry as pr_module + + sessions = [ + SimpleNamespace( + output_buffer="done\n", exited=True, exit_code=0, command="echo test" + ), + ] + monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions)) + + async def _instant_sleep(*_a, **_kw): + pass + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + runner = _build_runner(monkeypatch, tmp_path) + adapter = runner.adapters[Platform.DISCORD] + + watcher = _watcher_dict_with_notify() + watcher["user_id"] = "user-42" + watcher["user_name"] = "alice" + + await runner._run_process_watcher(watcher) + + assert adapter.handle_message.await_count == 1 + event = adapter.handle_message.await_args.args[0] + assert event.source.user_id == "user-42" + assert event.source.user_name == "alice" + + +@pytest.mark.asyncio +async def test_none_user_id_skips_pairing(monkeypatch, tmp_path): + """A non-internal event with user_id=None should be silently dropped.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.TELEGRAM] = adapter + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + user_id=None, + ) + event = MessageEvent( + text="service message", + source=source, + internal=False, + ) + + result = await runner._handle_message(event) + + # Should return None (dropped) and NOT send any pairing message + assert result is None + assert adapter.send.await_count == 0 + + +@pytest.mark.asyncio +async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path): + """A message with user_id=None must never call generate_code.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + + generate_called = False + original_generate = runner.pairing_store.generate_code + + def tracking_generate(*args, **kwargs): + nonlocal generate_called + generate_called = True + return original_generate(*args, **kwargs) + + runner.pairing_store.generate_code = tracking_generate + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="456", + chat_type="dm", + user_id=None, + ) + event = MessageEvent(text="anonymous", source=source, internal=False) + + await runner._handle_message(event) + + assert not generate_called, ( + "Pairing code should NOT be generated for messages with user_id=None" + ) + + @pytest.mark.asyncio async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path): """Verify the normal (non-internal) path still triggers pairing for unknown users.""" diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py index a7f1345b7..b75e267f1 100644 --- a/tests/gateway/test_session_env.py +++ b/tests/gateway/test_session_env.py @@ -18,6 +18,8 @@ def test_set_session_env_sets_contextvars(monkeypatch): chat_id="-1001", chat_name="Group", chat_type="group", + user_id="123456", + user_name="alice", thread_id="17585", ) context = SessionContext(source=source, connected_platforms=[], home_channels={}) @@ -25,6 +27,8 @@ def test_set_session_env_sets_contextvars(monkeypatch): monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False) monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False) tokens = runner._set_session_env(context) @@ -33,6 +37,8 @@ def test_set_session_env_sets_contextvars(monkeypatch): assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram" assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001" assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group" + assert get_session_env("HERMES_SESSION_USER_ID") == "123456" + assert get_session_env("HERMES_SESSION_USER_NAME") == "alice" assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585" # os.environ should NOT be touched @@ -50,6 +56,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch): monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False) monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False) source = SessionSource( @@ -57,12 +65,15 @@ def test_clear_session_env_restores_previous_state(monkeypatch): chat_id="-1001", chat_name="Group", chat_type="group", + user_id="123456", + user_name="alice", thread_id="17585", ) context = SessionContext(source=source, connected_platforms=[], home_channels={}) tokens = runner._set_session_env(context) assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram" + assert get_session_env("HERMES_SESSION_USER_ID") == "123456" runner._clear_session_env(tokens) @@ -70,6 +81,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch): assert get_session_env("HERMES_SESSION_PLATFORM") == "" assert get_session_env("HERMES_SESSION_CHAT_ID") == "" assert get_session_env("HERMES_SESSION_CHAT_NAME") == "" + assert get_session_env("HERMES_SESSION_USER_ID") == "" + assert get_session_env("HERMES_SESSION_USER_NAME") == "" assert get_session_env("HERMES_SESSION_THREAD_ID") == "" diff --git a/tests/tools/test_notify_on_complete.py b/tests/tools/test_notify_on_complete.py index ff6f14922..411f95f7e 100644 --- a/tests/tools/test_notify_on_complete.py +++ b/tests/tools/test_notify_on_complete.py @@ -227,6 +227,8 @@ class TestCheckpointNotify: "session_key": "sk1", "watcher_platform": "telegram", "watcher_chat_id": "123", + "watcher_user_id": "u123", + "watcher_user_name": "alice", "watcher_thread_id": "42", "watcher_interval": 5, "notify_on_complete": True, @@ -236,6 +238,8 @@ class TestCheckpointNotify: assert recovered == 1 assert len(registry.pending_watchers) == 1 assert registry.pending_watchers[0]["notify_on_complete"] is True + assert registry.pending_watchers[0]["user_id"] == "u123" + assert registry.pending_watchers[0]["user_name"] == "alice" def test_recover_defaults_false(self, registry, tmp_path): """Old checkpoint entries without the field default to False.""" diff --git a/tests/tools/test_process_registry.py b/tests/tools/test_process_registry.py index a61da9dd3..d981878a3 100644 --- a/tests/tools/test_process_registry.py +++ b/tests/tools/test_process_registry.py @@ -438,6 +438,8 @@ class TestCheckpoint: s = _make_session() s.watcher_platform = "telegram" s.watcher_chat_id = "999" + s.watcher_user_id = "u123" + s.watcher_user_name = "alice" s.watcher_thread_id = "42" s.watcher_interval = 60 registry._running[s.id] = s @@ -447,6 +449,8 @@ class TestCheckpoint: assert len(data) == 1 assert data[0]["watcher_platform"] == "telegram" assert data[0]["watcher_chat_id"] == "999" + assert data[0]["watcher_user_id"] == "u123" + assert data[0]["watcher_user_name"] == "alice" assert data[0]["watcher_thread_id"] == "42" assert data[0]["watcher_interval"] == 60 @@ -460,6 +464,8 @@ class TestCheckpoint: "session_key": "sk1", "watcher_platform": "telegram", "watcher_chat_id": "123", + "watcher_user_id": "u123", + "watcher_user_name": "alice", "watcher_thread_id": "42", "watcher_interval": 60, }])) @@ -471,6 +477,8 @@ class TestCheckpoint: assert w["session_id"] == "proc_live" assert w["platform"] == "telegram" assert w["chat_id"] == "123" + assert w["user_id"] == "u123" + assert w["user_name"] == "alice" assert w["thread_id"] == "42" assert w["check_interval"] == 60 diff --git a/tools/process_registry.py b/tools/process_registry.py index 1be9b89f6..1761221f0 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -85,6 +85,8 @@ class ProcessSession: # Watcher/notification metadata (persisted for crash recovery) watcher_platform: str = "" watcher_chat_id: str = "" + watcher_user_id: str = "" + watcher_user_name: str = "" watcher_thread_id: str = "" watcher_interval: int = 0 # 0 = no watcher configured notify_on_complete: bool = False # Queue agent notification on exit @@ -970,6 +972,8 @@ class ProcessRegistry: "session_key": s.session_key, "watcher_platform": s.watcher_platform, "watcher_chat_id": s.watcher_chat_id, + "watcher_user_id": s.watcher_user_id, + "watcher_user_name": s.watcher_user_name, "watcher_thread_id": s.watcher_thread_id, "watcher_interval": s.watcher_interval, "notify_on_complete": s.notify_on_complete, @@ -1031,6 +1035,8 @@ class ProcessRegistry: detached=True, # Can't read output, but can report status + kill watcher_platform=entry.get("watcher_platform", ""), watcher_chat_id=entry.get("watcher_chat_id", ""), + watcher_user_id=entry.get("watcher_user_id", ""), + watcher_user_name=entry.get("watcher_user_name", ""), watcher_thread_id=entry.get("watcher_thread_id", ""), watcher_interval=entry.get("watcher_interval", 0), notify_on_complete=entry.get("notify_on_complete", False), @@ -1049,6 +1055,8 @@ class ProcessRegistry: "session_key": session.session_key, "platform": session.watcher_platform, "chat_id": session.watcher_chat_id, + "user_id": session.watcher_user_id, + "user_name": session.watcher_user_name, "thread_id": session.watcher_thread_id, "notify_on_complete": session.notify_on_complete, }) diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 859f0f1f3..f0cbff0f4 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -1427,8 +1427,12 @@ def terminal_tool( if _gw_platform and not check_interval: _gw_chat_id = _gse("HERMES_SESSION_CHAT_ID", "") _gw_thread_id = _gse("HERMES_SESSION_THREAD_ID", "") + _gw_user_id = _gse("HERMES_SESSION_USER_ID", "") + _gw_user_name = _gse("HERMES_SESSION_USER_NAME", "") proc_session.watcher_platform = _gw_platform proc_session.watcher_chat_id = _gw_chat_id + proc_session.watcher_user_id = _gw_user_id + proc_session.watcher_user_name = _gw_user_name proc_session.watcher_thread_id = _gw_thread_id proc_session.watcher_interval = 5 process_registry.pending_watchers.append({ @@ -1437,6 +1441,8 @@ def terminal_tool( "session_key": session_key, "platform": _gw_platform, "chat_id": _gw_chat_id, + "user_id": _gw_user_id, + "user_name": _gw_user_name, "thread_id": _gw_thread_id, "notify_on_complete": True, }) @@ -1457,10 +1463,14 @@ def terminal_tool( watcher_platform = _gse2("HERMES_SESSION_PLATFORM", "") watcher_chat_id = _gse2("HERMES_SESSION_CHAT_ID", "") watcher_thread_id = _gse2("HERMES_SESSION_THREAD_ID", "") + watcher_user_id = _gse2("HERMES_SESSION_USER_ID", "") + watcher_user_name = _gse2("HERMES_SESSION_USER_NAME", "") # Store on session for checkpoint persistence proc_session.watcher_platform = watcher_platform proc_session.watcher_chat_id = watcher_chat_id + proc_session.watcher_user_id = watcher_user_id + proc_session.watcher_user_name = watcher_user_name proc_session.watcher_thread_id = watcher_thread_id proc_session.watcher_interval = effective_interval @@ -1470,6 +1480,8 @@ def terminal_tool( "session_key": session_key, "platform": watcher_platform, "chat_id": watcher_chat_id, + "user_id": watcher_user_id, + "user_name": watcher_user_name, "thread_id": watcher_thread_id, }) From 39da23a1291f45ee7170fbe24068b9a4ea5054d8 Mon Sep 17 00:00:00 2001 From: helix4u <4317663+helix4u@users.noreply.github.com> Date: Sat, 11 Apr 2026 12:42:01 -0600 Subject: [PATCH 02/35] fix(api-server): keep chat-completions SSE alive --- gateway/platforms/api_server.py | 18 ++++++++++++--- tests/gateway/test_api_server.py | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index baada7e05..1954a2b9e 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8642 MAX_STORED_RESPONSES = 100 MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies +CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0 def check_api_server_requirements() -> bool: @@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter): """ import queue as _q - sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"} + sse_headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + } # CORS middleware can't inject headers into StreamResponse after # prepare() flushes them, so resolve CORS headers up front. origin = request.headers.get("Origin", "") @@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter): await response.prepare(request) try: + last_activity = time.monotonic() + # Role chunk role_chunk = { "id": completion_id, "object": "chat.completion.chunk", @@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter): "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], } await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) + last_activity = time.monotonic() # Helper — route a queue item to the correct SSE event. async def _emit(item): @@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter): "choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}], } await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + return time.monotonic() # Stream content chunks as they arrive from the agent loop = asyncio.get_event_loop() @@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter): delta = stream_q.get_nowait() if delta is None: break - await _emit(delta) + last_activity = await _emit(delta) except _q.Empty: break break + if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS: + await response.write(b": keepalive\n\n") + last_activity = time.monotonic() continue if delta is None: # End of stream sentinel break - await _emit(delta) + last_activity = await _emit(delta) # Get usage from completed agent usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index afc3ce9ce..2be01fc2d 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint: ) assert resp.status == 200 assert "text/event-stream" in resp.headers.get("Content-Type", "") + assert resp.headers.get("X-Accel-Buffering") == "no" body = await resp.text() assert "data: " in body assert "[DONE]" in body assert "Hello!" in body + @pytest.mark.asyncio + async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter): + """Idle SSE streams should send keepalive comments while tools run silently.""" + import asyncio + import gateway.platforms.api_server as api_server_mod + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + if cb: + cb("Working") + await asyncio.sleep(0.65) + cb("...done") + return ( + {"final_response": "Working...done", "messages": [], "api_calls": 1}, + {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + ) + + with ( + patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01), + patch.object(adapter, "_run_agent", side_effect=_mock_run_agent), + ): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "do the thing"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + assert ": keepalive" in body + assert "Working" in body + assert "...done" in body + assert "[DONE]" in body + @pytest.mark.asyncio async def test_stream_survives_tool_call_none_sentinel(self, adapter): """stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream. From 591041200211f8b2268468dedb3d1a3e3947708c Mon Sep 17 00:00:00 2001 From: Tom Qiao Date: Sat, 11 Apr 2026 19:27:22 +0800 Subject: [PATCH 03/35] fix: detect truncated tool_calls when finish_reason is not length When API routers rewrite finish_reason from "length" to "tool_calls", truncated JSON arguments bypassed the length handler and wasted 3 retry attempts in the generic JSON validation loop. Now detects truncation patterns in tool call arguments regardless of finish_reason. Fixes #7680 Co-Authored-By: Claude Opus 4.6 --- run_agent.py | 33 +++++++++++++++++++++++++++++-- tests/run_agent/test_run_agent.py | 29 +++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/run_agent.py b/run_agent.py index f995a2886..e56e23b7d 100644 --- a/run_agent.py +++ b/run_agent.py @@ -9506,12 +9506,41 @@ class AIAgent: invalid_json_args.append((tc.function.name, str(e))) if invalid_json_args: + # Check if the invalid JSON is due to truncation rather + # than a model formatting mistake. Routers sometimes + # rewrite finish_reason from "length" to "tool_calls", + # hiding the truncation from the length handler above. + # Detect truncation: args that don't end with } or ] + # (after stripping whitespace) are cut off mid-stream. + _truncated = any( + not (tc.function.arguments or "").rstrip().endswith(("}", "]")) + for tc in assistant_message.tool_calls + if tc.function.name in {n for n, _ in invalid_json_args} + ) + if _truncated: + self._vprint( + f"{self.log_prefix}⚠️ Truncated tool call arguments detected " + f"(finish_reason={finish_reason!r}) — refusing to execute.", + force=True, + ) + self._invalid_json_retries = 0 + self._cleanup_task_resources(effective_task_id) + self._persist_session(messages, conversation_history) + return { + "final_response": None, + "messages": messages, + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": "Response truncated due to output length limit", + } + # Track retries for invalid JSON arguments self._invalid_json_retries += 1 - + tool_name, error_msg = invalid_json_args[0] self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}") - + if self._invalid_json_retries < 3: self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...") # Don't add anything to messages, just retry the API call diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 0f2d1d4de..9851939ae 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2169,6 +2169,35 @@ class TestRunConversation: mock_hfc.assert_called_once() assert result["final_response"] == "Done!" + def test_truncated_tool_args_detected_when_finish_reason_not_length(self, agent): + """When a router rewrites finish_reason from 'length' to 'tool_calls', + truncated JSON arguments should still be detected and refused rather + than wasting 3 retry attempts.""" + self._setup_agent(agent) + agent.valid_tool_names.add("write_file") + bad_tc = _mock_tool_call( + name="write_file", + arguments='{"path":"report.md","content":"partial', + call_id="c1", + ) + resp = _mock_response( + content="", finish_reason="tool_calls", tool_calls=[bad_tc], + ) + agent.client.chat.completions.create.return_value = resp + + with ( + patch("run_agent.handle_function_call") as mock_handle_function_call, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("write the report") + + assert result["completed"] is False + assert result["partial"] is True + assert "truncated due to output length limit" in result["error"] + mock_handle_function_call.assert_not_called() + class TestRetryExhaustion: """Regression: retry_count > max_retries was dead code (off-by-one). From 151654851c860efb0eb65e705f96fb2856324d95 Mon Sep 17 00:00:00 2001 From: ygd58 Date: Sat, 11 Apr 2026 15:23:35 +0200 Subject: [PATCH 04/35] fix(agent): prevent false thinking-exhaustion for non-reasoning models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Models that do not use tags (e.g. GLM-4.7 on NVIDIA Build, minimax) may return content=None or empty string when truncated. The previous _thinking_exhausted check treated any None/empty content as thinking-budget exhaustion, causing these models to always show the 'Thinking Budget Exhausted' error instead of attempting continuation. Fix: gate the exhaustion check on _has_think_tags — only trigger the exhaustion path when the model actually produced reasoning blocks (, , , ). Models without think tags now fall through to the normal continuation retry logic (up to 3 attempts). Fixes #7729 --- run_agent.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/run_agent.py b/run_agent.py index e56e23b7d..21a896063 100644 --- a/run_agent.py +++ b/run_agent.py @@ -8277,8 +8277,24 @@ class AIAgent: _text_parts.append(getattr(_blk, "text", "")) _trunc_content = "\n".join(_text_parts) if _text_parts else None + # A response is "thinking exhausted" only when the model + # actually produced reasoning blocks but no visible text after + # them. Models that do not use tags (e.g. GLM-4.7 on + # NVIDIA Build, minimax) may return content=None or an empty + # string for unrelated reasons — treat those as normal + # truncations that deserve continuation retries, not as + # thinking-budget exhaustion. + _has_think_tags = bool( + _trunc_content and re.search( + r'<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)[^>]*>', + _trunc_content, + re.IGNORECASE, + ) + ) _thinking_exhausted = ( - not _trunc_has_tool_calls and ( + not _trunc_has_tool_calls + and _has_think_tags + and ( (_trunc_content is not None and not self._has_content_after_think_block(_trunc_content)) or _trunc_content is None ) From 2d328d5c7095baf05fcd078e0f6ff53279fe71db Mon Sep 17 00:00:00 2001 From: konsisumer <11262660+konsisumer@users.noreply.github.com> Date: Sat, 11 Apr 2026 11:59:00 -0700 Subject: [PATCH 05/35] fix(gateway): break stuck session resume loops on restart (#7536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-picked from PR #7747 with follow-up fixes: - Narrowed suspend_all_active() to suspend_recently_active() — only suspends sessions updated within the last 2 minutes (likely in-flight), not all sessions which would unnecessarily reset idle users - /stop with no running agent no longer suspends the session; only actual force-stops mark the session for reset --- gateway/run.py | 46 +++++++++++++++++++++++++++++++--------- gateway/session.py | 52 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index 2bd493005..469abe9ec 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1465,7 +1465,18 @@ class GatewayRunner: logger.info("Recovered %s background process(es) from previous run", recovered) except Exception as e: logger.warning("Process checkpoint recovery: %s", e) - + + # Suspend sessions that were active when the gateway last exited. + # This prevents stuck sessions from being blindly resumed on restart, + # which can create an unrecoverable loop (#7536). Suspended sessions + # auto-reset on the next incoming message, giving the user a clean start. + try: + suspended = self.session_store.suspend_recently_active() + if suspended: + logger.info("Suspended %d in-flight session(s) from previous run", suspended) + except Exception as e: + logger.warning("Session suspension on startup failed: %s", e) + connected_count = 0 enabled_platform_count = 0 startup_nonretryable_errors: list[str] = [] @@ -2377,8 +2388,11 @@ class GatewayRunner: self._pending_messages.pop(_quick_key, None) if _quick_key in self._running_agents: del self._running_agents[_quick_key] - logger.info("HARD STOP for session %s — session lock released", _quick_key[:20]) - return "⚡ Force-stopped. The session is unlocked — you can send a new message." + # Mark session suspended so the next message starts fresh + # instead of resuming the stuck context (#7536). + self.session_store.suspend_session(_quick_key) + logger.info("HARD STOP for session %s — suspended, session lock released", _quick_key[:20]) + return "⚡ Force-stopped. The session is suspended — your next message will start fresh." # /reset and /new must bypass the running-agent guard so they # actually dispatch as commands instead of being queued as user @@ -2812,7 +2826,9 @@ class GatewayRunner: # so the agent knows this is a fresh conversation (not an intentional /reset). if getattr(session_entry, 'was_auto_reset', False): reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle' - if reset_reason == "daily": + if reset_reason == "suspended": + context_note = "[System note: The user's previous session was stopped and suspended. This is a fresh conversation with no prior context.]" + elif reset_reason == "daily": context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]" else: context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]" @@ -2829,7 +2845,9 @@ class GatewayRunner: ) platform_name = source.platform.value if source.platform else "" had_activity = getattr(session_entry, 'reset_had_activity', False) - should_notify = ( + # Suspended sessions always notify (they were explicitly stopped + # or crashed mid-operation) — skip the policy check. + should_notify = reset_reason == "suspended" or ( policy.notify and had_activity and platform_name not in policy.notify_exclude_platforms @@ -2837,7 +2855,9 @@ class GatewayRunner: if should_notify: adapter = self.adapters.get(source.platform) if adapter: - if reset_reason == "daily": + if reset_reason == "suspended": + reason_text = "previous session was stopped or interrupted" + elif reset_reason == "daily": reason_text = f"daily schedule at {policy.at_hour}:00" else: hours = policy.idle_minutes // 60 @@ -3920,25 +3940,31 @@ class GatewayRunner: handles /stop before this method is reached. This handler fires only through normal command dispatch (no running agent) or as a fallback. Force-clean the session lock in all cases for safety. + + When there IS a running/pending agent, the session is also marked + as *suspended* so the next message starts a fresh session instead + of resuming the stuck context (#7536). """ source = event.source session_entry = self.session_store.get_or_create_session(source) session_key = session_entry.session_key - + agent = self._running_agents.get(session_key) if agent is _AGENT_PENDING_SENTINEL: # Force-clean the sentinel so the session is unlocked. if session_key in self._running_agents: del self._running_agents[session_key] - logger.info("HARD STOP (pending) for session %s — sentinel cleared", session_key[:20]) - return "⚡ Force-stopped. The agent was still starting — session unlocked." + self.session_store.suspend_session(session_key) + logger.info("HARD STOP (pending) for session %s — suspended, sentinel cleared", session_key[:20]) + return "⚡ Force-stopped. The agent was still starting — your next message will start fresh." if agent: agent.interrupt("Stop requested") # Force-clean the session lock so a truly hung agent doesn't # keep it locked forever. if session_key in self._running_agents: del self._running_agents[session_key] - return "⚡ Force-stopped. The session is unlocked — you can send a new message." + self.session_store.suspend_session(session_key) + return "⚡ Force-stopped. Your next message will start a fresh session." else: return "No active task to stop." diff --git a/gateway/session.py b/gateway/session.py index 2b32c1889..96013df51 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -368,6 +368,11 @@ class SessionEntry: # survives gateway restarts (the old in-memory _pre_flushed_sessions # set was lost on restart, causing redundant re-flushes). memory_flushed: bool = False + + # When True the next call to get_or_create_session() will auto-reset + # this session (create a new session_id) so the user starts fresh. + # Set by /stop to break stuck-resume loops (#7536). + suspended: bool = False def to_dict(self) -> Dict[str, Any]: result = { @@ -387,6 +392,7 @@ class SessionEntry: "estimated_cost_usd": self.estimated_cost_usd, "cost_status": self.cost_status, "memory_flushed": self.memory_flushed, + "suspended": self.suspended, } if self.origin: result["origin"] = self.origin.to_dict() @@ -423,6 +429,7 @@ class SessionEntry: estimated_cost_usd=data.get("estimated_cost_usd", 0.0), cost_status=data.get("cost_status", "unknown"), memory_flushed=data.get("memory_flushed", False), + suspended=data.get("suspended", False), ) @@ -698,7 +705,12 @@ class SessionStore: if session_key in self._entries and not force_new: entry = self._entries[session_key] - reset_reason = self._should_reset(entry, source) + # Auto-reset sessions marked as suspended (e.g. after /stop + # broke a stuck loop — #7536). + if entry.suspended: + reset_reason = "suspended" + else: + reset_reason = self._should_reset(entry, source) if not reset_reason: entry.updated_at = now self._save() @@ -771,6 +783,44 @@ class SessionStore: entry.last_prompt_tokens = last_prompt_tokens self._save() + def suspend_session(self, session_key: str) -> bool: + """Mark a session as suspended so it auto-resets on next access. + + Used by ``/stop`` to prevent stuck sessions from being resumed + after a gateway restart (#7536). Returns True if the session + existed and was marked. + """ + with self._lock: + self._ensure_loaded_locked() + if session_key in self._entries: + self._entries[session_key].suspended = True + self._save() + return True + return False + + def suspend_recently_active(self, max_age_seconds: int = 120) -> int: + """Mark recently-active sessions as suspended. + + Called on gateway startup to prevent sessions that were likely + in-flight when the gateway last exited from being blindly resumed + (#7536). Only suspends sessions updated within *max_age_seconds* + to avoid resetting long-idle sessions that are harmless to resume. + Returns the number of sessions that were suspended. + """ + import time as _time + + cutoff = _time.time() - max_age_seconds + count = 0 + with self._lock: + self._ensure_loaded_locked() + for entry in self._entries.values(): + if not entry.suspended and entry.updated_at >= cutoff: + entry.suspended = True + count += 1 + if count: + self._save() + return count + def reset_session(self, session_key: str) -> Optional[SessionEntry]: """Force reset a session, creating a new session ID.""" db_end_session_id = None From 59e630a64d5b6ca94f2caa310ca531f18fa5939d Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 12:21:14 -0700 Subject: [PATCH 06/35] fix: update thinking-exhaustion test for think-tag gating The test expected content=None to immediately trigger thinking-exhaustion, but PR #7738 correctly gates that check on _has_think_tags. Without think tags, the agent falls through to normal continuation retry (3 attempts). --- tests/run_agent/test_run_agent.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 9851939ae..61137fe90 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2087,8 +2087,9 @@ class TestRunConversation: assert "Thinking Budget Exhausted" in result["final_response"] assert "/thinkon" in result["final_response"] - def test_length_empty_content_detected_as_thinking_exhausted(self, agent): - """When finish_reason='length' and content is None/empty, detect exhaustion.""" + def test_length_empty_content_without_think_tags_retries_normally(self, agent): + """When finish_reason='length' and content is None but no think tags, + fall through to normal continuation retry (not thinking-exhaustion).""" self._setup_agent(agent) resp = _mock_response(content=None, finish_reason="length") agent.client.chat.completions.create.return_value = resp @@ -2100,12 +2101,10 @@ class TestRunConversation: ): result = agent.run_conversation("hello") + # Without think tags, the agent should attempt continuation retries + # (up to 3), not immediately fire thinking-exhaustion. + assert result["api_calls"] == 3 assert result["completed"] is False - assert result["api_calls"] == 1 - assert "reasoning" in result["error"].lower() - # User-friendly message is returned - assert result["final_response"] is not None - assert "Thinking Budget Exhausted" in result["final_response"] def test_length_with_tool_calls_returns_partial_without_executing_tools(self, agent): self._setup_agent(agent) From f4f4078ad9dc76b96048ff23c9d9ee4265d06198 Mon Sep 17 00:00:00 2001 From: WAXLYY Date: Sat, 11 Apr 2026 19:38:32 +0300 Subject: [PATCH 07/35] fix(gateway/weixin): ensure atomic persistence for critical session state --- gateway/platforms/weixin.py | 7 ++-- tests/gateway/test_weixin.py | 68 +++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index e25bb350f..5e0208c77 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -63,6 +63,7 @@ from gateway.platforms.base import ( cache_image_from_bytes, ) from hermes_constants import get_hermes_home +from utils import atomic_json_write ILINK_BASE_URL = "https://ilinkai.weixin.qq.com" WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c" @@ -206,7 +207,7 @@ def save_weixin_account( "saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), } path = _account_file(hermes_home, account_id) - path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + atomic_json_write(path, payload) try: path.chmod(0o600) except OSError: @@ -269,7 +270,7 @@ class ContextTokenStore: if key.startswith(prefix) } try: - self._path(account_id).write_text(json.dumps(payload), encoding="utf-8") + atomic_json_write(self._path(account_id), payload) except Exception as exc: logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc) @@ -868,7 +869,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str: def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None: path = _sync_buf_path(hermes_home, account_id) - path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8") + atomic_json_write(path, {"get_updates_buf": sync_buf}) async def qr_login( diff --git a/tests/gateway/test_weixin.py b/tests/gateway/test_weixin.py index caf4a7eba..815ea75ef 100644 --- a/tests/gateway/test_weixin.py +++ b/tests/gateway/test_weixin.py @@ -1,12 +1,14 @@ """Tests for the Weixin platform adapter.""" import asyncio +import json import os from unittest.mock import AsyncMock, patch from gateway.config import PlatformConfig from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides -from gateway.platforms.weixin import WeixinAdapter +from gateway.platforms import weixin +from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter from tools.send_message_tool import _parse_target_ref, _send_to_platform @@ -187,6 +189,70 @@ class TestWeixinConfig: assert config.get_connected_platforms() == [] +class TestWeixinStatePersistence: + def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + account_path = tmp_path / "weixin" / "accounts" / "acct.json" + account_path.parent.mkdir(parents=True, exist_ok=True) + original = {"token": "old-token", "base_url": "https://old.example.com"} + account_path.write_text(json.dumps(original), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + try: + weixin.save_weixin_account( + str(tmp_path), + account_id="acct", + token="new-token", + base_url="https://new.example.com", + user_id="wxid_new", + ) + except OSError: + pass + else: + raise AssertionError("expected save_weixin_account to propagate replace failure") + + assert json.loads(account_path.read_text(encoding="utf-8")) == original + + def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json" + token_path.parent.mkdir(parents=True, exist_ok=True) + token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + store = ContextTokenStore(str(tmp_path)) + with patch.object(weixin.logger, "warning") as warning_mock: + store.set("acct", "user-b", "new-token") + + assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"} + warning_mock.assert_called_once() + + def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json" + sync_path.parent.mkdir(parents=True, exist_ok=True) + sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + try: + weixin._save_sync_buf(str(tmp_path), "acct", "new-sync") + except OSError: + pass + else: + raise AssertionError("expected _save_sync_buf to propagate replace failure") + + assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"} + + class TestWeixinSendMessageIntegration: def test_parse_target_ref_accepts_weixin_ids(self): assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True) From cf53e2676b64e94e1f027e426f70f8a6cecc0949 Mon Sep 17 00:00:00 2001 From: dalianmao000 Date: Sat, 11 Apr 2026 22:56:37 +0800 Subject: [PATCH 08/35] fix(wecom): handle appmsg attachments (PDF/Word/Excel) from WeCom AI Bot WeCom AI Bot sends file attachments with msgtype="appmsg", not msgtype="file". Previously only file content was discarded while the text title reached the agent. Changes: - _extract_text(): Extract appmsg title (filename) for display - _extract_media(): Handle appmsg type with file/image content Fixes #7750 Co-Authored-By: Claude Opus 4.6 --- gateway/platforms/wecom.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py index 6fde73927..aa07dc6a9 100644 --- a/gateway/platforms/wecom.py +++ b/gateway/platforms/wecom.py @@ -636,6 +636,13 @@ class WeComAdapter(BasePlatformAdapter): if voice_text: text_parts.append(voice_text) + # Extract appmsg title (filename) for WeCom AI Bot attachments + if msgtype == "appmsg": + appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {} + title = str(appmsg.get("title") or "").strip() + if title: + text_parts.append(title) + quote = body.get("quote") if isinstance(body.get("quote"), dict) else {} quote_type = str(quote.get("msgtype") or "").lower() if quote_type == "text": @@ -668,6 +675,13 @@ class WeComAdapter(BasePlatformAdapter): refs.append(("image", body["image"])) if msgtype == "file" and isinstance(body.get("file"), dict): refs.append(("file", body["file"])) + # Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel) + if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict): + appmsg = body["appmsg"] + if isinstance(appmsg.get("file"), dict): + refs.append(("file", appmsg["file"])) + elif isinstance(appmsg.get("image"), dict): + refs.append(("image", appmsg["image"])) quote = body.get("quote") if isinstance(body.get("quote"), dict) else {} quote_type = str(quote.get("msgtype") or "").lower() From 04c1c5d53f555b9798d180802a288f0566f9acd7 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 13:59:52 -0700 Subject: [PATCH 09/35] refactor: extract shared helpers to deduplicate repeated code patterns (#7917) * refactor: add shared helper modules for code deduplication New modules: - gateway/platforms/helpers.py: MessageDeduplicator, TextBatchAggregator, strip_markdown, ThreadParticipationTracker, redact_phone - hermes_cli/cli_output.py: print_info/success/warning/error, prompt helpers - tools/path_security.py: validate_within_dir, has_traversal_component - utils.py additions: safe_json_loads, read_json_file, read_jsonl, append_jsonl, env_str/lower/int/bool helpers - hermes_constants.py additions: get_config_path, get_skills_dir, get_logs_dir, get_env_path * refactor: migrate gateway adapters to shared helpers - MessageDeduplicator: discord, slack, dingtalk, wecom, weixin, mattermost - strip_markdown: bluebubbles, feishu, sms - redact_phone: sms, signal - ThreadParticipationTracker: discord, matrix - _acquire/_release_platform_lock: telegram, discord, slack, whatsapp, signal, weixin Net -316 lines across 19 files. * refactor: migrate CLI modules to shared helpers - tools_config.py: use cli_output print/prompt + curses_radiolist (-117 lines) - setup.py: use cli_output print helpers + curses_radiolist (-101 lines) - mcp_config.py: use cli_output prompt (-15 lines) - memory_setup.py: use curses_radiolist (-86 lines) Net -263 lines across 5 files. * refactor: migrate to shared utility helpers - safe_json_loads: agent/display.py (4 sites) - get_config_path: skill_utils.py, hermes_logging.py, hermes_time.py - get_skills_dir: skill_utils.py, prompt_builder.py - Token estimation dedup: skills_tool.py imports from model_metadata - Path security: skills_tool, cronjob_tools, skill_manager_tool, credential_files - Non-atomic YAML writes: doctor.py, config.py now use atomic_yaml_write - Platform dict: new platforms.py, skills_config + tools_config derive from it - Anthropic key: new get_anthropic_key() in auth.py, used by doctor/status/config/main * test: update tests for shared helper migrations - test_dingtalk: use _dedup.is_duplicate() instead of _is_duplicate() - test_mattermost: use _dedup instead of _seen_posts/_prune_seen - test_signal: import redact_phone from helpers instead of signal - test_discord_connect: _platform_lock_identity instead of _token_lock_identity - test_telegram_conflict: updated lock error message format - test_skill_manager_tool: 'escapes' instead of 'boundary' in error msgs --- agent/display.py | 25 +- agent/prompt_builder.py | 5 +- agent/skill_utils.py | 12 +- gateway/platforms/base.py | 31 ++- gateway/platforms/bluebubbles.py | 18 +- gateway/platforms/dingtalk.py | 25 +- gateway/platforms/discord.py | 111 +------- gateway/platforms/feishu.py | 16 +- gateway/platforms/helpers.py | 261 ++++++++++++++++++ gateway/platforms/matrix.py | 54 +--- gateway/platforms/mattermost.py | 23 +- gateway/platforms/signal.py | 45 +-- gateway/platforms/slack.py | 41 +-- gateway/platforms/sms.py | 37 +-- gateway/platforms/telegram.py | 34 +-- gateway/platforms/wecom.py | 26 +- gateway/platforms/weixin.py | 41 +-- gateway/platforms/whatsapp.py | 34 +-- hermes_cli/auth.py | 22 ++ hermes_cli/cli_output.py | 79 ++++++ hermes_cli/config.py | 7 +- hermes_cli/doctor.py | 7 +- hermes_cli/main.py | 9 +- hermes_cli/mcp_config.py | 15 +- hermes_cli/memory_setup.py | 86 +----- hermes_cli/platforms.py | 45 +++ hermes_cli/setup.py | 101 +------ hermes_cli/skills_config.py | 23 +- hermes_cli/status.py | 7 +- hermes_cli/tools_config.py | 142 ++-------- hermes_constants.py | 27 ++ hermes_logging.py | 4 +- hermes_time.py | 5 +- tests/e2e/conftest.py | 3 +- tests/gateway/test_dingtalk.py | 23 +- tests/gateway/test_discord_connect.py | 2 +- tests/gateway/test_discord_free_response.py | 6 +- .../test_discord_thread_persistence.py | 37 +-- tests/gateway/test_matrix_mention.py | 62 ++--- tests/gateway/test_mattermost.py | 22 +- tests/gateway/test_signal.py | 12 +- tests/gateway/test_telegram_conflict.py | 14 +- tests/tools/test_skill_manager_tool.py | 6 +- tools/credential_files.py | 36 ++- tools/cronjob_tools.py | 8 +- tools/path_security.py | 43 +++ tools/skill_manager_tool.py | 17 +- tools/skills_tool.py | 37 +-- utils.py | 90 +++++- 49 files changed, 887 insertions(+), 949 deletions(-) create mode 100644 gateway/platforms/helpers.py create mode 100644 hermes_cli/cli_output.py create mode 100644 hermes_cli/platforms.py create mode 100644 tools/path_security.py diff --git a/agent/display.py b/agent/display.py index 604b7a298..182064576 100644 --- a/agent/display.py +++ b/agent/display.py @@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency. Used by AIAgent._execute_tool_calls for CLI feedback. """ -import json import logging import os import sys @@ -14,6 +13,8 @@ from dataclasses import dataclass, field from difflib import unified_diff from pathlib import Path +from utils import safe_json_loads + # ANSI escape codes for coloring tool failure indicators _RED = "\033[31m" _RESET = "\033[0m" @@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool: """Conservatively detect whether a tool result represents success.""" if not result: return False - try: - data = json.loads(result) - except (json.JSONDecodeError, TypeError): + data = safe_json_loads(result) + if data is None: return False if not isinstance(data, dict): return False @@ -423,10 +423,7 @@ def extract_edit_diff( ) -> str | None: """Extract a unified diff from a file-edit tool result.""" if tool_name == "patch" and result: - try: - data = json.loads(result) - except (json.JSONDecodeError, TypeError): - data = None + data = safe_json_loads(result) if isinstance(data, dict): diff = data.get("diff") if isinstance(diff, str) and diff.strip(): @@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str] return False, "" if tool_name == "terminal": - try: - data = json.loads(result) + data = safe_json_loads(result) + if isinstance(data, dict): exit_code = data.get("exit_code") if exit_code is not None and exit_code != 0: return True, f" [exit {exit_code}]" - except (json.JSONDecodeError, TypeError, AttributeError): - logger.debug("Could not parse terminal result as JSON for exit code check") return False, "" # Memory-specific: distinguish "full" from real errors if tool_name == "memory": - try: - data = json.loads(result) + data = safe_json_loads(result) + if isinstance(data, dict): if data.get("success") is False and "exceed the limit" in data.get("error", ""): return True, " [full]" - except (json.JSONDecodeError, TypeError, AttributeError): - logger.debug("Could not parse memory result as JSON for capacity check") # Generic heuristic for non-terminal tools lower = result[:500].lower() diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 08b8fe0a6..26d913a02 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -12,7 +12,7 @@ import threading from collections import OrderedDict from pathlib import Path -from hermes_constants import get_hermes_home +from hermes_constants import get_hermes_home, get_skills_dir from typing import Optional from agent.skill_utils import ( @@ -548,8 +548,7 @@ def build_skills_system_prompt( are read-only — they appear in the index but new skills are always created in the local dir. Local skills take precedence when names collide. """ - hermes_home = get_hermes_home() - skills_dir = hermes_home / "skills" + skills_dir = get_skills_dir() external_dirs = get_all_skills_dirs()[1:] # skip local (index 0) if not skills_dir.exists() and not external_dirs: diff --git a/agent/skill_utils.py b/agent/skill_utils.py index ba606b358..97ba92b73 100644 --- a/agent/skill_utils.py +++ b/agent/skill_utils.py @@ -12,7 +12,7 @@ import sys from pathlib import Path from typing import Any, Dict, List, Set, Tuple -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path, get_skills_dir logger = logging.getLogger(__name__) @@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]: Reads the config file directly (no CLI config imports) to stay lightweight. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if not config_path.exists(): return set() try: @@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]: path. Only directories that actually exist are returned. Duplicates and paths that resolve to the local ``~/.hermes/skills/`` are silently skipped. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if not config_path.exists(): return [] try: @@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]: if not isinstance(raw_dirs, list): return [] - local_skills = (get_hermes_home() / "skills").resolve() + local_skills = get_skills_dir().resolve() seen: Set[Path] = set() result: List[Path] = [] @@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]: The local dir is always first (and always included even if it doesn't exist yet — callers handle that). External dirs follow in config order. """ - dirs = [get_hermes_home() / "skills"] + dirs = [get_skills_dir()] dirs.extend(get_external_skills_dirs()) return dirs @@ -384,7 +384,7 @@ def resolve_skill_config_values( current values (or the declared default if the key isn't set). Path values are expanded via ``os.path.expanduser``. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() config: Dict[str, Any] = {} if config_path.exists(): try: diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 04f0c1deb..352aecb33 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC): result = handler(self) if asyncio.iscoroutine(result): await result - + + def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool: + """Acquire a scoped lock for this adapter. Returns True on success.""" + from gateway.status import acquire_scoped_lock + self._platform_lock_scope = scope + self._platform_lock_identity = identity + acquired, existing = acquire_scoped_lock( + scope, identity, metadata={'platform': self.platform.value} + ) + if acquired: + return True + owner_pid = existing.get('pid') if isinstance(existing, dict) else None + message = ( + f'{resource_desc} already in use' + + (f' (PID {owner_pid})' if owner_pid else '') + + '. Stop the other gateway first.' + ) + logger.error('[%s] %s', self.name, message) + self._set_fatal_error(f'{scope}_lock', message, retryable=False) + return False + + def _release_platform_lock(self) -> None: + """Release the scoped lock acquired by _acquire_platform_lock.""" + identity = getattr(self, '_platform_lock_identity', None) + if not identity: + return + from gateway.status import release_scoped_lock + release_scoped_lock(self._platform_lock_scope, identity) + self._platform_lock_identity = None + @property def name(self) -> str: """Human-readable name for this adapter.""" diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py index f50cd9503..115000996 100644 --- a/gateway/platforms/bluebubbles.py +++ b/gateway/platforms/bluebubbles.py @@ -30,6 +30,7 @@ from gateway.platforms.base import ( cache_audio_from_bytes, cache_document_from_bytes, ) +from gateway.platforms.helpers import strip_markdown logger = logging.getLogger(__name__) @@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str: return value.rstrip("/") -def _strip_markdown(text: str) -> str: - """Strip common markdown formatting for iMessage plain-text delivery.""" - text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL) - text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL) - text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text) - text = re.sub(r"`(.+?)`", r"\1", text) - text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) - text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text) - text = re.sub(r"\n{3,}", "\n\n", text) - return text.strip() + # --------------------------------------------------------------------------- @@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: - text = _strip_markdown(content or "") + text = strip_markdown(content or "") if not text: return SendResult(success=False, error="BlueBubbles send requires text") chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH) @@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): return info def format_message(self, content: str) -> str: - return _strip_markdown(content) + return strip_markdown(content) # ------------------------------------------------------------------ # Inbound attachment downloading (from #4588) diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py index e83b902df..5d50deca5 100644 --- a/gateway/platforms/dingtalk.py +++ b/gateway/platforms/dingtalk.py @@ -42,6 +42,7 @@ except ImportError: httpx = None # type: ignore[assignment] from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -52,8 +53,6 @@ from gateway.platforms.base import ( logger = logging.getLogger(__name__) MAX_MESSAGE_LENGTH = 20000 -DEDUP_WINDOW_SECONDS = 300 -DEDUP_MAX_SIZE = 1000 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] _SESSION_WEBHOOKS_MAX = 500 _DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/') @@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter): self._stream_task: Optional[asyncio.Task] = None self._http_client: Optional["httpx.AsyncClient"] = None - # Message deduplication: msg_id -> timestamp - self._seen_messages: Dict[str, float] = {} + # Message deduplication + self._dedup = MessageDeduplicator(max_size=1000) # Map chat_id -> session_webhook for reply routing self._session_webhooks: Dict[str, str] = {} @@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter): self._stream_client = None self._session_webhooks.clear() - self._seen_messages.clear() + self._dedup.clear() logger.info("[%s] Disconnected", self.name) # -- Inbound message processing ----------------------------------------- @@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter): async def _on_message(self, message: "ChatbotMessage") -> None: """Process an incoming DingTalk chatbot message.""" msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex - if self._is_duplicate(msg_id): + if self._dedup.is_duplicate(msg_id): logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id) return @@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter): content = " ".join(parts).strip() return content - # -- Deduplication ------------------------------------------------------ - - def _is_duplicate(self, msg_id: str) -> bool: - """Check and record a message ID. Returns True if already seen.""" - now = time.time() - if len(self._seen_messages) > DEDUP_MAX_SIZE: - cutoff = now - DEDUP_WINDOW_SECONDS - self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff} - - if msg_id in self._seen_messages: - return True - self._seen_messages[msg_id] = now - return False - # -- Outbound messaging ------------------------------------------------- async def send( diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index dcf05a162..b1d07e5d6 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig import re +from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter): # Track threads where the bot has participated so follow-up messages # in those threads don't require @mention. Persisted to disk so the # set survives gateway restarts. - self._bot_participated_threads: set = self._load_participated_threads() + self._threads = ThreadParticipationTracker("discord") # Persistent typing indicator loops per channel (DMs don't reliably # show the standard typing gateway event for bots) self._typing_tasks: Dict[str, asyncio.Task] = {} self._bot_task: Optional[asyncio.Task] = None - # Cap to prevent unbounded growth (Discord threads get archived). - self._MAX_TRACKED_THREADS = 500 - # Dedup cache: message_id → timestamp. Prevents duplicate bot - # responses when Discord RESUME replays events after reconnects. - self._seen_messages: Dict[str, float] = {} - self._SEEN_TTL = 300 # 5 minutes - self._SEEN_MAX = 2000 # prune threshold + # Dedup cache: prevents duplicate bot responses when Discord + # RESUME replays events after reconnects. + self._dedup = MessageDeduplicator() # Reply threading mode: "off" (no replies), "first" (reply on first # chunk only, default), "all" (reply-reference on every chunk). self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first' @@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter): return False try: - # Acquire scoped lock to prevent duplicate bot token usage - from gateway.status import acquire_scoped_lock - self._token_lock_identity = self.config.token - acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'}) - if not acquired: - owner_pid = existing.get('pid') if isinstance(existing, dict) else None - message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' - logger.error('[%s] %s', self.name, message) - self._set_fatal_error('discord_token_lock', message, retryable=False) + if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'): return False - # Parse allowed user entries (may contain usernames or IDs) allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "") if allowed_env: @@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter): @self._client.event async def on_message(message: DiscordMessage): # Dedup: Discord RESUME replays events after reconnects (#4777) - msg_id = str(message.id) - now = time.time() - if msg_id in adapter_self._seen_messages: + if adapter_self._dedup.is_duplicate(str(message.id)): return - adapter_self._seen_messages[msg_id] = now - if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX: - cutoff = now - adapter_self._SEEN_TTL - adapter_self._seen_messages = { - k: v for k, v in adapter_self._seen_messages.items() - if v > cutoff - } # Always ignore our own messages if message.author == self._client.user: @@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter): except asyncio.TimeoutError: logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() return False except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() return False async def disconnect(self) -> None: @@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter): self._client = None self._ready_event.clear() - # Release the token lock - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() logger.info("[%s] Disconnected", self.name) @@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so follow-ups don't require @mention if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) # If a message was provided, kick off a new Hermes session in the thread starter = (message or "").strip() @@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter): return f"{parent_name} / {thread_name}" return thread_name - # ------------------------------------------------------------------ - # Thread participation persistence - # ------------------------------------------------------------------ - - @staticmethod - def _thread_state_path() -> Path: - """Path to the persisted thread participation set.""" - from hermes_cli.config import get_hermes_home - return get_hermes_home() / "discord_threads.json" - - @classmethod - def _load_participated_threads(cls) -> set: - """Load persisted thread IDs from disk.""" - path = cls._thread_state_path() - try: - if path.exists(): - data = json.loads(path.read_text(encoding="utf-8")) - if isinstance(data, list): - return set(data) - except Exception as e: - logger.debug("Could not load discord thread state: %s", e) - return set() - - def _save_participated_threads(self) -> None: - """Persist the current thread set to disk (best-effort).""" - path = self._thread_state_path() - try: - # Trim to most recent entries if over cap - thread_list = list(self._bot_participated_threads) - if len(thread_list) > self._MAX_TRACKED_THREADS: - thread_list = thread_list[-self._MAX_TRACKED_THREADS:] - self._bot_participated_threads = set(thread_list) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(thread_list), encoding="utf-8") - except Exception as e: - logger.debug("Could not save discord thread state: %s", e) - - def _track_thread(self, thread_id: str) -> None: - """Add a thread to the participation set and persist.""" - if thread_id not in self._bot_participated_threads: - self._bot_participated_threads.add(thread_id) - self._save_participated_threads() - async def _handle_message(self, message: DiscordMessage) -> None: """Handle incoming Discord messages.""" # In server channels (not DMs), require the bot to be @mentioned @@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter): # Skip the mention check if the message is in a thread where # the bot has previously participated (auto-created or replied in). - in_bot_thread = is_thread and thread_id in self._bot_participated_threads + in_bot_thread = is_thread and thread_id in self._threads if require_mention and not is_free_channel and not in_bot_thread: if self._client.user not in message.mentions: @@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter): is_thread = True thread_id = str(thread.id) auto_threaded_channel = thread - self._track_thread(thread_id) + self._threads.mark(thread_id) # Determine message type msg_type = MessageType.TEXT @@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so the bot won't require @mention for # follow-up messages in threads it has already engaged in. if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) # Only batch plain text messages — commands, media, etc. dispatch # immediately since they won't be split by the Discord client. diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py index a88c7e52b..16f5467b2 100644 --- a/gateway/platforms/feishu.py +++ b/gateway/platforms/feishu.py @@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str: def _strip_markdown_to_plain_text(text: str) -> str: + """Strip markdown formatting to plain text for Feishu text fallbacks. + + Delegates common markdown stripping to the shared helper and adds + Feishu-specific patterns (blockquotes, strikethrough, underline tags, + horizontal rules, \\r\\n normalisation). + """ + from gateway.platforms.helpers import strip_markdown plain = text.replace("\r\n", "\n") plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain) - plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE) plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE) plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE) - plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain) - plain = re.sub(r"`([^`\n]+)`", r"\1", plain) - plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain) - plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain) plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain) plain = re.sub(r"([\s\S]*?)", r"\1", plain) - plain = re.sub(r"\n{3,}", "\n\n", plain) - return plain.strip() + plain = strip_markdown(plain) + return plain def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]: diff --git a/gateway/platforms/helpers.py b/gateway/platforms/helpers.py new file mode 100644 index 000000000..c834dd89c --- /dev/null +++ b/gateway/platforms/helpers.py @@ -0,0 +1,261 @@ +"""Shared helper classes for gateway platform adapters. + +Extracts common patterns that were duplicated across 5-7 adapters: +message deduplication, text batch aggregation, markdown stripping, +and thread participation tracking. +""" + +import asyncio +import json +import logging +import re +import time +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from gateway.platforms.base import BasePlatformAdapter, MessageEvent + +logger = logging.getLogger(__name__) + + +# ─── Message Deduplication ──────────────────────────────────────────────────── + + +class MessageDeduplicator: + """TTL-based message deduplication cache. + + Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern + previously duplicated in discord, slack, dingtalk, wecom, weixin, + mattermost, and feishu adapters. + + Usage:: + + self._dedup = MessageDeduplicator() + + # In message handler: + if self._dedup.is_duplicate(msg_id): + return + """ + + def __init__(self, max_size: int = 2000, ttl_seconds: float = 300): + self._seen: Dict[str, float] = {} + self._max_size = max_size + self._ttl = ttl_seconds + + def is_duplicate(self, msg_id: str) -> bool: + """Return True if *msg_id* was already seen within the TTL window.""" + if not msg_id: + return False + now = time.time() + if msg_id in self._seen: + return True + self._seen[msg_id] = now + if len(self._seen) > self._max_size: + cutoff = now - self._ttl + self._seen = {k: v for k, v in self._seen.items() if v > cutoff} + return False + + def clear(self): + """Clear all tracked messages.""" + self._seen.clear() + + +# ─── Text Batch Aggregation ────────────────────────────────────────────────── + + +class TextBatchAggregator: + """Aggregates rapid-fire text events into single messages. + + Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern + previously duplicated in telegram, discord, matrix, wecom, and feishu. + + Usage:: + + self._text_batcher = TextBatchAggregator( + handler=self._message_handler, + batch_delay=0.6, + split_threshold=1900, + ) + + # In message dispatch: + if msg_type == MessageType.TEXT and self._text_batcher.is_enabled(): + self._text_batcher.enqueue(event, session_key) + return + """ + + def __init__( + self, + handler, + *, + batch_delay: float = 0.6, + split_delay: float = 2.0, + split_threshold: int = 4000, + ): + self._handler = handler + self._batch_delay = batch_delay + self._split_delay = split_delay + self._split_threshold = split_threshold + self._pending: Dict[str, "MessageEvent"] = {} + self._pending_tasks: Dict[str, asyncio.Task] = {} + + def is_enabled(self) -> bool: + """Return True if batching is active (delay > 0).""" + return self._batch_delay > 0 + + def enqueue(self, event: "MessageEvent", key: str) -> None: + """Add *event* to the pending batch for *key*.""" + chunk_len = len(event.text or "") + existing = self._pending.get(key) + if not existing: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] + self._pending[key] = event + else: + existing.text = f"{existing.text}\n{event.text}" + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] + + # Cancel prior flush timer, start a new one + prior = self._pending_tasks.get(key) + if prior and not prior.done(): + prior.cancel() + self._pending_tasks[key] = asyncio.create_task(self._flush(key)) + + async def _flush(self, key: str) -> None: + """Wait then dispatch the batched event for *key*.""" + current_task = self._pending_tasks.get(key) + pending = self._pending.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + + # Use longer delay when the last chunk looks like a split message + delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay + await asyncio.sleep(delay) + + event = self._pending.pop(key, None) + if event: + try: + await self._handler(event) + except Exception: + logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key) + + if self._pending_tasks.get(key) is current_task: + self._pending_tasks.pop(key, None) + + def cancel_all(self) -> None: + """Cancel all pending flush tasks.""" + for task in self._pending_tasks.values(): + if not task.done(): + task.cancel() + self._pending_tasks.clear() + self._pending.clear() + + +# ─── Markdown Stripping ────────────────────────────────────────────────────── + +# Pre-compiled regexes for performance +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL) +_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL) +_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL) +_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL) +_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?") +_RE_INLINE_CODE = re.compile(r"`(.+?)`") +_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE) +_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)") +_RE_MULTI_NEWLINE = re.compile(r"\n{3,}") + + +def strip_markdown(text: str) -> str: + """Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.). + + Replaces the identical ``_strip_markdown()`` functions previously + duplicated in sms.py, bluebubbles.py, and feishu.py. + """ + text = _RE_BOLD.sub(r"\1", text) + text = _RE_ITALIC_STAR.sub(r"\1", text) + text = _RE_BOLD_UNDER.sub(r"\1", text) + text = _RE_ITALIC_UNDER.sub(r"\1", text) + text = _RE_CODE_BLOCK.sub("", text) + text = _RE_INLINE_CODE.sub(r"\1", text) + text = _RE_HEADING.sub("", text) + text = _RE_LINK.sub(r"\1", text) + text = _RE_MULTI_NEWLINE.sub("\n\n", text) + return text.strip() + + +# ─── Thread Participation Tracking ─────────────────────────────────────────── + + +class ThreadParticipationTracker: + """Persistent tracking of threads the bot has participated in. + + Replaces the identical ``_load/_save_participated_threads`` + + ``_mark_thread_participated`` pattern previously duplicated in + discord.py and matrix.py. + + Usage:: + + self._threads = ThreadParticipationTracker("discord") + + # Check membership: + if thread_id in self._threads: + ... + + # Mark participation: + self._threads.mark(thread_id) + """ + + _MAX_TRACKED = 500 + + def __init__(self, platform_name: str, max_tracked: int = 500): + self._platform = platform_name + self._max_tracked = max_tracked + self._threads: set = self._load() + + def _state_path(self) -> Path: + from hermes_constants import get_hermes_home + return get_hermes_home() / f"{self._platform}_threads.json" + + def _load(self) -> set: + path = self._state_path() + if path.exists(): + try: + return set(json.loads(path.read_text(encoding="utf-8"))) + except Exception: + pass + return set() + + def _save(self) -> None: + path = self._state_path() + path.parent.mkdir(parents=True, exist_ok=True) + thread_list = list(self._threads) + if len(thread_list) > self._max_tracked: + thread_list = thread_list[-self._max_tracked:] + self._threads = set(thread_list) + path.write_text(json.dumps(thread_list), encoding="utf-8") + + def mark(self, thread_id: str) -> None: + """Mark *thread_id* as participated and persist.""" + if thread_id not in self._threads: + self._threads.add(thread_id) + self._save() + + def __contains__(self, thread_id: str) -> bool: + return thread_id in self._threads + + def clear(self) -> None: + self._threads.clear() + + +# ─── Phone Number Redaction ────────────────────────────────────────────────── + + +def redact_phone(phone: str) -> str: + """Redact a phone number for logging, preserving country code and last 4. + + Replaces the identical ``_redact_phone()`` functions in signal.py, + sms.py, and bluebubbles.py. + """ + if not phone: + return "" + if len(phone) <= 8: + return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" + return phone[:4] + "****" + phone[-4:] diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 7daf2e70e..349f962d2 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -92,6 +92,7 @@ from gateway.platforms.base import ( ProcessingOutcome, SendResult, ) +from gateway.platforms.helpers import ThreadParticipationTracker logger = logging.getLogger(__name__) @@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter): self._pending_megolm: list = [] # Thread participation tracking (for require_mention bypass) - self._bot_participated_threads: set = self._load_participated_threads() - self._MAX_TRACKED_THREADS = 500 + self._threads = ThreadParticipationTracker("matrix") # Mention/thread gating — parsed once from env vars. self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") @@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter): # Require-mention gating. if not is_dm: is_free_room = room_id in self._free_rooms - in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) + in_bot_thread = bool(thread_id and thread_id in self._threads) if self._require_mention and not is_free_room and not in_bot_thread: if not is_mentioned: return None @@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter): # DM mention-thread. if is_dm and not thread_id and self._dm_mention_threads and is_mentioned: thread_id = event_id - self._track_thread(thread_id) + self._threads.mark(thread_id) # Strip mention from body. if is_mentioned: @@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter): # Auto-thread. if not is_dm and not thread_id and self._auto_thread: thread_id = event_id - self._track_thread(thread_id) + self._threads.mark(thread_id) display_name = await self._get_display_name(room_id, sender) source = self.build_source( @@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter): ) if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) self._background_read_receipt(room_id, event_id) @@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter): for rid in self._joined_rooms } - # ------------------------------------------------------------------ - # Thread participation tracking - # ------------------------------------------------------------------ - - @staticmethod - def _thread_state_path() -> Path: - """Path to the persisted thread participation set.""" - from hermes_cli.config import get_hermes_home - return get_hermes_home() / "matrix_threads.json" - - @classmethod - def _load_participated_threads(cls) -> set: - """Load persisted thread IDs from disk.""" - path = cls._thread_state_path() - try: - if path.exists(): - data = json.loads(path.read_text(encoding="utf-8")) - if isinstance(data, list): - return set(data) - except Exception as e: - logger.debug("Could not load matrix thread state: %s", e) - return set() - - def _save_participated_threads(self) -> None: - """Persist the current thread set to disk (best-effort).""" - path = self._thread_state_path() - try: - thread_list = list(self._bot_participated_threads) - if len(thread_list) > self._MAX_TRACKED_THREADS: - thread_list = thread_list[-self._MAX_TRACKED_THREADS:] - self._bot_participated_threads = set(thread_list) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(thread_list), encoding="utf-8") - except Exception as e: - logger.debug("Could not save matrix thread state: %s", e) - - def _track_thread(self, thread_id: str) -> None: - """Add a thread to the participation set and persist.""" - if thread_id not in self._bot_participated_threads: - self._bot_participated_threads.add(thread_id) - self._save_participated_threads() - # ------------------------------------------------------------------ # Mention detection helpers # ------------------------------------------------------------------ diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py index 56f29e876..23a86f02b 100644 --- a/gateway/platforms/mattermost.py +++ b/gateway/platforms/mattermost.py @@ -18,11 +18,11 @@ import json import logging import os import re -import time from pathlib import Path from typing import Any, Dict, List, Optional from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter): or os.getenv("MATTERMOST_REPLY_MODE", "off") ).lower() - # Dedup cache: post_id → timestamp (prevent reprocessing) - self._seen_posts: Dict[str, float] = {} - self._SEEN_MAX = 2000 - self._SEEN_TTL = 300 # 5 minutes + # Dedup cache (prevent reprocessing) + self._dedup = MessageDeduplicator() # ------------------------------------------------------------------ # HTTP helpers @@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter): post_id = post.get("id", "") # Dedup. - self._prune_seen() - if post_id in self._seen_posts: + if self._dedup.is_duplicate(post_id): return - self._seen_posts[post_id] = time.time() # Build message event. channel_id = post.get("channel_id", "") @@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter): await self.handle_message(msg_event) - def _prune_seen(self) -> None: - """Remove expired entries from the dedup cache.""" - if len(self._seen_posts) < self._SEEN_MAX: - return - now = time.time() - self._seen_posts = { - pid: ts - for pid, ts in self._seen_posts.items() - if now - ts < self._SEEN_TTL - } + diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 08b62f2a6..8ef7bd0d6 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -37,6 +37,7 @@ from gateway.platforms.base import ( cache_document_from_bytes, cache_image_from_url, ) +from gateway.platforms.helpers import redact_phone logger = logging.getLogger(__name__) @@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0 HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern -# E.164 phone number pattern for redaction -_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +155****4567.""" - if not phone: - return "" - if len(phone) <= 8: - return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" - return phone[:4] + "****" + phone[-4:] - def _parse_comma_list(value: str) -> List[str]: """Split a comma-separated string into a list, stripping whitespace.""" @@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter): self._recent_sent_timestamps: set = set() self._max_recent_timestamps = 50 - self._phone_lock_identity: Optional[str] = None - logger.info("Signal adapter initialized: url=%s account=%s groups=%s", - self.http_url, _redact_phone(self.account), + self.http_url, redact_phone(self.account), "enabled" if self.group_allow_from else "disabled") # ------------------------------------------------------------------ @@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter): # Acquire scoped lock to prevent duplicate Signal listeners for the same phone try: - from gateway.status import acquire_scoped_lock - - self._phone_lock_identity = self.account - acquired, existing = acquire_scoped_lock( - "signal-phone", - self._phone_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Signal account" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Signal listener." - ) - logger.error("Signal: %s", message) - self._set_fatal_error("signal_phone_lock", message, retryable=False) + if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'): return False except Exception as e: logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e) @@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter): await self.client.aclose() self.client = None - if self._phone_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("signal-phone", self._phone_lock_identity) - except Exception as e: - logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True) - self._phone_lock_identity = None + self._release_platform_lock() logger.info("Signal: disconnected") @@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter): ) logger.debug("Signal: message from %s in %s: %s", - _redact_phone(sender), chat_id[:20], (text or "")[:50]) + redact_phone(sender), chat_id[:20], (text or "")[:50]) await self.handle_message(event) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 361f74882..8f9934cf7 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -33,6 +33,7 @@ from pathlib import Path as _Path sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter): self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id self._channel_team: Dict[str, str] = {} # channel_id → team_id - # Dedup cache: event_ts → timestamp. Prevents duplicate bot - # responses when Socket Mode reconnects redeliver events. - self._seen_messages: Dict[str, float] = {} - self._SEEN_TTL = 300 # 5 minutes - self._SEEN_MAX = 2000 # prune threshold + # Dedup cache: prevents duplicate bot responses when Socket Mode + # reconnects redeliver events. + self._dedup = MessageDeduplicator() # Track pending approval message_ts → resolved flag to prevent # double-clicks on approval buttons. self._approval_resolved: Dict[str, bool] = {} @@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter): logger.warning("[Slack] Failed to read %s: %s", tokens_file, e) try: - # Acquire scoped lock to prevent duplicate app token usage - from gateway.status import acquire_scoped_lock - self._token_lock_identity = app_token - acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'}) - if not acquired: - owner_pid = existing.get('pid') if isinstance(existing, dict) else None - message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' - logger.error('[%s] %s', self.name, message) - self._set_fatal_error('slack_token_lock', message, retryable=False) + if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'): return False # First token is the primary — used for AsyncApp / Socket Mode @@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter): logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True) self._running = False - # Release the token lock (use stored identity, not re-read env) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('slack-app-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() logger.info("[Slack] Disconnected") @@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter): """Handle an incoming Slack message event.""" # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) event_ts = event.get("ts", "") - if event_ts: - now = time.time() - if event_ts in self._seen_messages: - return - self._seen_messages[event_ts] = now - if len(self._seen_messages) > self._SEEN_MAX: - cutoff = now - self._SEEN_TTL - self._seen_messages = { - k: v for k, v in self._seen_messages.items() - if v > cutoff - } + if event_ts and self._dedup.is_duplicate(event_ts): + return # Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots): # "none" — ignore all bot messages (default, backward-compatible) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index a0760199b..953ec5c5e 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -19,7 +19,6 @@ import asyncio import base64 import logging import os -import re import urllib.parse from typing import Any, Dict, Optional @@ -30,6 +29,7 @@ from gateway.platforms.base import ( MessageType, SendResult, ) +from gateway.platforms.helpers import redact_phone, strip_markdown logger = logging.getLogger(__name__) @@ -37,18 +37,6 @@ TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts" MAX_SMS_LENGTH = 1600 # ~10 SMS segments DEFAULT_WEBHOOK_PORT = 8080 -# E.164 phone number pattern for redaction -_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") - - -def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +1555***4567.""" - if not phone: - return "" - if len(phone) <= 8: - return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****" - return phone[:5] + "***" + phone[-4:] - def check_sms_requirements() -> bool: """Check if SMS adapter dependencies are available.""" @@ -114,7 +102,7 @@ class SmsAdapter(BasePlatformAdapter): logger.info( "[sms] Twilio webhook server listening on port %d, from: %s", self._webhook_port, - _redact_phone(self._from_number), + redact_phone(self._from_number), ) return True @@ -163,7 +151,7 @@ class SmsAdapter(BasePlatformAdapter): error_msg = body.get("message", str(body)) logger.error( "[sms] send failed to %s: %s %s", - _redact_phone(chat_id), + redact_phone(chat_id), resp.status, error_msg, ) @@ -174,7 +162,7 @@ class SmsAdapter(BasePlatformAdapter): msg_sid = body.get("sid", "") last_result = SendResult(success=True, message_id=msg_sid) except Exception as e: - logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e) + logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e) return SendResult(success=False, error=str(e)) finally: # Close session only if we created a fallback (no persistent session) @@ -192,16 +180,7 @@ class SmsAdapter(BasePlatformAdapter): def format_message(self, content: str) -> str: """Strip markdown — SMS renders it as literal characters.""" - content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL) - content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL) - content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL) - content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL) - content = re.sub(r"```[a-z]*\n?", "", content) - content = re.sub(r"`(.+?)`", r"\1", content) - content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE) - content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content) - content = re.sub(r"\n{3,}", "\n\n", content) - return content.strip() + return strip_markdown(content) # ------------------------------------------------------------------ # Twilio webhook handler @@ -236,7 +215,7 @@ class SmsAdapter(BasePlatformAdapter): # Ignore messages from our own number (echo prevention) if from_number == self._from_number: - logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number)) + logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number)) return web.Response( text='', content_type="application/xml", @@ -244,8 +223,8 @@ class SmsAdapter(BasePlatformAdapter): logger.info( "[sms] inbound from %s -> %s: %s", - _redact_phone(from_number), - _redact_phone(to_number), + redact_phone(from_number), + redact_phone(to_number), text[:80], ) diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 8b4e43514..884ef9c45 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter): self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) self._pending_text_batches: Dict[str, MessageEvent] = {} self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} - self._token_lock_identity: Optional[str] = None self._polling_error_task: Optional[asyncio.Task] = None self._polling_conflict_count: int = 0 self._polling_network_error_count: int = 0 @@ -497,23 +496,7 @@ class TelegramAdapter(BasePlatformAdapter): return False try: - from gateway.status import acquire_scoped_lock - - self._token_lock_identity = self.config.token - acquired, existing = acquire_scoped_lock( - "telegram-bot-token", - self._token_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Telegram bot token" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Telegram poller." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("telegram_token_lock", message, retryable=False) + if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'): return False # Build the application @@ -737,12 +720,7 @@ class TelegramAdapter(BasePlatformAdapter): return True except Exception as e: - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("telegram-bot-token", self._token_lock_identity) - except Exception: - pass + self._release_platform_lock() message = f"Telegram startup failed: {e}" self._set_fatal_error("telegram_connect_error", message, retryable=True) logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True) @@ -768,12 +746,7 @@ class TelegramAdapter(BasePlatformAdapter): await self._app.shutdown() except Exception as e: logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True) - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("telegram-bot-token", self._token_lock_identity) - except Exception as e: - logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True) + self._release_platform_lock() for task in self._pending_photo_batch_tasks.values(): if task and not task.done(): @@ -784,7 +757,6 @@ class TelegramAdapter(BasePlatformAdapter): self._mark_disconnected() self._app = None self._bot = None - self._token_lock_identity = None logger.info("[%s] Disconnected from Telegram", self.name) def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool: diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py index aa07dc6a9..a0e71e01b 100644 --- a/gateway/platforms/wecom.py +++ b/gateway/platforms/wecom.py @@ -59,6 +59,7 @@ except ImportError: httpx = None # type: ignore[assignment] from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0 HEARTBEAT_INTERVAL_SECONDS = 30.0 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] -DEDUP_WINDOW_SECONDS = 300 DEDUP_MAX_SIZE = 1000 IMAGE_MAX_BYTES = 10 * 1024 * 1024 @@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter): self._listen_task: Optional[asyncio.Task] = None self._heartbeat_task: Optional[asyncio.Task] = None self._pending_responses: Dict[str, asyncio.Future] = {} - self._seen_messages: Dict[str, float] = {} + self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE) self._reply_req_ids: Dict[str, str] = {} # Text batching: merge rapid successive messages (Telegram-style). @@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter): await self._http_client.aclose() self._http_client = None - self._seen_messages.clear() + self._dedup.clear() logger.info("[%s] Disconnected", self.name) async def _cleanup_ws(self) -> None: @@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter): return msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex) - if self._is_duplicate(msg_id): + if self._dedup.is_duplicate(msg_id): logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id) return self._remember_reply_req_id(msg_id, self._payload_req_id(payload)) @@ -839,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter): wildcard = self._groups.get("*") return wildcard if isinstance(wildcard, dict) else {} - def _is_duplicate(self, msg_id: str) -> bool: - now = time.time() - if len(self._seen_messages) > DEDUP_MAX_SIZE: - cutoff = now - DEDUP_WINDOW_SECONDS - self._seen_messages = { - key: ts for key, ts in self._seen_messages.items() if ts > cutoff - } - if self._reply_req_ids: - self._reply_req_ids = { - key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages - } - - if msg_id in self._seen_messages: - return True - - self._seen_messages[msg_id] = now - return False - def _remember_reply_req_id(self, message_id: str, req_id: str) -> None: normalized_message_id = str(message_id or "").strip() normalized_req_id = str(req_id or "").strip() diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index 5e0208c77..3a4a80540 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate CRYPTO_AVAILABLE = False from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -1008,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter): self._typing_cache = TypingTicketCache() self._session: Optional[aiohttp.ClientSession] = None self._poll_task: Optional[asyncio.Task] = None - self._seen_messages: Dict[str, float] = {} - self._token_lock_identity: Optional[str] = None + self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS) self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip() self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip() @@ -1067,23 +1067,7 @@ class WeixinAdapter(BasePlatformAdapter): return False try: - from gateway.status import acquire_scoped_lock - - self._token_lock_identity = self._token - acquired, existing = acquire_scoped_lock( - "weixin-bot-token", - self._token_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Weixin token" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Weixin poller." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("weixin_token_lock", message, retryable=False) + if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'): return False except Exception as exc: logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc) @@ -1107,12 +1091,7 @@ class WeixinAdapter(BasePlatformAdapter): if self._session and not self._session.closed: await self._session.close() self._session = None - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("weixin-bot-token", self._token_lock_identity) - except Exception as exc: - logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True) + self._release_platform_lock() self._mark_disconnected() logger.info("[%s] Disconnected", self.name) @@ -1190,16 +1169,8 @@ class WeixinAdapter(BasePlatformAdapter): return message_id = str(message.get("message_id") or "").strip() - if message_id: - now = time.time() - self._seen_messages = { - key: value - for key, value in self._seen_messages.items() - if now - value < MESSAGE_DEDUP_TTL_SECONDS - } - if message_id in self._seen_messages: - return - self._seen_messages[message_id] = now + if message_id and self._dedup.is_duplicate(message_id): + return chat_type, effective_chat_id = _guess_chat_type(message, self._account_id) if chat_type == "group": diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index a6475dcb8..c616f7244 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter): self._bridge_log: Optional[Path] = None self._poll_task: Optional[asyncio.Task] = None self._http_session: Optional["aiohttp.ClientSession"] = None - self._session_lock_identity: Optional[str] = None def _whatsapp_require_mention(self) -> bool: configured = self.config.extra.get("require_mention") @@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter): # Acquire scoped lock to prevent duplicate sessions try: - from gateway.status import acquire_scoped_lock - - self._session_lock_identity = str(self._session_path) - acquired, existing = acquire_scoped_lock( - "whatsapp-session", - self._session_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this WhatsApp session" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second WhatsApp bridge." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("whatsapp_session_lock", message, retryable=False) + if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'): return False except Exception as e: logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e) @@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter): return True except Exception as e: - if self._session_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("whatsapp-session", self._session_lock_identity) - except Exception: - pass + self._release_platform_lock() logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True) self._close_bridge_log() return False @@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter): await self._http_session.close() self._http_session = None - if self._session_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("whatsapp-session", self._session_lock_identity) - except Exception as e: - logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True) + self._release_platform_lock() self._mark_disconnected() self._bridge_process = None self._close_bridge_log() - self._session_lock_identity = None print(f"[{self.name}] Disconnected") async def send( diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index fcb7c2dc5..56b9fb63c 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { } +# ============================================================================= +# Anthropic Key Helper +# ============================================================================= + +def get_anthropic_key() -> str: + """Return the first usable Anthropic credential, or ``""``. + + Checks both the ``.env`` file (via ``get_env_value``) and the process + environment (``os.getenv``). The fallback order mirrors the + ``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple: + + ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN + """ + from hermes_cli.config import get_env_value + + for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars: + value = get_env_value(var) or os.getenv(var, "") + if value: + return value + return "" + + # ============================================================================= # Kimi Code Endpoint Detection # ============================================================================= diff --git a/hermes_cli/cli_output.py b/hermes_cli/cli_output.py new file mode 100644 index 000000000..3d454eb30 --- /dev/null +++ b/hermes_cli/cli_output.py @@ -0,0 +1,79 @@ +"""Shared CLI output helpers for Hermes CLI modules. + +Extracts the identical ``print_info/success/warning/error`` and ``prompt()`` +functions previously duplicated across setup.py, tools_config.py, +mcp_config.py, and memory_setup.py. +""" + +import getpass +import sys + +from hermes_cli.colors import Colors, color + + +# ─── Print Helpers ──────────────────────────────────────────────────────────── + + +def print_info(text: str) -> None: + """Print a dim informational message.""" + print(color(f" {text}", Colors.DIM)) + + +def print_success(text: str) -> None: + """Print a green success message with ✓ prefix.""" + print(color(f"✓ {text}", Colors.GREEN)) + + +def print_warning(text: str) -> None: + """Print a yellow warning message with ⚠ prefix.""" + print(color(f"⚠ {text}", Colors.YELLOW)) + + +def print_error(text: str) -> None: + """Print a red error message with ✗ prefix.""" + print(color(f"✗ {text}", Colors.RED)) + + +def print_header(text: str) -> None: + """Print a bold yellow header.""" + print(color(f"\n {text}", Colors.YELLOW)) + + +# ─── Input Prompts ──────────────────────────────────────────────────────────── + + +def prompt( + question: str, + default: str | None = None, + password: bool = False, +) -> str: + """Prompt the user for input with optional default and password masking. + + Replaces the four independent ``_prompt()`` / ``prompt()`` implementations + in setup.py, tools_config.py, mcp_config.py, and memory_setup.py. + + Returns the user's input (stripped), or *default* if the user presses Enter. + Returns empty string on Ctrl-C or EOF. + """ + suffix = f" [{default}]" if default else "" + display = color(f" {question}{suffix}: ", Colors.YELLOW) + + try: + if password: + value = getpass.getpass(display) + else: + value = input(display) + value = value.strip() + return value if value else (default or "") + except (KeyboardInterrupt, EOFError): + print() + return "" + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt for a yes/no answer. Returns bool.""" + hint = "Y/n" if default else "y/N" + answer = prompt(f"{question} ({hint})") + if not answer: + return default + return answer.lower().startswith("y") diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 4661455d1..c3cf0456e 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -2582,7 +2582,8 @@ def show_config(): for env_key, name in keys: value = get_env_value(env_key) print(f" {name:<14} {redact_key(value)}") - anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY") + from hermes_cli.auth import get_anthropic_key + anthropic_value = get_anthropic_key() print(f" {'Anthropic':<14} {redact_key(anthropic_value)}") # Model settings @@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str): # Write only user config back (not the full merged defaults) ensure_hermes_home() - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + from utils import atomic_yaml_write + atomic_yaml_write(config_path, user_config, sort_keys=False) # Keep .env in sync for keys that terminal_tool reads directly from env vars. # config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc. diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index f5f8a228a..13c904692 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -336,8 +336,8 @@ def run_doctor(args): model_section[k] = raw_config.pop(k) else: raw_config.pop(k) - with open(config_path, "w") as f: - yaml.dump(raw_config, f, default_flow_style=False) + from utils import atomic_yaml_write + atomic_yaml_write(config_path, raw_config) check_ok("Migrated stale root-level keys into model section") fixed_count += 1 else: @@ -686,7 +686,8 @@ def run_doctor(args): else: check_warn("OpenRouter API", "(not configured)") - anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY") + from hermes_cli.auth import get_anthropic_key + anthropic_key = get_anthropic_key() if anthropic_key: print(" Checking Anthropic API...", end="", flush=True) try: diff --git a/hermes_cli/main.py b/hermes_cli/main.py index e004a6e93..4b7dd600b 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -2549,13 +2549,8 @@ def _model_flow_anthropic(config, current_model=""): from hermes_cli.models import _PROVIDER_MODELS # Check ALL credential sources - existing_key = ( - get_env_value("ANTHROPIC_TOKEN") - or os.getenv("ANTHROPIC_TOKEN", "") - or get_env_value("ANTHROPIC_API_KEY") - or os.getenv("ANTHROPIC_API_KEY", "") - or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "") - ) + from hermes_cli.auth import get_anthropic_key + existing_key = get_anthropic_key() cc_available = False try: from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index 9154ed50a..cf2dde089 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool: def _prompt(question: str, *, password: bool = False, default: str = "") -> str: - display = f" {question}" - if default: - display += f" [{default}]" - display += ": " - try: - if password: - value = getpass.getpass(color(display, Colors.YELLOW)) - else: - value = input(color(display, Colors.YELLOW)) - return value.strip() or default - except (KeyboardInterrupt, EOFError): - print() - return default + from hermes_cli.cli_output import prompt as _shared_prompt + return _shared_prompt(question, default=default, password=password) # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/memory_setup.py b/hermes_cli/memory_setup.py index 2843f4f44..1aa431367 100644 --- a/hermes_cli/memory_setup.py +++ b/hermes_cli/memory_setup.py @@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) - items: list of (label, description) tuples. Returns selected index, or default on escape/quit. """ - try: - import curses - result = [default] - - def _menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - curses.init_pair(3, curses.COLOR_CYAN, -1) - cursor = default - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - - # Title - try: - stdscr.addnstr(0, 0, title, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) - stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1, - curses.color_pair(3) if curses.has_colors() else curses.A_DIM) - except curses.error: - pass - - for i, (label, desc) in enumerate(items): - y = i + 3 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {label}" - if desc: - line += f" {desc}" - - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - - if key in (curses.KEY_UP, ord('k')): - cursor = (cursor - 1) % len(items) - elif key in (curses.KEY_DOWN, ord('j')): - cursor = (cursor + 1) % len(items) - elif key in (curses.KEY_ENTER, 10, 13): - result[0] = cursor - return - elif key in (27, ord('q')): - return - - curses.wrapper(_menu) - return result[0] - - except Exception: - # Fallback: numbered input - print(f"\n {title}\n") - for i, (label, desc) in enumerate(items): - marker = "→" if i == default else " " - d = f" {desc}" if desc else "" - print(f" {marker} {i + 1}. {label}{d}") - while True: - try: - val = input(f"\n Select [1-{len(items)}] ({default + 1}): ") - if not val: - return default - idx = int(val) - 1 - if 0 <= idx < len(items): - return idx - except (ValueError, EOFError): - return default + from hermes_cli.curses_ui import curses_radiolist + # Format (label, desc) tuples into display strings + display_items = [ + f"{label} {desc}" if desc else label + for label, desc in items + ] + return curses_radiolist(title, display_items, selected=default, cancel_returns=default) def _prompt(label: str, default: str | None = None, secret: bool = False) -> str: diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py new file mode 100644 index 000000000..18307912b --- /dev/null +++ b/hermes_cli/platforms.py @@ -0,0 +1,45 @@ +""" +Shared platform registry for Hermes Agent. + +Single source of truth for platform metadata consumed by both +skills_config (label display) and tools_config (default toolset +resolution). Import ``PLATFORMS`` from here instead of maintaining +duplicate dicts in each module. +""" + +from collections import OrderedDict +from typing import NamedTuple + + +class PlatformInfo(NamedTuple): + """Metadata for a single platform entry.""" + label: str + default_toolset: str + + +# Ordered so that TUI menus are deterministic. +PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ + ("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")), + ("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")), + ("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")), + ("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")), + ("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")), + ("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")), + ("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")), + ("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")), + ("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")), + ("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")), + ("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")), + ("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")), + ("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")), + ("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")), + ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")), + ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")), + ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")), +]) + + +def platform_label(key: str, default: str = "") -> str: + """Return the display label for a platform key, or *default*.""" + info = PLATFORMS.get(key) + return info.label if info is not None else default diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index ca877606f..fb70d9081 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -197,24 +197,12 @@ def print_header(title: str): print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD)) -def print_info(text: str): - """Print info text.""" - print(color(f" {text}", Colors.DIM)) - - -def print_success(text: str): - """Print success message.""" - print(color(f"✓ {text}", Colors.GREEN)) - - -def print_warning(text: str): - """Print warning message.""" - print(color(f"⚠ {text}", Colors.YELLOW)) - - -def print_error(text: str): - """Print error message.""" - print(color(f"✗ {text}", Colors.RED)) +from hermes_cli.cli_output import ( # noqa: E402 + print_error, + print_info, + print_success, + print_warning, +) def is_interactive_stdin() -> bool: @@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str: def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int: - """Single-select menu using curses to avoid simple_term_menu rendering bugs.""" - try: - import curses - result_holder = [default] - - def _curses_menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - cursor = default - scroll_offset = 0 - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - - # Rows available for list items: rows 2..(max_y-2) inclusive. - visible = max(1, max_y - 3) - - # Scroll the viewport so the cursor is always visible. - if cursor < scroll_offset: - scroll_offset = cursor - elif cursor >= scroll_offset + visible: - scroll_offset = cursor - visible + 1 - scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible))) - - try: - stdscr.addnstr( - 0, - 0, - question, - max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0), - ) - except curses.error: - pass - - for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))): - y = row + 2 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {choices[i]}" - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line, max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - if key in (curses.KEY_UP, ord("k")): - cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord("j")): - cursor = (cursor + 1) % len(choices) - elif key in (curses.KEY_ENTER, 10, 13): - result_holder[0] = cursor - return - elif key in (27, ord("q")): - return - - curses.wrapper(_curses_menu) - from hermes_cli.curses_ui import flush_stdin - flush_stdin() - return result_holder[0] - except Exception: - return -1 + """Single-select menu using curses. Delegates to curses_radiolist.""" + from hermes_cli.curses_ui import curses_radiolist + return curses_radiolist(question, choices, selected=default, cancel_returns=-1) diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index b017361fe..92424a0ca 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -15,25 +15,12 @@ from typing import List, Optional, Set from hermes_cli.config import load_config, save_config from hermes_cli.colors import Colors, color +from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label -PLATFORMS = { - "cli": "🖥️ CLI", - "telegram": "📱 Telegram", - "discord": "💬 Discord", - "slack": "💼 Slack", - "whatsapp": "📱 WhatsApp", - "signal": "📡 Signal", - "bluebubbles": "💬 BlueBubbles", - "email": "📧 Email", - "homeassistant": "🏠 Home Assistant", - "mattermost": "💬 Mattermost", - "matrix": "💬 Matrix", - "dingtalk": "💬 DingTalk", - "feishu": "🪽 Feishu", - "wecom": "💬 WeCom", - "weixin": "💬 Weixin", - "webhook": "🔗 Webhook", -} +# Backward-compatible view: {key: label_string} so existing code that +# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps +# working without changes to every call site. +PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"} # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/status.py b/hermes_cli/status.py index baba4f359..7a7a9c645 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -141,11 +141,8 @@ def show_status(args): display = redact_key(value) if not show_all else value print(f" {name:<12} {check_mark(has_key)} {display}") - anthropic_value = ( - get_env_value("ANTHROPIC_TOKEN") - or get_env_value("ANTHROPIC_API_KEY") - or "" - ) + from hermes_cli.auth import get_anthropic_key + anthropic_value = get_anthropic_key() anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}") diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 91c41dce5..343007cab 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve() # ─── UI Helpers (shared with setup.py) ──────────────────────────────────────── -def _print_info(text: str): - print(color(f" {text}", Colors.DIM)) - -def _print_success(text: str): - print(color(f"✓ {text}", Colors.GREEN)) - -def _print_warning(text: str): - print(color(f"⚠ {text}", Colors.YELLOW)) - -def _print_error(text: str): - print(color(f"✗ {text}", Colors.RED)) - -def _prompt(question: str, default: str = None, password: bool = False) -> str: - if default: - display = f"{question} [{default}]: " - else: - display = f"{question}: " - try: - if password: - import getpass - value = getpass.getpass(color(display, Colors.YELLOW)) - else: - value = input(color(display, Colors.YELLOW)) - return value.strip() or default or "" - except (KeyboardInterrupt, EOFError): - print() - return default or "" +from hermes_cli.cli_output import ( # noqa: E402 — late import block + print_error as _print_error, + print_info as _print_info, + print_success as _print_success, + print_warning as _print_warning, + prompt as _prompt, +) # ─── Toolset Registry ───────────────────────────────────────────────────────── @@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set: except Exception: return set() -# Platform display config +# Platform display config — derived from the canonical registry so every +# module shares the same data. Kept as dict-of-dicts for backward +# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns. +from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY + PLATFORMS = { - "cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"}, - "telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"}, - "discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"}, - "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, - "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, - "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, - "bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"}, - "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"}, - "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, - "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"}, - "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"}, - "feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"}, - "wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"}, - "weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"}, - "api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"}, - "mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"}, - "webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"}, + k: {"label": info.label, "default_toolset": info.default_toolset} + for k, info in _PLATFORMS_REGISTRY.items() } @@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool: # ─── Menu Helpers ───────────────────────────────────────────────────────────── def _prompt_choice(question: str, choices: list, default: int = 0) -> int: - """Single-select menu (arrow keys). Uses curses to avoid simple_term_menu - rendering bugs in tmux, iTerm, and other non-standard terminals.""" - - # Curses-based single-select — works in tmux, iTerm, and standard terminals - try: - import curses - result_holder = [default] - - def _curses_menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - cursor = default - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - try: - stdscr.addnstr(0, 0, question, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) - except curses.error: - pass - - for i, c in enumerate(choices): - y = i + 2 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {c}" - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line, max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - - if key in (curses.KEY_UP, ord('k')): - cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord('j')): - cursor = (cursor + 1) % len(choices) - elif key in (curses.KEY_ENTER, 10, 13): - result_holder[0] = cursor - return - elif key in (27, ord('q')): - return - - curses.wrapper(_curses_menu) - from hermes_cli.curses_ui import flush_stdin - flush_stdin() - return result_holder[0] - - except Exception: - pass - - # Fallback: numbered input (Windows without curses, etc.) - print(color(question, Colors.YELLOW)) - for i, c in enumerate(choices): - marker = "●" if i == default else "○" - style = Colors.GREEN if i == default else "" - print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}") - while True: - try: - val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM)) - if not val: - return default - idx = int(val) - 1 - if 0 <= idx < len(choices): - return idx - except (ValueError, KeyboardInterrupt, EOFError): - print() - return default + """Single-select menu (arrow keys). Delegates to curses_radiolist.""" + from hermes_cli.curses_ui import curses_radiolist + return curses_radiolist(question, choices, selected=default, cancel_returns=default) # ─── Token Estimation ──────────────────────────────────────────────────────── diff --git a/hermes_constants.py b/hermes_constants.py index 7d149f404..85955d548 100644 --- a/hermes_constants.py +++ b/hermes_constants.py @@ -189,6 +189,33 @@ def is_wsl() -> bool: return _wsl_detected +# ─── Well-Known Paths ───────────────────────────────────────────────────────── + + +def get_config_path() -> Path: + """Return the path to ``config.yaml`` under HERMES_HOME. + + Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated + in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.). + """ + return get_hermes_home() / "config.yaml" + + +def get_skills_dir() -> Path: + """Return the path to the skills directory under HERMES_HOME.""" + return get_hermes_home() / "skills" + + +def get_logs_dir() -> Path: + """Return the path to the logs directory under HERMES_HOME.""" + return get_hermes_home() / "logs" + + +def get_env_path() -> Path: + """Return the path to the ``.env`` file under HERMES_HOME.""" + return get_hermes_home() / ".env" + + OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models" diff --git a/hermes_logging.py b/hermes_logging.py index 5d71590c3..b765e9464 100644 --- a/hermes_logging.py +++ b/hermes_logging.py @@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler from pathlib import Path from typing import Optional -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path, get_hermes_home # Sentinel to track whether setup_logging() has already run. The function # is idempotent — calling it twice is safe but the second call is a no-op @@ -246,7 +246,7 @@ def _read_logging_config(): """ try: import yaml - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} diff --git a/hermes_time.py b/hermes_time.py index f7d085544..9f172d28f 100644 --- a/hermes_time.py +++ b/hermes_time.py @@ -16,7 +16,7 @@ crashes due to a bad timezone string. import logging import os from datetime import datetime -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path from typing import Optional logger = logging.getLogger(__name__) @@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str: # 2. config.yaml ``timezone`` key try: import yaml - hermes_home = get_hermes_home() - config_path = hermes_home / "config.yaml" + config_path = get_config_path() if config_path.exists(): with open(config_path) as f: cfg = yaml.safe_load(f) or {} diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index ef17af10b..d9ca627c4 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None): config = PlatformConfig(enabled=True, token="e2e-test-token") if platform == Platform.DISCORD: - with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()): + from gateway.platforms.helpers import ThreadParticipationTracker + with patch.object(ThreadParticipationTracker, "_load", return_value=set()): adapter = DiscordAdapter(config) platform_key = Platform.DISCORD elif platform == Platform.SLACK: diff --git a/tests/gateway/test_dingtalk.py b/tests/gateway/test_dingtalk.py index 5c73253fb..527113650 100644 --- a/tests/gateway/test_dingtalk.py +++ b/tests/gateway/test_dingtalk.py @@ -119,28 +119,29 @@ class TestDeduplication: def test_first_message_not_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - assert adapter._is_duplicate("msg-1") is False + assert adapter._dedup.is_duplicate("msg-1") is False def test_second_same_message_is_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - adapter._is_duplicate("msg-1") - assert adapter._is_duplicate("msg-1") is True + adapter._dedup.is_duplicate("msg-1") + assert adapter._dedup.is_duplicate("msg-1") is True def test_different_messages_not_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - adapter._is_duplicate("msg-1") - assert adapter._is_duplicate("msg-2") is False + adapter._dedup.is_duplicate("msg-1") + assert adapter._dedup.is_duplicate("msg-2") is False def test_cache_cleanup_on_overflow(self): - from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE + from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + max_size = adapter._dedup._max_size # Fill beyond max - for i in range(DEDUP_MAX_SIZE + 10): - adapter._is_duplicate(f"msg-{i}") + for i in range(max_size + 10): + adapter._dedup.is_duplicate(f"msg-{i}") # Cache should have been pruned - assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10 + assert len(adapter._dedup._seen) <= max_size + 10 # --------------------------------------------------------------------------- @@ -253,13 +254,13 @@ class TestConnect: from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter._session_webhooks["a"] = "http://x" - adapter._seen_messages["b"] = 1.0 + adapter._dedup._seen["b"] = 1.0 adapter._http_client = AsyncMock() adapter._stream_task = None await adapter.disconnect() assert len(adapter._session_webhooks) == 0 - assert len(adapter._seen_messages) == 0 + assert len(adapter._dedup._seen) == 0 assert adapter._http_client is None diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index dd594cf7e..9f094dd0d 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch): assert ok is False assert released == [("discord-bot-token", "test-token")] - assert adapter._token_lock_identity is None + assert adapter._platform_lock_identity is None diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index bc63c14f5..29f65efc6 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") # Simulate bot having previously participated in thread 456 - adapter._bot_participated_threads.add("456") + adapter._threads.mark("456") thread = FakeThread(channel_id=456, name="existing thread") message = make_message(channel=thread, content="follow-up without mention") @@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch): await adapter._handle_message(message) - assert "555" in adapter._bot_participated_threads + assert "555" in adapter._threads @pytest.mark.asyncio @@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp await adapter._handle_message(message) - assert "777" in adapter._bot_participated_threads + assert "777" in adapter._threads diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py index 0288b620d..083f61ac7 100644 --- a/tests/gateway/test_discord_thread_persistence.py +++ b/tests/gateway/test_discord_thread_persistence.py @@ -1,6 +1,6 @@ """Tests for Discord thread participation persistence. -Verifies that _bot_participated_threads survives adapter restarts by +Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by being persisted to ~/.hermes/discord_threads.json. """ @@ -25,13 +25,13 @@ class TestDiscordThreadPersistence: def test_starts_empty_when_no_state_file(self, tmp_path): adapter = self._make_adapter(tmp_path) - assert adapter._bot_participated_threads == set() + assert "$nonexistent" not in adapter._threads def test_track_thread_persists_to_disk(self, tmp_path): adapter = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter._track_thread("111") - adapter._track_thread("222") + adapter._threads.mark("111") + adapter._threads.mark("222") state_file = tmp_path / "discord_threads.json" assert state_file.exists() @@ -42,42 +42,43 @@ class TestDiscordThreadPersistence: """Threads tracked by one adapter instance are visible to the next.""" adapter1 = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter1._track_thread("aaa") - adapter1._track_thread("bbb") + adapter1._threads.mark("aaa") + adapter1._threads.mark("bbb") adapter2 = self._make_adapter(tmp_path) - assert "aaa" in adapter2._bot_participated_threads - assert "bbb" in adapter2._bot_participated_threads + assert "aaa" in adapter2._threads + assert "bbb" in adapter2._threads def test_duplicate_track_does_not_double_save(self, tmp_path): adapter = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter._track_thread("111") - adapter._track_thread("111") # no-op + adapter._threads.mark("111") + adapter._threads.mark("111") # no-op saved = json.loads((tmp_path / "discord_threads.json").read_text()) assert saved.count("111") == 1 def test_caps_at_max_tracked_threads(self, tmp_path): adapter = self._make_adapter(tmp_path) - adapter._MAX_TRACKED_THREADS = 5 + adapter._threads._max_tracked = 5 with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): for i in range(10): - adapter._track_thread(str(i)) + adapter._threads.mark(str(i)) - assert len(adapter._bot_participated_threads) == 5 + saved = json.loads((tmp_path / "discord_threads.json").read_text()) + assert len(saved) == 5 def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): state_file = tmp_path / "discord_threads.json" state_file.write_text("not valid json{{{") adapter = self._make_adapter(tmp_path) - assert adapter._bot_participated_threads == set() + assert "$nonexistent" not in adapter._threads def test_missing_hermes_home_does_not_crash(self, tmp_path): """Load/save tolerate missing directories.""" fake_home = tmp_path / "nonexistent" / "deep" with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}): - from gateway.platforms.discord import DiscordAdapter - # _load should return empty set, not crash - threads = DiscordAdapter._load_participated_threads() - assert threads == set() + from gateway.platforms.helpers import ThreadParticipationTracker + # ThreadParticipationTracker should return empty set, not crash + tracker = ThreadParticipationTracker("discord") + assert "$test" not in tracker diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py index d36c2b765..873b873c2 100644 --- a/tests/gateway/test_matrix_mention.py +++ b/tests/gateway/test_matrix_mention.py @@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - adapter._bot_participated_threads.add("$thread1") + adapter._threads.mark("$thread1") event = _make_event("hello without mention", thread_id="$thread1") @@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - adapter._bot_participated_threads.add("$thread_root") + adapter._threads.mark("$thread_root") event = _make_event("reply in thread", thread_id="$thread_root") await adapter._on_room_message(event) @@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch): @pytest.mark.asyncio async def test_auto_thread_tracks_participation(monkeypatch): - """Auto-created threads are tracked in _bot_participated_threads.""" + """Auto-created threads are tracked in _threads.""" monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false") monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() event = _make_event("hello", event_id="$msg1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) - assert "$msg1" in adapter._bot_participated_threads + assert "$msg1" in adapter._threads # --------------------------------------------------------------------------- @@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch): class TestThreadPersistence: def test_empty_state_file(self, tmp_path, monkeypatch): """No state file → empty set.""" - from gateway.platforms.matrix import MatrixAdapter + from gateway.platforms.helpers import ThreadParticipationTracker monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: tmp_path / "matrix_threads.json"), + ThreadParticipationTracker, "_state_path", + lambda self: tmp_path / "matrix_threads.json", ) adapter = _make_adapter() - loaded = adapter._load_participated_threads() - assert loaded == set() + assert "$nonexistent" not in adapter._threads def test_track_thread_persists(self, tmp_path, monkeypatch): - """_track_thread writes to disk.""" - from gateway.platforms.matrix import MatrixAdapter + """mark() writes to disk.""" + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - adapter._track_thread("$thread_abc") + adapter._threads.mark("$thread_abc") data = json.loads(state_path.read_text()) assert "$thread_abc" in data def test_threads_survive_reload(self, tmp_path, monkeypatch): """Persisted threads are loaded by a new adapter instance.""" - from gateway.platforms.matrix import MatrixAdapter + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" state_path.write_text(json.dumps(["$t1", "$t2"])) monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - assert "$t1" in adapter._bot_participated_threads - assert "$t2" in adapter._bot_participated_threads + assert "$t1" in adapter._threads + assert "$t2" in adapter._threads def test_cap_max_tracked_threads(self, tmp_path, monkeypatch): - """Thread set is trimmed to _MAX_TRACKED_THREADS.""" - from gateway.platforms.matrix import MatrixAdapter + """Thread set is trimmed to max_tracked.""" + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - adapter._MAX_TRACKED_THREADS = 5 + adapter._threads._max_tracked = 5 for i in range(10): - adapter._bot_participated_threads.add(f"$t{i}") - adapter._save_participated_threads() + adapter._threads.mark(f"$t{i}") data = json.loads(state_path.read_text()) assert len(data) == 5 @@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch): _set_dm(adapter) event = _make_event("@hermes:example.org help me", event_id="$dm1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): adapter = _make_adapter() _set_dm(adapter) - adapter._bot_participated_threads.add("$existing_thread") + adapter._threads.mark("$existing_thread") event = _make_event("@hermes:example.org help me", thread_id="$existing_thread") await adapter._on_room_message(event) @@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): @pytest.mark.asyncio async def test_dm_mention_thread_tracks_participation(monkeypatch): - """DM mention-thread tracks the thread in _bot_participated_threads.""" + """DM mention-thread tracks the thread in _threads.""" monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true") monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") @@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch): _set_dm(adapter) event = _make_event("@hermes:example.org help", event_id="$dm1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) - assert "$dm1" in adapter._bot_participated_threads + assert "$dm1" in adapter._threads # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py index 7d47c0a3e..56e46f636 100644 --- a/tests/gateway/test_mattermost.py +++ b/tests/gateway/test_mattermost.py @@ -614,25 +614,27 @@ class TestMattermostDedup: assert self.adapter.handle_message.call_count == 2 def test_prune_seen_clears_expired(self): - """_prune_seen should remove entries older than _SEEN_TTL.""" + """Dedup cache should remove entries older than TTL on overflow.""" now = time.time() + dedup = self.adapter._dedup # Fill with enough expired entries to trigger pruning - for i in range(self.adapter._SEEN_MAX + 10): - self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago + for i in range(dedup._max_size + 10): + dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL) # Add a fresh one - self.adapter._seen_posts["fresh"] = now + dedup._seen["fresh"] = now - self.adapter._prune_seen() + # Trigger pruning by calling is_duplicate with a new entry (over max_size) + dedup.is_duplicate("trigger_prune") # Old entries should be pruned, fresh one kept - assert "fresh" in self.adapter._seen_posts - assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX + assert "fresh" in dedup._seen + assert len(dedup._seen) < dedup._max_size + 10 def test_seen_cache_tracks_post_ids(self): - """Posts are tracked in _seen_posts dict.""" - self.adapter._seen_posts["test_post"] = time.time() - assert "test_post" in self.adapter._seen_posts + """Posts are tracked in the dedup cache.""" + self.adapter._dedup._seen["test_post"] = time.time() + assert "test_post" in self.adapter._dedup._seen # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index ae985300d..265f9be78 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -114,16 +114,16 @@ class TestSignalAdapterInit: class TestSignalHelpers: def test_redact_phone_long(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("+15551234567") == "+155****4567" + from gateway.platforms.helpers import redact_phone + assert redact_phone("+155****4567") == "+155****4567" def test_redact_phone_short(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("+12345") == "+1****45" + from gateway.platforms.helpers import redact_phone + assert redact_phone("+12345") == "+1****45" def test_redact_phone_empty(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("") == "" + from gateway.platforms.helpers import redact_phone + assert redact_phone("") == "" def test_parse_comma_list(self): from gateway.platforms.signal import _parse_comma_list diff --git a/tests/gateway/test_telegram_conflict.py b/tests/gateway/test_telegram_conflict.py index 47a67f229..dcf311688 100644 --- a/tests/gateway/test_telegram_conflict.py +++ b/tests/gateway/test_telegram_conflict.py @@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch): async def _noop(): return [] monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop) + # Mock HTTPXRequest so the builder chain doesn't fail + monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock()) @pytest.mark.asyncio @@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch): ok = await adapter.connect() assert ok is False - assert adapter.fatal_error_code == "telegram_token_lock" + assert adapter.fatal_error_code == "telegram-bot-token_lock" assert adapter.has_fatal_error is True - assert "already using this Telegram bot token" in adapter.fatal_error_message + assert "already in use" in adapter.fatal_error_message @pytest.mark.asyncio @@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) @@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) @@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder app = SimpleNamespace( bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()), updater=SimpleNamespace(), @@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr( "gateway.platforms.telegram.Application", diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py index 7b9e49d4f..dd0ae17f8 100644 --- a/tests/tools/test_skill_manager_tool.py +++ b/tests/tools/test_skill_manager_tool.py @@ -348,7 +348,7 @@ word word result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert outside_file.read_text() == "old text here" @@ -412,7 +412,7 @@ class TestWriteFile: result = _write_file("my-skill", "references/escape/owned.md", "malicious") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert not (outside_dir / "owned.md").exists() @@ -449,7 +449,7 @@ class TestRemoveFile: result = _remove_file("my-skill", "references/escape/keep.txt") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert outside_file.exists() diff --git a/tools/credential_files.py b/tools/credential_files.py index 6ddcd0770..7998321e6 100644 --- a/tools/credential_files.py +++ b/tools/credential_files.py @@ -80,20 +80,18 @@ def register_credential_file( # Resolve symlinks and normalise ``..`` before the containment check so # that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME. - try: - resolved = host_path.resolve() - hermes_home_resolved = hermes_home.resolve() - resolved.relative_to(hermes_home_resolved) # raises ValueError if outside - except ValueError: + from tools.path_security import validate_within_dir + + containment_error = validate_within_dir(host_path, hermes_home) + if containment_error: logger.warning( - "credential_files: rejected path traversal %r " - "(resolves to %s, outside HERMES_HOME %s)", + "credential_files: rejected path traversal %r (%s)", relative_path, - resolved, - hermes_home_resolved, + containment_error, ) return False + resolved = host_path.resolve() if not resolved.is_file(): logger.debug("credential_files: skipping %s (not found)", resolved) return False @@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]: cfg = read_raw_config() cred_files = cfg.get("terminal", {}).get("credential_files") if isinstance(cred_files, list): - hermes_home_resolved = hermes_home.resolve() + from tools.path_security import validate_within_dir + for item in cred_files: if isinstance(item, str) and item.strip(): rel = item.strip() @@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]: "credential_files: rejected absolute config path %r", rel, ) continue - host_path = (hermes_home / rel).resolve() - try: - host_path.relative_to(hermes_home_resolved) - except ValueError: + host_path = hermes_home / rel + containment_error = validate_within_dir(host_path, hermes_home) + if containment_error: logger.warning( - "credential_files: rejected config path traversal %r " - "(resolves to %s, outside HERMES_HOME %s)", - rel, host_path, hermes_home_resolved, + "credential_files: rejected config path traversal %r (%s)", + rel, containment_error, ) continue - if host_path.is_file(): + resolved_path = host_path.resolve() + if resolved_path.is_file(): container_path = f"/root/.hermes/{rel}" result.append({ - "host_path": str(host_path), + "host_path": str(resolved_path), "container_path": container_path, }) except Exception as e: diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 3018b8731..e2db93381 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]: ) # Validate containment after resolution + from tools.path_security import validate_within_dir + scripts_dir = get_hermes_home() / "scripts" scripts_dir.mkdir(parents=True, exist_ok=True) - resolved = (scripts_dir / raw).resolve() - try: - resolved.relative_to(scripts_dir.resolve()) - except ValueError: + containment_error = validate_within_dir(scripts_dir / raw, scripts_dir) + if containment_error: return ( f"Script path escapes the scripts directory via traversal: {raw!r}" ) diff --git a/tools/path_security.py b/tools/path_security.py new file mode 100644 index 000000000..828011e5d --- /dev/null +++ b/tools/path_security.py @@ -0,0 +1,43 @@ +"""Shared path validation helpers for tool implementations. + +Extracts the ``resolve() + relative_to()`` and ``..`` traversal check +patterns previously duplicated across skill_manager_tool, skills_tool, +skills_hub, cronjob_tools, and credential_files. +""" + +import logging +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + + +def validate_within_dir(path: Path, root: Path) -> Optional[str]: + """Ensure *path* resolves to a location within *root*. + + Returns an error message string if validation fails, or ``None`` if the + path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize + ``..`` components. + + Usage:: + + error = validate_within_dir(user_path, allowed_root) + if error: + return json.dumps({"error": error}) + """ + try: + resolved = path.resolve() + root_resolved = root.resolve() + resolved.relative_to(root_resolved) + except (ValueError, OSError) as exc: + return f"Path escapes allowed directory: {exc}" + return None + + +def has_traversal_component(path_str: str) -> bool: + """Return True if *path_str* contains ``..`` traversal components. + + Quick check for obvious traversal attempts before doing full resolution. + """ + parts = Path(path_str).parts + return ".." in parts diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 2273d75fa..2b2625fa0 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]: Validate a file path for write_file/remove_file. Must be under an allowed subdirectory and not escape the skill dir. """ + from tools.path_security import has_traversal_component + if not file_path: return "file_path is required." normalized = Path(file_path) # Prevent path traversal - if ".." in normalized.parts: + if has_traversal_component(file_path): return "Path traversal ('..') is not allowed." # Must be under an allowed subdirectory @@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]: def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]: """Resolve a supporting-file path and ensure it stays within the skill directory.""" + from tools.path_security import validate_within_dir + target = skill_dir / file_path - try: - resolved = target.resolve(strict=False) - skill_dir_resolved = skill_dir.resolve() - resolved.relative_to(skill_dir_resolved) - except ValueError: - return None, "Path escapes skill directory boundary." - except OSError as e: - return None, f"Invalid file path '{file_path}': {e}" + error = validate_within_dir(target, skill_dir) + if error: + return None, error return target, None diff --git a/tools/skills_tool.py b/tools/skills_tool.py index 085ed0055..94b7c235b 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]: return None -def _estimate_tokens(content: str) -> int: - """ - Rough token estimate (4 chars per token average). - - Args: - content: Text content - - Returns: - Estimated token count - """ - return len(content) // 4 +# Token estimation — use the shared implementation from model_metadata. +from agent.model_metadata import estimate_tokens_rough as _estimate_tokens def _parse_tags(tags_value) -> List[str]: @@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: # If a specific file path is requested, read that instead if file_path and skill_dir: + from tools.path_security import validate_within_dir, has_traversal_component + # Security: Prevent path traversal attacks - normalized_path = Path(file_path) - if ".." in normalized_path.parts: + if has_traversal_component(file_path): return json.dumps( { "success": False, @@ -962,24 +954,13 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: target_file = skill_dir / file_path # Security: Verify resolved path is still within skill directory - try: - resolved = target_file.resolve() - skill_dir_resolved = skill_dir.resolve() - if not resolved.is_relative_to(skill_dir_resolved): - return json.dumps( - { - "success": False, - "error": "Path escapes skill directory boundary.", - "hint": "Use a relative path within the skill directory", - }, - ensure_ascii=False, - ) - except (OSError, ValueError): + traversal_error = validate_within_dir(target_file, skill_dir) + if traversal_error: return json.dumps( { "success": False, - "error": f"Invalid file path: '{file_path}'", - "hint": "Use a valid relative path within the skill directory", + "error": traversal_error, + "hint": "Use a relative path within the skill directory", }, ensure_ascii=False, ) diff --git a/utils.py b/utils.py index 9a2105d54..bd2a6b70f 100644 --- a/utils.py +++ b/utils.py @@ -1,13 +1,16 @@ """Shared utility functions for hermes-agent.""" import json +import logging import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any, List, Optional, Union import yaml +logger = logging.getLogger(__name__) + TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"}) @@ -124,3 +127,88 @@ def atomic_yaml_write( except OSError: pass raise + + +# ─── JSON Helpers ───────────────────────────────────────────────────────────── + + +def safe_json_loads(text: str, default: Any = None) -> Any: + """Parse JSON, returning *default* on any parse error. + + Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)`` + pattern duplicated across display.py, anthropic_adapter.py, + auxiliary_client.py, and others. + """ + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError, ValueError): + return default + + +def read_json_file(path: Path, default: Any = None) -> Any: + """Read and parse a JSON file, returning *default* on any error. + + Replaces the repeated ``try: json.loads(path.read_text()) except ...`` + pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py, + and skill_utils.py. + """ + try: + return json.loads(Path(path).read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError, IOError, ValueError) as exc: + logger.debug("Failed to read %s: %s", path, exc) + return default + + +def read_jsonl(path: Path) -> List[dict]: + """Read a JSONL file (one JSON object per line). + + Returns a list of parsed objects, skipping blank lines. + """ + entries = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def append_jsonl(path: Path, entry: dict) -> None: + """Append a single JSON object as a new line to a JSONL file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + +# ─── Environment Variable Helpers ───────────────────────────────────────────── + + +def env_str(key: str, default: str = "") -> str: + """Read an environment variable, stripped of whitespace. + + Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times + across runtime_provider.py, anthropic_adapter.py, models.py, etc. + """ + return os.getenv(key, default).strip() + + +def env_lower(key: str, default: str = "") -> str: + """Read an environment variable, stripped and lowercased.""" + return os.getenv(key, default).strip().lower() + + +def env_int(key: str, default: int = 0) -> int: + """Read an environment variable as an integer, with fallback.""" + raw = os.getenv(key, "").strip() + if not raw: + return default + try: + return int(raw) + except (ValueError, TypeError): + return default + + +def env_bool(key: str, default: bool = False) -> bool: + """Read an environment variable as a boolean.""" + return is_truthy_value(os.getenv(key, ""), default=default) From 885123d44bd72330dc0afe044b81836a25ebdfaf Mon Sep 17 00:00:00 2001 From: Markus Corazzione Date: Sat, 11 Apr 2026 14:02:12 -0700 Subject: [PATCH 10/35] fix(weixin): add per-chunk retry with backoff for text delivery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When sending multi-chunk responses, individual chunks can fail due to transient iLink API errors. Previously a single failure would abort the entire message. Now each chunk is retried with linear backoff before giving up, and the same client_id is reused across retries for server-side deduplication. Configurable via config.yaml (platforms.weixin.extra) or env vars: - send_chunk_delay_seconds (default 0.35s) — pacing between chunks - send_chunk_retries (default 2) — max retry attempts per chunk - send_chunk_retry_delay_seconds (default 1.0s) — base retry delay Replaces the hardcoded 0.3s inter-chunk delay from #7903. Salvaged from PR #7899 by @corazzione. Fixes #7836. --- gateway/platforms/weixin.py | 64 +++++++++++++++++++++++++++++++----- tests/gateway/test_weixin.py | 49 +++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 8 deletions(-) diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index 3a4a80540..5821d922f 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -1017,6 +1017,16 @@ class WeixinAdapter(BasePlatformAdapter): self._cdn_base_url = str( extra.get("cdn_base_url") or os.getenv("WEIXIN_CDN_BASE_URL", WEIXIN_CDN_BASE_URL) ).strip().rstrip("/") + self._send_chunk_delay_seconds = float( + extra.get("send_chunk_delay_seconds") or os.getenv("WEIXIN_SEND_CHUNK_DELAY_SECONDS", "0.35") + ) + self._send_chunk_retries = int( + extra.get("send_chunk_retries") or os.getenv("WEIXIN_SEND_CHUNK_RETRIES", "2") + ) + self._send_chunk_retry_delay_seconds = float( + extra.get("send_chunk_retry_delay_seconds") + or os.getenv("WEIXIN_SEND_CHUNK_RETRY_DELAY_SECONDS", "1.0") + ) self._dm_policy = str(extra.get("dm_policy") or os.getenv("WEIXIN_DM_POLICY", "open")).strip().lower() self._group_policy = str(extra.get("group_policy") or os.getenv("WEIXIN_GROUP_POLICY", "disabled")).strip().lower() allow_from = extra.get("allow_from") @@ -1346,6 +1356,47 @@ class WeixinAdapter(BasePlatformAdapter): content, self.MAX_MESSAGE_LENGTH, self._split_multiline_messages, ) + async def _send_text_chunk( + self, + *, + chat_id: str, + chunk: str, + context_token: Optional[str], + client_id: str, + ) -> None: + """Send a single text chunk with per-chunk retry and backoff.""" + last_error: Optional[Exception] = None + for attempt in range(self._send_chunk_retries + 1): + try: + await _send_message( + self._session, + base_url=self._base_url, + token=self._token, + to=chat_id, + text=chunk, + context_token=context_token, + client_id=client_id, + ) + return + except Exception as exc: + last_error = exc + if attempt >= self._send_chunk_retries: + break + wait = self._send_chunk_retry_delay_seconds * (attempt + 1) + logger.warning( + "[%s] send chunk failed to=%s attempt=%d/%d, retrying in %.2fs: %s", + self.name, + _safe_id(chat_id), + attempt + 1, + self._send_chunk_retries + 1, + wait, + exc, + ) + if wait > 0: + await asyncio.sleep(wait) + assert last_error is not None + raise last_error + async def send( self, chat_id: str, @@ -1360,19 +1411,16 @@ class WeixinAdapter(BasePlatformAdapter): try: chunks = self._split_text(self.format_message(content)) for idx, chunk in enumerate(chunks): - if idx > 0: - await asyncio.sleep(0.3) client_id = f"hermes-weixin-{uuid.uuid4().hex}" - await _send_message( - self._session, - base_url=self._base_url, - token=self._token, - to=chat_id, - text=chunk, + await self._send_text_chunk( + chat_id=chat_id, + chunk=chunk, context_token=context_token, client_id=client_id, ) last_message_id = client_id + if idx < len(chunks) - 1 and self._send_chunk_delay_seconds > 0: + await asyncio.sleep(self._send_chunk_delay_seconds) return SendResult(success=True, message_id=last_message_id) except Exception as exc: logger.error("[%s] send failed to=%s: %s", self.name, _safe_id(chat_id), exc) diff --git a/tests/gateway/test_weixin.py b/tests/gateway/test_weixin.py index 815ea75ef..bb439fa9a 100644 --- a/tests/gateway/test_weixin.py +++ b/tests/gateway/test_weixin.py @@ -283,6 +283,55 @@ class TestWeixinSendMessageIntegration: ) +class TestWeixinChunkDelivery: + def _connected_adapter(self) -> WeixinAdapter: + adapter = _make_adapter() + adapter._session = object() + adapter._token = "test-token" + adapter._base_url = "https://weixin.example.com" + adapter._token_store.get = lambda account_id, chat_id: "ctx-token" + return adapter + + @patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock) + @patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock) + def test_send_waits_between_multiple_chunks(self, send_message_mock, sleep_mock): + adapter = self._connected_adapter() + adapter.MAX_MESSAGE_LENGTH = 12 + + # Use double newlines so _pack_markdown_blocks splits into 3 blocks + result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird")) + + assert result.success is True + assert send_message_mock.await_count == 3 + assert sleep_mock.await_count == 2 + + @patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock) + @patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock) + def test_send_retries_failed_chunk_before_continuing(self, send_message_mock, sleep_mock): + adapter = self._connected_adapter() + adapter.MAX_MESSAGE_LENGTH = 12 + calls = {"count": 0} + + async def flaky_send(*args, **kwargs): + calls["count"] += 1 + if calls["count"] == 2: + raise RuntimeError("temporary iLink failure") + + send_message_mock.side_effect = flaky_send + + # Use double newlines so _pack_markdown_blocks splits into 3 blocks + result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird")) + + assert result.success is True + # 3 chunks, but chunk 2 fails once and retries → 4 _send_message calls total + assert send_message_mock.await_count == 4 + # The retried chunk should reuse the same client_id for deduplication + first_try = send_message_mock.await_args_list[1].kwargs + retry = send_message_mock.await_args_list[2].kwargs + assert first_try["text"] == retry["text"] + assert first_try["client_id"] == retry["client_id"] + + class TestWeixinRemoteMediaSafety: def test_download_remote_media_blocks_unsafe_urls(self): adapter = _make_adapter() From 75380de4301a74f96e74ee8f68a572b97b42d908 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:02:46 -0700 Subject: [PATCH 11/35] fix: reap orphaned browser sessions on startup (#7931) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a Python process exits uncleanly (SIGKILL, crash, gateway restart via hermes update), in-memory _active_sessions tracking is lost but the agent-browser node daemons and their Chromium child processes keep running indefinitely. On a long-running system this causes unbounded memory growth — 24 orphaned sessions consumed 7.6 GB on a production machine over 9 days. Add _reap_orphaned_browser_sessions() which scans the tmp directory for agent-browser-{h_*,cdp_*} socket dirs on cleanup thread startup. For each dir not tracked by the current process, reads the daemon PID file and sends SIGTERM if the daemon is still alive. Handles edge cases: dead PIDs, corrupt PID files, permission errors, foreign processes. The reaper runs once on thread startup (not every 30s) to avoid races with sessions being actively created by concurrent agents. --- tests/tools/test_browser_orphan_reaper.py | 158 ++++++++++++++++++++++ tools/browser_tool.py | 91 +++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 tests/tools/test_browser_orphan_reaper.py diff --git a/tests/tools/test_browser_orphan_reaper.py b/tests/tools/test_browser_orphan_reaper.py new file mode 100644 index 000000000..254dad7db --- /dev/null +++ b/tests/tools/test_browser_orphan_reaper.py @@ -0,0 +1,158 @@ +"""Tests for _reap_orphaned_browser_sessions() — kills orphaned agent-browser +daemons whose Python parent exited without cleaning up.""" + +import os +import signal +import textwrap +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + + +@pytest.fixture +def fake_tmpdir(tmp_path): + """Patch _socket_safe_tmpdir to return a temp dir we control.""" + with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)): + yield tmp_path + + +@pytest.fixture(autouse=True) +def _isolate_sessions(): + """Ensure _active_sessions is empty for each test.""" + import tools.browser_tool as bt + orig = bt._active_sessions.copy() + bt._active_sessions.clear() + yield + bt._active_sessions.clear() + bt._active_sessions.update(orig) + + +def _make_socket_dir(tmpdir, session_name, pid=None): + """Create a fake agent-browser socket directory with optional PID file.""" + d = tmpdir / f"agent-browser-{session_name}" + d.mkdir() + if pid is not None: + (d / f"{session_name}.pid").write_text(str(pid)) + return d + + +class TestReapOrphanedBrowserSessions: + """Tests for the orphan reaper function.""" + + def test_no_socket_dirs_is_noop(self, fake_tmpdir): + """No socket dirs => nothing happens, no errors.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + _reap_orphaned_browser_sessions() # should not raise + + def test_stale_dir_without_pid_file_is_removed(self, fake_tmpdir): + """Socket dir with no PID file is cleaned up.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + d = _make_socket_dir(fake_tmpdir, "h_abc1234567") + assert d.exists() + _reap_orphaned_browser_sessions() + assert not d.exists() + + def test_stale_dir_with_dead_pid_is_removed(self, fake_tmpdir): + """Socket dir whose daemon PID is dead gets cleaned up.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + d = _make_socket_dir(fake_tmpdir, "h_dead123456", pid=999999999) + assert d.exists() + _reap_orphaned_browser_sessions() + assert not d.exists() + + def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir): + """Alive daemon not tracked by _active_sessions gets SIGTERM.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345) + + kill_calls = [] + original_kill = os.kill + + def mock_kill(pid, sig): + kill_calls.append((pid, sig)) + if sig == 0: + return # pretend process exists + # Don't actually kill anything + + with patch("os.kill", side_effect=mock_kill): + _reap_orphaned_browser_sessions() + + # Should have checked existence (sig 0) then killed (SIGTERM) + assert (12345, 0) in kill_calls + assert (12345, signal.SIGTERM) in kill_calls + + def test_tracked_session_is_not_reaped(self, fake_tmpdir): + """Sessions tracked in _active_sessions are left alone.""" + import tools.browser_tool as bt + from tools.browser_tool import _reap_orphaned_browser_sessions + + session_name = "h_tracked1234" + d = _make_socket_dir(fake_tmpdir, session_name, pid=12345) + + # Register the session as actively tracked + bt._active_sessions["some_task"] = {"session_name": session_name} + + kill_calls = [] + + def mock_kill(pid, sig): + kill_calls.append((pid, sig)) + + with patch("os.kill", side_effect=mock_kill): + _reap_orphaned_browser_sessions() + + # Should NOT have tried to kill anything + assert len(kill_calls) == 0 + # Dir should still exist + assert d.exists() + + def test_permission_error_on_kill_check_skips(self, fake_tmpdir): + """If we can't check the PID (PermissionError), skip it.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "h_perm1234567", pid=12345) + + def mock_kill(pid, sig): + if sig == 0: + raise PermissionError("not our process") + + with patch("os.kill", side_effect=mock_kill): + _reap_orphaned_browser_sessions() + + # Dir should still exist (we didn't touch someone else's process) + assert d.exists() + + def test_cdp_sessions_are_also_reaped(self, fake_tmpdir): + """CDP sessions (cdp_ prefix) are also scanned.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "cdp_abc1234567") + assert d.exists() + _reap_orphaned_browser_sessions() + # No PID file → cleaned up + assert not d.exists() + + def test_non_hermes_dirs_are_ignored(self, fake_tmpdir): + """Socket dirs that don't match our naming pattern are left alone.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + # Create a dir that doesn't match h_* or cdp_* pattern + d = fake_tmpdir / "agent-browser-other_session" + d.mkdir() + (d / "other_session.pid").write_text("12345") + + _reap_orphaned_browser_sessions() + + # Should NOT be touched + assert d.exists() + + def test_corrupt_pid_file_is_cleaned(self, fake_tmpdir): + """PID file with non-integer content is cleaned up.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "h_corrupt1234") + (d / "h_corrupt1234.pid").write_text("not-a-number") + + _reap_orphaned_browser_sessions() + assert not d.exists() diff --git a/tools/browser_tool.py b/tools/browser_tool.py index ed3cfbb9b..bb2486606 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -473,13 +473,104 @@ def _cleanup_inactive_browser_sessions(): logger.warning("Error cleaning up inactive session %s: %s", task_id, e) +def _reap_orphaned_browser_sessions(): + """Scan for orphaned agent-browser daemon processes from previous runs. + + When the Python process that created a browser session exits uncleanly + (SIGKILL, crash, gateway restart), the in-memory ``_active_sessions`` + tracking is lost but the node + Chromium processes keep running. + + This function scans the tmp directory for ``agent-browser-*`` socket dirs + left behind by previous runs, reads the daemon PID files, and kills any + daemons that are still alive but not tracked by the current process. + + Called once on cleanup-thread startup — not every 30 seconds — to avoid + races with sessions being actively created. + """ + import glob + + tmpdir = _socket_safe_tmpdir() + pattern = os.path.join(tmpdir, "agent-browser-h_*") + socket_dirs = glob.glob(pattern) + # Also pick up CDP sessions + socket_dirs += glob.glob(os.path.join(tmpdir, "agent-browser-cdp_*")) + + if not socket_dirs: + return + + # Build set of session_names currently tracked by this process + with _cleanup_lock: + tracked_names = { + info.get("session_name") + for info in _active_sessions.values() + if info.get("session_name") + } + + reaped = 0 + for socket_dir in socket_dirs: + dir_name = os.path.basename(socket_dir) + # dir_name is "agent-browser-{session_name}" + session_name = dir_name.removeprefix("agent-browser-") + if not session_name: + continue + + # Skip sessions that we are actively tracking + if session_name in tracked_names: + continue + + pid_file = os.path.join(socket_dir, f"{session_name}.pid") + if not os.path.isfile(pid_file): + # No PID file — just a stale dir, remove it + shutil.rmtree(socket_dir, ignore_errors=True) + continue + + try: + daemon_pid = int(Path(pid_file).read_text().strip()) + except (ValueError, OSError): + shutil.rmtree(socket_dir, ignore_errors=True) + continue + + # Check if the daemon is still alive + try: + os.kill(daemon_pid, 0) # signal 0 = existence check + except ProcessLookupError: + # Already dead, just clean up the dir + shutil.rmtree(socket_dir, ignore_errors=True) + continue + except PermissionError: + # Alive but owned by someone else — leave it alone + continue + + # Daemon is alive and not tracked — orphan. Kill it. + try: + os.kill(daemon_pid, signal.SIGTERM) + logger.info("Reaped orphaned browser daemon PID %d (session %s)", + daemon_pid, session_name) + reaped += 1 + except (ProcessLookupError, PermissionError, OSError): + pass + + # Clean up the socket directory + shutil.rmtree(socket_dir, ignore_errors=True) + + if reaped: + logger.info("Reaped %d orphaned browser session(s) from previous run(s)", reaped) + + def _browser_cleanup_thread_worker(): """ Background thread that periodically cleans up inactive browser sessions. Runs every 30 seconds and checks for sessions that haven't been used within the BROWSER_SESSION_INACTIVITY_TIMEOUT period. + On first run, also reaps orphaned sessions from previous process lifetimes. """ + # One-time orphan reap on startup + try: + _reap_orphaned_browser_sessions() + except Exception as e: + logger.warning("Orphan reap error: %s", e) + while _cleanup_running: try: _cleanup_inactive_browser_sessions() From dfc820345d4e49d16fd70cdea2b22d1736229ad9 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:02:58 -0700 Subject: [PATCH 12/35] fix: scope tool interrupt signal per-thread to prevent cross-session leaks (#7930) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The interrupt mechanism in tools/interrupt.py used a process-global threading.Event. In the gateway, multiple agents run concurrently in the same process via run_in_executor. When any agent was interrupted (user sends a follow-up message), the global flag killed ALL agents' running tools — terminal commands, browser ops, web requests — across all sessions. Changes: - tools/interrupt.py: Replace single threading.Event with a set of interrupted thread IDs. set_interrupt() targets a specific thread; is_interrupted() checks the current thread. Includes a backward- compatible _ThreadAwareEventProxy for legacy _interrupt_event usage. - run_agent.py: Store execution thread ID at start of run_conversation(). interrupt() and clear_interrupt() pass it to set_interrupt() so only this agent's thread is affected. - tools/code_execution_tool.py: Use is_interrupted() instead of directly checking _interrupt_event.is_set(). - tools/process_registry.py: Same — use is_interrupted(). - tests: Update interrupt tests for per-thread semantics. Add new TestPerThreadInterruptIsolation with two tests verifying cross-thread isolation. --- run_agent.py | 16 +- tests/run_agent/test_interrupt_propagation.py | 148 ++++++++++++------ tests/tools/test_code_execution.py | 13 +- tools/code_execution_tool.py | 6 +- tools/interrupt.py | 74 +++++++-- tools/process_registry.py | 4 +- 6 files changed, 183 insertions(+), 78 deletions(-) diff --git a/run_agent.py b/run_agent.py index 21a896063..ba0a9f93d 100644 --- a/run_agent.py +++ b/run_agent.py @@ -739,6 +739,7 @@ class AIAgent: # Interrupt mechanism for breaking out of tool loops self._interrupt_requested = False self._interrupt_message = None # Optional message that triggered interrupt + self._execution_thread_id: int | None = None # Set at run_conversation() start self._client_lock = threading.RLock() # Subagent delegation state @@ -2832,8 +2833,10 @@ class AIAgent: """ self._interrupt_requested = True self._interrupt_message = message - # Signal all tools to abort any in-flight operations immediately - _set_interrupt(True) + # Signal all tools to abort any in-flight operations immediately. + # Scope the interrupt to this agent's execution thread so other + # agents running in the same process (gateway) are not affected. + _set_interrupt(True, self._execution_thread_id) # Propagate interrupt to any running child agents (subagent delegation) with self._active_children_lock: children_copy = list(self._active_children) @@ -2846,10 +2849,10 @@ class AIAgent: print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else "")) def clear_interrupt(self) -> None: - """Clear any pending interrupt request and the global tool interrupt signal.""" + """Clear any pending interrupt request and the per-thread tool interrupt signal.""" self._interrupt_requested = False self._interrupt_message = None - _set_interrupt(False) + _set_interrupt(False, self._execution_thread_id) def _touch_activity(self, desc: str) -> None: """Update the last-activity timestamp and description (thread-safe).""" @@ -7799,6 +7802,11 @@ class AIAgent: compression_attempts = 0 _turn_exit_reason = "unknown" # Diagnostic: why the loop ended + # Record the execution thread so interrupt()/clear_interrupt() can + # scope the tool-level interrupt signal to THIS agent's thread only. + # Must be set before clear_interrupt() which uses it. + self._execution_thread_id = threading.current_thread().ident + # Clear any stale interrupt state at start self.clear_interrupt() diff --git a/tests/run_agent/test_interrupt_propagation.py b/tests/run_agent/test_interrupt_propagation.py index 7f8cb01c3..a746efdac 100644 --- a/tests/run_agent/test_interrupt_propagation.py +++ b/tests/run_agent/test_interrupt_propagation.py @@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase): def tearDown(self): set_interrupt(False) + def _make_bare_agent(self): + """Create a bare AIAgent via __new__ with all interrupt-related attrs.""" + from run_agent import AIAgent + agent = AIAgent.__new__(AIAgent) + agent._interrupt_requested = False + agent._interrupt_message = None + agent._execution_thread_id = None # defaults to current thread in set_interrupt + agent._active_children = [] + agent._active_children_lock = threading.Lock() + agent.quiet_mode = True + return agent + def test_parent_interrupt_sets_child_flag(self): """When parent.interrupt() is called, child._interrupt_requested should be set.""" - from run_agent import AIAgent - - parent = AIAgent.__new__(AIAgent) - parent._interrupt_requested = False - parent._interrupt_message = None - parent._active_children = [] - parent._active_children_lock = threading.Lock() - parent.quiet_mode = True - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + parent = self._make_bare_agent() + child = self._make_bare_agent() parent._active_children.append(child) @@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase): assert child._interrupt_message == "new user message" assert is_interrupted() is True - def test_child_clear_interrupt_at_start_clears_global(self): - """child.clear_interrupt() at start of run_conversation clears the GLOBAL event. - - This is the intended behavior at startup, but verify it doesn't - accidentally clear an interrupt intended for a running child. + def test_child_clear_interrupt_at_start_clears_thread(self): + """child.clear_interrupt() at start of run_conversation clears the + per-thread interrupt flag for the current thread. """ - from run_agent import AIAgent - - child = AIAgent.__new__(AIAgent) + child = self._make_bare_agent() child._interrupt_requested = True child._interrupt_message = "msg" - child.quiet_mode = True - child._active_children = [] - child._active_children_lock = threading.Lock() - # Global is set + # Interrupt for current thread is set set_interrupt(True) assert is_interrupted() is True - # child.clear_interrupt() clears both + # child.clear_interrupt() clears both instance flag and thread flag child.clear_interrupt() assert child._interrupt_requested is False assert is_interrupted() is False def test_interrupt_during_child_api_call_detected(self): """Interrupt set during _interruptible_api_call is detected within 0.5s.""" - from run_agent import AIAgent - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + child = self._make_bare_agent() child.api_mode = "chat_completions" child.log_prefix = "" child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"} @@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase): def test_concurrent_interrupt_propagation(self): """Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts.""" - from run_agent import AIAgent - - parent = AIAgent.__new__(AIAgent) - parent._interrupt_requested = False - parent._interrupt_message = None - parent._active_children = [] - parent._active_children_lock = threading.Lock() - parent.quiet_mode = True - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + parent = self._make_bare_agent() + child = self._make_bare_agent() # Register child (simulating what _run_single_child does) parent._active_children.append(child) @@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase): set_interrupt(False) +class TestPerThreadInterruptIsolation(unittest.TestCase): + """Verify that interrupting one agent does NOT affect another agent's thread. + + This is the core fix for the gateway cross-session interrupt leak: + multiple agents run in separate threads within the same process, and + interrupting agent A must not kill agent B's running tools. + """ + + def setUp(self): + set_interrupt(False) + + def tearDown(self): + set_interrupt(False) + + def test_interrupt_only_affects_target_thread(self): + """set_interrupt(True, tid) only makes is_interrupted() True on that thread.""" + results = {} + barrier = threading.Barrier(2) + + def thread_a(): + """Agent A's execution thread — will be interrupted.""" + tid = threading.current_thread().ident + results["a_tid"] = tid + barrier.wait(timeout=5) # sync with thread B + time.sleep(0.2) # let the interrupt arrive + results["a_interrupted"] = is_interrupted() + + def thread_b(): + """Agent B's execution thread — should NOT be affected.""" + tid = threading.current_thread().ident + results["b_tid"] = tid + barrier.wait(timeout=5) # sync with thread A + time.sleep(0.2) + results["b_interrupted"] = is_interrupted() + + ta = threading.Thread(target=thread_a) + tb = threading.Thread(target=thread_b) + ta.start() + tb.start() + + # Wait for both threads to register their TIDs + time.sleep(0.05) + while "a_tid" not in results or "b_tid" not in results: + time.sleep(0.01) + + # Interrupt ONLY thread A (simulates gateway interrupting agent A) + set_interrupt(True, results["a_tid"]) + + ta.join(timeout=3) + tb.join(timeout=3) + + assert results["a_interrupted"] is True, "Thread A should see the interrupt" + assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt" + + def test_clear_interrupt_only_clears_target_thread(self): + """Clearing one thread's interrupt doesn't clear another's.""" + tid_a = 99990001 + tid_b = 99990002 + set_interrupt(True, tid_a) + set_interrupt(True, tid_b) + + # Clear only A + set_interrupt(False, tid_a) + + # Simulate checking from thread B's perspective + from tools.interrupt import _interrupted_threads, _lock + with _lock: + assert tid_a not in _interrupted_threads + assert tid_b in _interrupted_threads + + # Cleanup + set_interrupt(False, tid_b) + + if __name__ == "__main__": unittest.main() diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 33653c360..e015e5d42 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase): @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") class TestInterruptHandling(unittest.TestCase): def test_interrupt_event_stops_execution(self): - """When _interrupt_event is set, execute_code should stop the script.""" + """When interrupt is set for the execution thread, execute_code should stop.""" code = "import time; time.sleep(60); print('should not reach')" + from tools.interrupt import set_interrupt + + # Capture the main thread ID so we can target the interrupt correctly. + # execute_code runs in the current thread; set_interrupt needs its ID. + main_tid = threading.current_thread().ident def set_interrupt_after_delay(): import time as _t _t.sleep(1) - from tools.terminal_tool import _interrupt_event - _interrupt_event.set() + set_interrupt(True, main_tid) t = threading.Thread(target=set_interrupt_after_delay, daemon=True) t.start() @@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase): self.assertEqual(result["status"], "interrupted") self.assertIn("interrupted", result["output"]) finally: - from tools.terminal_tool import _interrupt_event - _interrupt_event.clear() + set_interrupt(False, main_tid) t.join(timeout=3) diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 7837d70d6..d6c561e2c 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -924,8 +924,8 @@ def execute_code( # --- Local execution path (UDS) --- below this line is unchanged --- - # Import interrupt event from terminal_tool (cooperative cancellation) - from tools.terminal_tool import _interrupt_event + # Import per-thread interrupt check (cooperative cancellation) + from tools.interrupt import is_interrupted as _is_interrupted # Resolve config _cfg = _load_config() @@ -1114,7 +1114,7 @@ def execute_code( status = "success" while proc.poll() is None: - if _interrupt_event.is_set(): + if _is_interrupted(): _kill_process_group(proc) status = "interrupted" break diff --git a/tools/interrupt.py b/tools/interrupt.py index e5c9b1e27..9bc8b83ae 100644 --- a/tools/interrupt.py +++ b/tools/interrupt.py @@ -1,8 +1,12 @@ -"""Shared interrupt signaling for all tools. +"""Per-thread interrupt signaling for all tools. -Provides a global threading.Event that any tool can check to determine -if the user has requested an interrupt. The agent's interrupt() method -sets this event, and tools poll it during long-running operations. +Provides thread-scoped interrupt tracking so that interrupting one agent +session does not kill tools running in other sessions. This is critical +in the gateway where multiple agents run concurrently in the same process. + +The agent stores its execution thread ID at the start of run_conversation() +and passes it to set_interrupt()/clear_interrupt(). Tools call +is_interrupted() which checks the CURRENT thread — no argument needed. Usage in tools: from tools.interrupt import is_interrupted @@ -12,17 +16,61 @@ Usage in tools: import threading -_interrupt_event = threading.Event() +# Set of thread idents that have been interrupted. +_interrupted_threads: set[int] = set() +_lock = threading.Lock() -def set_interrupt(active: bool) -> None: - """Called by the agent to signal or clear the interrupt.""" - if active: - _interrupt_event.set() - else: - _interrupt_event.clear() +def set_interrupt(active: bool, thread_id: int | None = None) -> None: + """Set or clear interrupt for a specific thread. + + Args: + active: True to signal interrupt, False to clear it. + thread_id: Target thread ident. When None, targets the + current thread (backward compat for CLI/tests). + """ + tid = thread_id if thread_id is not None else threading.current_thread().ident + with _lock: + if active: + _interrupted_threads.add(tid) + else: + _interrupted_threads.discard(tid) def is_interrupted() -> bool: - """Check if an interrupt has been requested. Safe to call from any thread.""" - return _interrupt_event.is_set() + """Check if an interrupt has been requested for the current thread. + + Safe to call from any thread — each thread only sees its own + interrupt state. + """ + tid = threading.current_thread().ident + with _lock: + return tid in _interrupted_threads + + +# --------------------------------------------------------------------------- +# Backward-compatible _interrupt_event proxy +# --------------------------------------------------------------------------- +# Some legacy call sites (code_execution_tool, process_registry, tests) +# import _interrupt_event directly and call .is_set() / .set() / .clear(). +# This shim maps those calls to the per-thread functions above so existing +# code keeps working while the underlying mechanism is thread-scoped. + +class _ThreadAwareEventProxy: + """Drop-in proxy that maps threading.Event methods to per-thread state.""" + + def is_set(self) -> bool: + return is_interrupted() + + def set(self) -> None: # noqa: A003 + set_interrupt(True) + + def clear(self) -> None: + set_interrupt(False) + + def wait(self, timeout: float | None = None) -> bool: + """Not truly supported — returns current state immediately.""" + return self.is_set() + + +_interrupt_event = _ThreadAwareEventProxy() diff --git a/tools/process_registry.py b/tools/process_registry.py index 1761221f0..044a4e776 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -686,7 +686,7 @@ class ProcessRegistry: and output snapshot. """ from tools.ansi_strip import strip_ansi - from tools.terminal_tool import _interrupt_event + from tools.interrupt import is_interrupted as _is_interrupted try: default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180")) @@ -723,7 +723,7 @@ class ProcessRegistry: result["timeout_note"] = timeout_note return result - if _interrupt_event.is_set(): + if _is_interrupted(): result = { "status": "interrupted", "output": strip_ansi(session.output_buffer[-1000:]), From cc4b1f0007925f48233c96f9656975e6dfa00c11 Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 14:03:17 -0700 Subject: [PATCH 13/35] fix(whatsapp): pin Baileys to fix/abprops-abt-fetch for bad-request fix WhatsApp changed their server protocol for property queries, causing 400 bad-request errors in fetchProps/executeInitQueries on every reconnect (Baileys issue #2477). The fix in PR #2473 changes the IQ namespace from 'w' to 'abt' and protocol from '2' to '1'. Pin to the fix branch until the next Baileys release includes it. --- scripts/whatsapp-bridge/package-lock.json | 15 +++++++++++---- scripts/whatsapp-bridge/package.json | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/scripts/whatsapp-bridge/package-lock.json b/scripts/whatsapp-bridge/package-lock.json index 01af1c15a..23ea30a09 100644 --- a/scripts/whatsapp-bridge/package-lock.json +++ b/scripts/whatsapp-bridge/package-lock.json @@ -8,7 +8,7 @@ "name": "hermes-whatsapp-bridge", "version": "1.0.0", "dependencies": { - "@whiskeysockets/baileys": "7.0.0-rc.9", + "@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch", "express": "^4.21.0", "pino": "^9.0.0", "qrcode-terminal": "^0.12.0" @@ -730,21 +730,22 @@ } }, "node_modules/@whiskeysockets/baileys": { + "name": "baileys", "version": "7.0.0-rc.9", - "resolved": "https://registry.npmjs.org/@whiskeysockets/baileys/-/baileys-7.0.0-rc.9.tgz", - "integrity": "sha512-YFm5gKXfDP9byCXCW3OPHKXLzrAKzolzgVUlRosHHgwbnf2YOO3XknkMm6J7+F0ns8OA0uuSBhgkRHTDtqkacw==", + "resolved": "git+ssh://git@github.com/WhiskeySockets/Baileys.git#01047debd81beb20da7b7779b08edcb06aa03770", "hasInstallScript": true, "license": "MIT", "dependencies": { "@cacheable/node-cache": "^1.4.0", "@hapi/boom": "^9.1.3", "async-mutex": "^0.5.0", - "libsignal": "git+https://github.com/whiskeysockets/libsignal-node.git", + "libsignal": "git+https://github.com/whiskeysockets/libsignal-node", "lru-cache": "^11.1.0", "music-metadata": "^11.7.0", "p-queue": "^9.0.0", "pino": "^9.6", "protobufjs": "^7.2.4", + "whatsapp-rust-bridge": "0.5.2", "ws": "^8.13.0" }, "engines": { @@ -2125,6 +2126,12 @@ "node": ">= 0.8" } }, + "node_modules/whatsapp-rust-bridge": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/whatsapp-rust-bridge/-/whatsapp-rust-bridge-0.5.2.tgz", + "integrity": "sha512-6KBRNvxg6WMIwZ/euA8qVzj16qxMBzLllfmaJIP1JGAAfSvwn6nr8JDOMXeqpXPEOl71UfOG+79JwKEoT2b1Fw==", + "license": "MIT" + }, "node_modules/win-guid": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/win-guid/-/win-guid-0.2.1.tgz", diff --git a/scripts/whatsapp-bridge/package.json b/scripts/whatsapp-bridge/package.json index 7db81f699..2d32560f4 100644 --- a/scripts/whatsapp-bridge/package.json +++ b/scripts/whatsapp-bridge/package.json @@ -8,7 +8,7 @@ "start": "node bridge.js" }, "dependencies": { - "@whiskeysockets/baileys": "7.0.0-rc.9", + "@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch", "express": "^4.21.0", "qrcode-terminal": "^0.12.0", "pino": "^9.0.0" From c22bffc92e4b7ddd44b1c76ccbbac0809db00646 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Sat, 11 Apr 2026 15:11:42 -0300 Subject: [PATCH 14/35] add basic twilio signature checking and tests --- gateway/platforms/sms.py | 66 +++++++++++- tests/gateway/test_sms.py | 206 +++++++++++++++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 5 deletions(-) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index 953ec5c5e..bdd64d179 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -10,6 +10,8 @@ Shares credentials with the optional telephony skill — same env vars: Gateway-specific env vars: - SMS_WEBHOOK_PORT (default 8080) + - SMS_WEBHOOK_HOST (default 0.0.0.0) + - SMS_WEBHOOK_URL (public URL for Twilio signature validation) - SMS_ALLOWED_USERS (comma-separated E.164 phone numbers) - SMS_ALLOW_ALL_USERS (true/false) - SMS_HOME_CHANNEL (phone number for cron delivery) @@ -17,6 +19,8 @@ Gateway-specific env vars: import asyncio import base64 +import hashlib +import hmac import logging import os import urllib.parse @@ -28,6 +32,7 @@ from gateway.platforms.base import ( MessageEvent, MessageType, SendResult, + is_network_accessible, ) from gateway.platforms.helpers import redact_phone, strip_markdown @@ -36,6 +41,7 @@ logger = logging.getLogger(__name__) TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts" MAX_SMS_LENGTH = 1600 # ~10 SMS segments DEFAULT_WEBHOOK_PORT = 8080 +DEFAULT_WEBHOOK_HOST = "0.0.0.0" def check_sms_requirements() -> bool: @@ -65,6 +71,8 @@ class SmsAdapter(BasePlatformAdapter): self._webhook_port: int = int( os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT)) ) + self._webhook_host: str = os.getenv("SMS_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST) + self._webhook_url: str = os.getenv("SMS_WEBHOOK_URL", "").strip() self._runner = None self._http_session: Optional["aiohttp.ClientSession"] = None @@ -86,13 +94,21 @@ class SmsAdapter(BasePlatformAdapter): logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies") return False + if not self._webhook_url: + logger.warning( + "[sms] SMS_WEBHOOK_URL not set — Twilio signature validation is " + "DISABLED. Any client that can reach port %d can inject messages. " + "Set SMS_WEBHOOK_URL to enable signature validation.", + self._webhook_port, + ) + app = web.Application() app.router.add_post("/webhooks/twilio", self._handle_webhook) app.router.add_get("/health", lambda _: web.Response(text="ok")) self._runner = web.AppRunner(app) await self._runner.setup() - site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port) + site = web.TCPSite(self._runner, self._webhook_host, self._webhook_port) await site.start() self._http_session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=30), @@ -100,7 +116,8 @@ class SmsAdapter(BasePlatformAdapter): self._running = True logger.info( - "[sms] Twilio webhook server listening on port %d, from: %s", + "[sms] Twilio webhook server listening on %s:%d, from: %s", + self._webhook_host, self._webhook_port, redact_phone(self._from_number), ) @@ -182,6 +199,28 @@ class SmsAdapter(BasePlatformAdapter): """Strip markdown — SMS renders it as literal characters.""" return strip_markdown(content) + # ------------------------------------------------------------------ + # Twilio signature validation + # ------------------------------------------------------------------ + + def _validate_twilio_signature( + self, url: str, post_params: dict, signature: str, + ) -> bool: + """Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64). + + Algorithm: https://www.twilio.com/docs/usage/security#validating-requests + """ + data_to_sign = url + for key in sorted(post_params.keys()): + data_to_sign += key + post_params[key] + mac = hmac.new( + self._auth_token.encode("utf-8"), + data_to_sign.encode("utf-8"), + hashlib.sha1, + ) + computed = base64.b64encode(mac.digest()).decode("utf-8") + return hmac.compare_digest(computed, signature) + # ------------------------------------------------------------------ # Twilio webhook handler # ------------------------------------------------------------------ @@ -192,7 +231,7 @@ class SmsAdapter(BasePlatformAdapter): try: raw = await request.read() # Twilio sends form-encoded data, not JSON - form = urllib.parse.parse_qs(raw.decode("utf-8")) + form = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True) except Exception as e: logger.error("[sms] webhook parse error: %s", e) return web.Response( @@ -201,6 +240,27 @@ class SmsAdapter(BasePlatformAdapter): status=400, ) + # Validate Twilio request signature when SMS_WEBHOOK_URL is configured + if self._webhook_url: + twilio_sig = request.headers.get("X-Twilio-Signature", "") + if not twilio_sig: + logger.warning("[sms] Rejected: missing X-Twilio-Signature header") + return web.Response( + text='', + content_type="application/xml", + status=403, + ) + flat_params = {k: v[0] for k, v in form.items() if v} + if not self._validate_twilio_signature( + self._webhook_url, flat_params, twilio_sig + ): + logger.warning("[sms] Rejected: invalid Twilio signature") + return web.Response( + text='', + content_type="application/xml", + status=403, + ) + # Extract fields (parse_qs returns lists) from_number = (form.get("From", [""]))[0].strip() to_number = (form.get("To", [""]))[0].strip() diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index 54c1edf23..cfe06df98 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -1,11 +1,14 @@ """Tests for SMS (Twilio) platform integration. Covers config loading, format/truncate, echo prevention, -requirements check, and toolset verification. +requirements check, toolset verification, and Twilio signature validation. """ +import base64 +import hashlib +import hmac import os -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -213,3 +216,202 @@ class TestSmsToolset: from tools.cronjob_tools import CRONJOB_SCHEMA deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"] assert "sms" in deliver_desc.lower() + + +# ── Webhook host configuration ───────────────────────────────────── + +class TestWebhookHostConfig: + """Verify SMS_WEBHOOK_HOST env var and default.""" + + def test_default_host_is_all_interfaces(self): + from gateway.platforms.sms import DEFAULT_WEBHOOK_HOST + assert DEFAULT_WEBHOOK_HOST == "0.0.0.0" + + def test_host_from_env(self): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_HOST": "127.0.0.1", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._webhook_host == "127.0.0.1" + + def test_webhook_url_from_env(self): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._webhook_url == "https://example.com/webhooks/twilio" + + def test_webhook_url_stripped(self): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_URL": " https://example.com/webhooks/twilio ", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._webhook_url == "https://example.com/webhooks/twilio" + + +# ── Twilio signature validation ──────────────────────────────────── + +def _compute_twilio_signature(auth_token, url, params): + """Reference implementation of Twilio's signature algorithm.""" + data_to_sign = url + for key in sorted(params.keys()): + data_to_sign += key + params[key] + mac = hmac.new( + auth_token.encode("utf-8"), + data_to_sign.encode("utf-8"), + hashlib.sha1, + ) + return base64.b64encode(mac.digest()).decode("utf-8") + + +class TestTwilioSignatureValidation: + """Unit tests for SmsAdapter._validate_twilio_signature.""" + + def _make_adapter(self, auth_token="test_token_secret"): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": auth_token, + "TWILIO_PHONE_NUMBER": "+15550001111", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key=auth_token) + adapter = SmsAdapter(pc) + return adapter + + def test_valid_signature_accepted(self): + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "hello", "To": "+15550001111"} + sig = _compute_twilio_signature("test_token_secret", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is True + + def test_invalid_signature_rejected(self): + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "hello"} + assert adapter._validate_twilio_signature(url, params, "badsig") is False + + def test_wrong_token_rejected(self): + adapter = self._make_adapter(auth_token="correct_token") + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature("wrong_token", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is False + + def test_params_sorted_by_key(self): + """Signature must be computed with params sorted alphabetically.""" + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"Zebra": "last", "Alpha": "first", "Middle": "mid"} + sig = _compute_twilio_signature("test_token_secret", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is True + + def test_empty_param_values_included(self): + """Blank values must be included in signature computation.""" + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "", "SmsStatus": "received"} + sig = _compute_twilio_signature("test_token_secret", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is True + + def test_url_matters(self): + """Different URLs produce different signatures.""" + adapter = self._make_adapter() + params = {"Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://a.com/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://b.com/webhooks/twilio", params, sig + ) is False + + +# ── Webhook signature enforcement (handler-level) ────────────────── + +class TestWebhookSignatureEnforcement: + """Integration tests for signature validation in _handle_webhook.""" + + def _make_adapter(self, webhook_url=""): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "test_token_secret", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_URL": webhook_url, + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="test_token_secret") + adapter = SmsAdapter(pc) + adapter._message_handler = AsyncMock() + return adapter + + def _mock_request(self, body, headers=None): + request = MagicMock() + request.read = AsyncMock(return_value=body) + request.headers = headers or {} + return request + + @pytest.mark.asyncio + async def test_no_webhook_url_skips_validation(self): + """Without SMS_WEBHOOK_URL, all requests are accepted.""" + adapter = self._make_adapter(webhook_url="") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body) + resp = await adapter._handle_webhook(request) + assert resp.status == 200 + + @pytest.mark.asyncio + async def test_missing_signature_returns_403(self): + adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={}) + resp = await adapter._handle_webhook(request) + assert resp.status == 403 + + @pytest.mark.asyncio + async def test_invalid_signature_returns_403(self): + adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={"X-Twilio-Signature": "invalid"}) + resp = await adapter._handle_webhook(request) + assert resp.status == 403 + + @pytest.mark.asyncio + async def test_valid_signature_returns_200(self): + webhook_url = "https://example.com/webhooks/twilio" + adapter = self._make_adapter(webhook_url=webhook_url) + params = { + "From": "+15551234567", + "To": "+15550001111", + "Body": "hello", + "MessageSid": "SM123", + } + sig = _compute_twilio_signature("test_token_secret", webhook_url, params) + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={"X-Twilio-Signature": sig}) + resp = await adapter._handle_webhook(request) + assert resp.status == 200 From ad1e8804a60e4548dc125bb2bb64a9c178aba3f0 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Sat, 11 Apr 2026 15:15:33 -0300 Subject: [PATCH 15/35] handle port variants in Twilio signatures --- gateway/platforms/sms.py | 46 ++++++++++++++++++++++++++++ tests/gateway/test_sms.py | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index bdd64d179..5bc66d881 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -208,8 +208,24 @@ class SmsAdapter(BasePlatformAdapter): ) -> bool: """Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64). + Tries both with and without the default port for the URL scheme, + since Twilio may sign with either variant. + Algorithm: https://www.twilio.com/docs/usage/security#validating-requests """ + if self._check_signature(url, post_params, signature): + return True + + variant = self._port_variant_url(url) + if variant and self._check_signature(variant, post_params, signature): + return True + + return False + + def _check_signature( + self, url: str, post_params: dict, signature: str, + ) -> bool: + """Compute and compare a single Twilio signature.""" data_to_sign = url for key in sorted(post_params.keys()): data_to_sign += key + post_params[key] @@ -221,6 +237,36 @@ class SmsAdapter(BasePlatformAdapter): computed = base64.b64encode(mac.digest()).decode("utf-8") return hmac.compare_digest(computed, signature) + @staticmethod + def _port_variant_url(url: str) -> str | None: + """Return the URL with the default port toggled, or None. + + Only toggles default ports (443 for https, 80 for http). + Non-standard ports are never modified. + """ + parsed = urllib.parse.urlparse(url) + default_ports = {"https": 443, "http": 80} + default_port = default_ports.get(parsed.scheme) + if default_port is None: + return None + + if parsed.port == default_port: + # Has explicit default port → strip it + return urllib.parse.urlunparse( + (parsed.scheme, parsed.hostname, parsed.path, + parsed.params, parsed.query, parsed.fragment) + ) + elif parsed.port is None: + # No port → add default + netloc = f"{parsed.hostname}:{default_port}" + return urllib.parse.urlunparse( + (parsed.scheme, netloc, parsed.path, + parsed.params, parsed.query, parsed.fragment) + ) + + # Non-standard port — no variant + return None + # ------------------------------------------------------------------ # Twilio webhook handler # ------------------------------------------------------------------ diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index cfe06df98..670e50693 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -348,6 +348,50 @@ class TestTwilioSignatureValidation: "https://b.com/webhooks/twilio", params, sig ) is False + def test_port_variant_443_matches_without_port(self): + """Signature for https URL with :443 validates against URL without port.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com:443/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://example.com/webhooks/twilio", params, sig + ) is True + + def test_port_variant_without_port_matches_443(self): + """Signature for https URL without port validates against URL with :443.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://example.com:443/webhooks/twilio", params, sig + ) is True + + def test_non_standard_port_no_variant(self): + """Non-standard port must NOT match URL without port.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://example.com:8080/webhooks/twilio", params, sig + ) is False + + def test_port_variant_http_80(self): + """Port variant also works for http with port 80.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "http://example.com:80/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "http://example.com/webhooks/twilio", params, sig + ) is True + # ── Webhook signature enforcement (handler-level) ────────────────── @@ -415,3 +459,22 @@ class TestWebhookSignatureEnforcement: request = self._mock_request(body, headers={"X-Twilio-Signature": sig}) resp = await adapter._handle_webhook(request) assert resp.status == 200 + + @pytest.mark.asyncio + async def test_port_variant_signature_returns_200(self): + """Signature computed with :443 should pass when URL configured without port.""" + webhook_url = "https://example.com/webhooks/twilio" + adapter = self._make_adapter(webhook_url=webhook_url) + params = { + "From": "+15551234567", + "To": "+15550001111", + "Body": "hello", + "MessageSid": "SM123", + } + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com:443/webhooks/twilio", params + ) + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={"X-Twilio-Signature": sig}) + resp = await adapter._handle_webhook(request) + assert resp.status == 200 From 8ce6aaac235f654d1ecd5f2559d3d6075eb2b78d Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Sat, 11 Apr 2026 15:46:24 -0300 Subject: [PATCH 16/35] change Twilio signature verification from opt-in to opt-out --- gateway/platforms/sms.py | 23 +++++++++++++---- tests/gateway/test_sms.py | 52 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index 5bc66d881..3246b3075 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -11,7 +11,8 @@ Shares credentials with the optional telephony skill — same env vars: Gateway-specific env vars: - SMS_WEBHOOK_PORT (default 8080) - SMS_WEBHOOK_HOST (default 0.0.0.0) - - SMS_WEBHOOK_URL (public URL for Twilio signature validation) + - SMS_WEBHOOK_URL (public URL for Twilio signature validation — required) + - SMS_INSECURE_NO_SIGNATURE (true to disable signature validation — dev only) - SMS_ALLOWED_USERS (comma-separated E.164 phone numbers) - SMS_ALLOW_ALL_USERS (true/false) - SMS_HOME_CHANNEL (phone number for cron delivery) @@ -94,11 +95,23 @@ class SmsAdapter(BasePlatformAdapter): logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies") return False - if not self._webhook_url: + insecure_no_sig = os.getenv("SMS_INSECURE_NO_SIGNATURE", "").lower() == "true" + + if not self._webhook_url and not insecure_no_sig: + logger.error( + "[sms] Refusing to start: SMS_WEBHOOK_URL is required for Twilio " + "signature validation. Set it to the public URL configured in your " + "Twilio console (e.g. https://example.com/webhooks/twilio). " + "For local development without validation, set " + "SMS_INSECURE_NO_SIGNATURE=true (NOT recommended for production).", + ) + return False + + if insecure_no_sig: logger.warning( - "[sms] SMS_WEBHOOK_URL not set — Twilio signature validation is " - "DISABLED. Any client that can reach port %d can inject messages. " - "Set SMS_WEBHOOK_URL to enable signature validation.", + "[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation " + "is DISABLED. Any client that can reach port %d can inject messages. " + "Do NOT use this in production.", self._webhook_port, ) diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index 670e50693..7a310d109 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -270,6 +270,50 @@ class TestWebhookHostConfig: assert adapter._webhook_url == "https://example.com/webhooks/twilio" +# ── Startup guard (fail-closed) ──────────────────────────────────── + +class TestStartupGuard: + """Adapter must refuse to start without SMS_WEBHOOK_URL.""" + + def _make_adapter(self, extra_env=None): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + } + if extra_env: + env.update(extra_env) + with patch.dict(os.environ, env, clear=False): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + return adapter + + @pytest.mark.asyncio + async def test_refuses_start_without_webhook_url(self): + adapter = self._make_adapter() + result = await adapter.connect() + assert result is False + + @pytest.mark.asyncio + async def test_insecure_flag_allows_start_without_url(self): + with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}): + adapter = self._make_adapter() + result = await adapter.connect() + assert result is True + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_webhook_url_allows_start(self): + adapter = self._make_adapter( + extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"} + ) + result = await adapter.connect() + assert result is True + await adapter.disconnect() + + # ── Twilio signature validation ──────────────────────────────────── def _compute_twilio_signature(auth_token, url, params): @@ -420,9 +464,11 @@ class TestWebhookSignatureEnforcement: return request @pytest.mark.asyncio - async def test_no_webhook_url_skips_validation(self): - """Without SMS_WEBHOOK_URL, all requests are accepted.""" - adapter = self._make_adapter(webhook_url="") + async def test_insecure_flag_skips_validation(self): + """With SMS_INSECURE_NO_SIGNATURE=true and no URL, requests are accepted.""" + env = {"SMS_INSECURE_NO_SIGNATURE": "true"} + with patch.dict(os.environ, env): + adapter = self._make_adapter(webhook_url="") body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" request = self._mock_request(body) resp = await adapter._handle_webhook(request) From 0970f1de5048db71249de3e78798871ec509b1c9 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Sat, 11 Apr 2026 16:09:27 -0300 Subject: [PATCH 17/35] update docks with changes made --- .../docs/reference/environment-variables.md | 5 ++- website/docs/user-guide/messaging/sms.md | 31 +++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index 958faa61f..a548a6ff6 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -195,9 +195,12 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI | `SIGNAL_IGNORE_STORIES` | Ignore Signal stories/status updates | | `SIGNAL_ALLOW_ALL_USERS` | Allow all Signal users without an allowlist | | `TWILIO_ACCOUNT_SID` | Twilio Account SID (shared with telephony skill) | -| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill) | +| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill; also used for webhook signature validation) | | `TWILIO_PHONE_NUMBER` | Twilio phone number in E.164 format (shared with telephony skill) | +| `SMS_WEBHOOK_URL` | Public URL for Twilio signature validation — must match the webhook URL in Twilio Console (required) | | `SMS_WEBHOOK_PORT` | Webhook listener port for inbound SMS (default: `8080`) | +| `SMS_WEBHOOK_HOST` | Webhook bind address (default: `0.0.0.0`) | +| `SMS_INSECURE_NO_SIGNATURE` | Set to `true` to disable Twilio signature validation (local dev only — not for production) | | `SMS_ALLOWED_USERS` | Comma-separated E.164 phone numbers allowed to chat | | `SMS_ALLOW_ALL_USERS` | Allow all SMS senders without an allowlist | | `SMS_HOME_CHANNEL` | Phone number for cron job / notification delivery | diff --git a/website/docs/user-guide/messaging/sms.md b/website/docs/user-guide/messaging/sms.md index 84a3b8fa2..c5b28cd6f 100644 --- a/website/docs/user-guide/messaging/sms.md +++ b/website/docs/user-guide/messaging/sms.md @@ -84,6 +84,13 @@ ngrok http 8080 Set the resulting public URL as your Twilio webhook. ::: +**Set `SMS_WEBHOOK_URL` to the same URL you configured in Twilio.** This is required for Twilio signature validation — the adapter will refuse to start without it: + +```bash +# Must match the webhook URL in your Twilio Console +SMS_WEBHOOK_URL=https://your-server:8080/webhooks/twilio +``` + The webhook port defaults to `8080`. Override with: ```bash @@ -101,9 +108,11 @@ hermes gateway You should see: ``` -[sms] Twilio webhook server listening on port 8080, from: +1555***4567 +[sms] Twilio webhook server listening on 0.0.0.0:8080, from: +1555***4567 ``` +If you see `Refusing to start: SMS_WEBHOOK_URL is required`, set `SMS_WEBHOOK_URL` to the public URL configured in your Twilio Console (see Step 3). + Text your Twilio number — Hermes will respond via SMS. --- @@ -113,9 +122,12 @@ Text your Twilio number — Hermes will respond via SMS. | Variable | Required | Description | |----------|----------|-------------| | `TWILIO_ACCOUNT_SID` | Yes | Twilio Account SID (starts with `AC`) | -| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token | +| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token (also used for webhook signature validation) | | `TWILIO_PHONE_NUMBER` | Yes | Your Twilio phone number (E.164 format) | +| `SMS_WEBHOOK_URL` | Yes | Public URL for Twilio signature validation — must match the webhook URL in your Twilio Console | | `SMS_WEBHOOK_PORT` | No | Webhook listener port (default: `8080`) | +| `SMS_WEBHOOK_HOST` | No | Webhook bind address (default: `0.0.0.0`) | +| `SMS_INSECURE_NO_SIGNATURE` | No | Set to `true` to disable signature validation (local dev only — **not for production**) | | `SMS_ALLOWED_USERS` | No | Comma-separated E.164 phone numbers allowed to chat | | `SMS_ALLOW_ALL_USERS` | No | Set to `true` to allow anyone (not recommended) | | `SMS_HOME_CHANNEL` | No | Phone number for cron job / notification delivery | @@ -134,6 +146,21 @@ Text your Twilio number — Hermes will respond via SMS. ## Security +### Webhook signature validation + +Hermes validates that inbound webhooks genuinely originate from Twilio by verifying the `X-Twilio-Signature` header (HMAC-SHA1). This prevents attackers from injecting forged messages. + +**`SMS_WEBHOOK_URL` is required.** Set it to the public URL configured in your Twilio Console. The adapter will refuse to start without it. + +For local development without a public URL, you can disable validation: + +```bash +# Local dev only — NOT for production +SMS_INSECURE_NO_SIGNATURE=true +``` + +### User allowlists + **The gateway denies all users by default.** Configure an allowlist: ```bash From d0538457036aeac1fa57d503aa634ea5c57e63a6 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Sat, 11 Apr 2026 16:25:14 -0300 Subject: [PATCH 18/35] remove unused import and fix misleading log --- gateway/platforms/sms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index 3246b3075..161949dab 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -33,7 +33,6 @@ from gateway.platforms.base import ( MessageEvent, MessageType, SendResult, - is_network_accessible, ) from gateway.platforms.helpers import redact_phone, strip_markdown @@ -107,7 +106,7 @@ class SmsAdapter(BasePlatformAdapter): ) return False - if insecure_no_sig: + if insecure_no_sig and not self._webhook_url: logger.warning( "[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation " "is DISABLED. Any client that can reach port %d can inject messages. " From 0a922bf218c4801abab14b9c9683903e91cd2e7c Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Sat, 11 Apr 2026 16:29:04 -0300 Subject: [PATCH 19/35] add new test covering edge case where both insecure_no_sig and _webhook_url are set --- tests/gateway/test_sms.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index 7a310d109..dbdb3b42b 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -474,6 +474,16 @@ class TestWebhookSignatureEnforcement: resp = await adapter._handle_webhook(request) assert resp.status == 200 + @pytest.mark.asyncio + async def test_insecure_flag_with_url_still_validates(self): + """When both SMS_WEBHOOK_URL and SMS_INSECURE_NO_SIGNATURE are set, + validation stays active (URL takes precedence).""" + adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={}) + resp = await adapter._handle_webhook(request) + assert resp.status == 403 + @pytest.mark.asyncio async def test_missing_signature_returns_403(self): adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") From b0892375cd260a8d3e40af15002a4d498a7d19c1 Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 14:03:51 -0700 Subject: [PATCH 20/35] fix: mock aiohttp server in startup guard tests to avoid port binding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The startup guard tests called connect() which bound a real aiohttp server on port 8080 — flaky in any environment where the port is in use. Mock AppRunner, TCPSite, and ClientSession instead. --- tests/gateway/test_sms.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index dbdb3b42b..d8a1589bd 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -298,7 +298,14 @@ class TestStartupGuard: @pytest.mark.asyncio async def test_insecure_flag_allows_start_without_url(self): - with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}): + mock_session = AsyncMock() + with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}), \ + patch("aiohttp.web.AppRunner") as mock_runner_cls, \ + patch("aiohttp.web.TCPSite") as mock_site_cls, \ + patch("aiohttp.ClientSession", return_value=mock_session): + mock_runner_cls.return_value.setup = AsyncMock() + mock_runner_cls.return_value.cleanup = AsyncMock() + mock_site_cls.return_value.start = AsyncMock() adapter = self._make_adapter() result = await adapter.connect() assert result is True @@ -306,12 +313,19 @@ class TestStartupGuard: @pytest.mark.asyncio async def test_webhook_url_allows_start(self): - adapter = self._make_adapter( - extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"} - ) - result = await adapter.connect() - assert result is True - await adapter.disconnect() + mock_session = AsyncMock() + with patch("aiohttp.web.AppRunner") as mock_runner_cls, \ + patch("aiohttp.web.TCPSite") as mock_site_cls, \ + patch("aiohttp.ClientSession", return_value=mock_session): + mock_runner_cls.return_value.setup = AsyncMock() + mock_runner_cls.return_value.cleanup = AsyncMock() + mock_site_cls.return_value.start = AsyncMock() + adapter = self._make_adapter( + extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"} + ) + result = await adapter.connect() + assert result is True + await adapter.disconnect() # ── Twilio signature validation ──────────────────────────────────── From 0e6354df5077c7e020671e80c3c9f6e585f7e8b3 Mon Sep 17 00:00:00 2001 From: 0xFrank-eth <0xFrank-eth@users.noreply.github.com> Date: Sat, 11 Apr 2026 22:53:08 +0300 Subject: [PATCH 21/35] fix(custom-providers): propagate model field from config to runtime so API receives the correct model name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #7828 When a custom_providers entry carries a `model` field, that value was silently dropped by `_get_named_custom_provider` and `_resolve_named_custom_runtime`. Callers received a runtime dict with `base_url`, `api_key`, and `api_mode` — but no `model`. As a result, `hermes chat --model ` sent the *provider name* (e.g. "my-dashscope-provider") as the model string to the API instead of the configured model (e.g. "qwen3.6-plus"), producing: Error code: 400 - {'error': {'message': 'Model Not Exist'}} Setting the provider as the *default* model in config.yaml worked because that path writes `model.default` and the agent reads it back directly, bypassing the broken runtime resolution path. Changes: 1. hermes_cli/runtime_provider.py — _get_named_custom_provider() Reads `entry.get("model")` and includes it in the result dict so the value is available to callers. 2. hermes_cli/runtime_provider.py — _resolve_named_custom_runtime() Propagates `custom_provider["model"]` into the returned runtime dict. 3. cli.py — _ensure_runtime_credentials() After resolving runtime, if `runtime["model"]` is set, assign it to `self.model` so the AIAgent is initialised with the correct model name rather than the provider name the user typed on the CLI. Co-Authored-By: Claude Sonnet 4.6 --- cli.py | 9 +++++++++ hermes_cli/runtime_provider.py | 10 +++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/cli.py b/cli.py index 0969a060b..ff80a49b8 100644 --- a/cli.py +++ b/cli.py @@ -2710,6 +2710,15 @@ class HermesCLI: self.api_key = api_key self.base_url = base_url + # When a custom_provider entry carries an explicit `model` field, + # use it as the effective model name. Without this, running + # `hermes chat --model ` sends the provider name + # (e.g. "my-provider") as the model string to the API instead of + # the configured model (e.g. "qwen3.6-plus"), causing 400 errors. + runtime_model = runtime.get("model") + if runtime_model and isinstance(runtime_model, str): + self.model = runtime_model + # Normalize model for the resolved provider (e.g. swap non-Codex # models when provider is openai-codex). Fixes #651. model_changed = self._normalize_model_for_provider(resolved_provider) diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index 3d1333c26..c3fcd3aae 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -304,6 +304,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An api_mode = _parse_api_mode(entry.get("api_mode")) if api_mode: result["api_mode"] = api_mode + model_name = str(entry.get("model", "") or "").strip() + if model_name: + result["model"] = model_name return result return None @@ -339,7 +342,7 @@ def _resolve_named_custom_runtime( ] api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "") - return { + result = { "provider": "custom", "api_mode": custom_provider.get("api_mode") or _detect_api_mode_for_url(base_url) @@ -348,6 +351,11 @@ def _resolve_named_custom_runtime( "api_key": api_key or "no-key-required", "source": f"custom_provider:{custom_provider.get('name', requested_provider)}", } + # Propagate the model name so callers can override self.model when the + # provider name differs from the actual model string the API expects. + if custom_provider.get("model"): + result["model"] = custom_provider["model"] + return result def _resolve_openrouter_runtime( From 4bede272cf4879cd2922126d68342cef7069fecc Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 14:07:12 -0700 Subject: [PATCH 22/35] fix: propagate model through credential pool path + add tests The cherry-picked fix from PR #7916 placed model propagation after the credential pool early-return in _resolve_named_custom_runtime(), making it dead code when a pool is active (which happens whenever custom_providers has an api_key that auto-seeds the pool). - Inject model into pool_result before returning - Add 5 regression tests covering direct path, pool path, empty model, and absent model scenarios - Add 'model' to _VALID_CUSTOM_PROVIDER_FIELDS for config validation --- hermes_cli/config.py | 2 +- hermes_cli/runtime_provider.py | 5 + .../test_runtime_provider_resolution.py | 112 ++++++++++++++++++ 3 files changed, 118 insertions(+), 1 deletion(-) diff --git a/hermes_cli/config.py b/hermes_cli/config.py index c3cf0456e..1545d15aa 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -1497,7 +1497,7 @@ _KNOWN_ROOT_KEYS = { # Valid fields inside a custom_providers list entry _VALID_CUSTOM_PROVIDER_FIELDS = { - "name", "base_url", "api_key", "api_mode", "models", + "name", "base_url", "api_key", "api_mode", "model", "models", "context_length", "rate_limit_delay", } diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index c3fcd3aae..cd0b66722 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -332,6 +332,11 @@ def _resolve_named_custom_runtime( # Check if a credential pool exists for this custom endpoint pool_result = _try_resolve_from_custom_pool(base_url, "custom", custom_provider.get("api_mode")) if pool_result: + # Propagate the model name even when using pooled credentials — + # the pool doesn't know about the custom_providers model field. + model_name = custom_provider.get("model") + if model_name: + pool_result["model"] = model_name return pool_result api_key_candidates = [ diff --git a/tests/hermes_cli/test_runtime_provider_resolution.py b/tests/hermes_cli/test_runtime_provider_resolution.py index f46b2dd13..20486a805 100644 --- a/tests/hermes_cli/test_runtime_provider_resolution.py +++ b/tests/hermes_cli/test_runtime_provider_resolution.py @@ -1214,3 +1214,115 @@ def test_openrouter_provider_not_affected_by_custom_fix(monkeypatch): resolved = rp.resolve_runtime_provider(requested="openrouter") assert resolved["provider"] == "openrouter" + + +# ------------------------------------------------------------------ +# fix #7828 — custom_providers model field must propagate to runtime +# ------------------------------------------------------------------ + + +def test_get_named_custom_provider_includes_model(monkeypatch): + """_get_named_custom_provider should include the model field from config.""" + monkeypatch.setattr(rp, "load_config", lambda: { + "custom_providers": [{ + "name": "my-dashscope", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "api_key": "test-key", + "api_mode": "chat_completions", + "model": "qwen3.6-plus", + }], + }) + + result = rp._get_named_custom_provider("my-dashscope") + assert result is not None + assert result["model"] == "qwen3.6-plus" + + +def test_get_named_custom_provider_excludes_empty_model(monkeypatch): + """Empty or whitespace-only model field should not appear in result.""" + for model_val in ["", " ", None]: + entry = { + "name": "test-ep", + "base_url": "https://example.com/v1", + "api_key": "key", + } + if model_val is not None: + entry["model"] = model_val + + monkeypatch.setattr(rp, "load_config", lambda e=entry: { + "custom_providers": [e], + }) + + result = rp._get_named_custom_provider("test-ep") + assert result is not None + assert "model" not in result, ( + f"model field {model_val!r} should not be included in result" + ) + + +def test_named_custom_runtime_propagates_model_direct_path(monkeypatch): + """Model should propagate through the direct (non-pool) resolution path.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-server", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + "model": "qwen3.6-plus", + }, + ) + # Ensure pool doesn't intercept + monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None) + + resolved = rp.resolve_runtime_provider(requested="my-server") + assert resolved["model"] == "qwen3.6-plus" + assert resolved["provider"] == "custom" + + +def test_named_custom_runtime_propagates_model_pool_path(monkeypatch): + """Model should propagate even when credential pool handles credentials.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-server", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + "model": "qwen3.6-plus", + }, + ) + # Pool returns a result (intercepting the normal path) + monkeypatch.setattr( + rp, "_try_resolve_from_custom_pool", + lambda *a, **k: { + "provider": "custom", + "api_mode": "chat_completions", + "base_url": "http://localhost:8000/v1", + "api_key": "pool-key", + "source": "pool:custom:my-server", + }, + ) + + resolved = rp.resolve_runtime_provider(requested="my-server") + assert resolved["model"] == "qwen3.6-plus", ( + "model must be injected into pool result" + ) + assert resolved["api_key"] == "pool-key", "pool credentials should be used" + + +def test_named_custom_runtime_no_model_when_absent(monkeypatch): + """When custom_providers entry has no model field, runtime should not either.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-server", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + }, + ) + monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None) + + resolved = rp.resolve_runtime_provider(requested="my-server") + assert "model" not in resolved From 255f59de1891011cc6c270dba7dbe4bc1bfdfee4 Mon Sep 17 00:00:00 2001 From: Dusk1e Date: Sat, 11 Apr 2026 22:54:45 +0300 Subject: [PATCH 23/35] fix(tools): prevent command argument injection and path traversal in checkpoint manager This commit addresses a security vulnerability where unsanitized user inputs for commit_hash and file_path were passed directly to git commands in CheckpointManager.restore() and diff(). It validates commit hashes to be strictly hexadecimal characters without leading dashes (preventing flag injection like '--patch') and enforces file paths to stay within the working directory via root resolution. Regression tests test_restore_rejects_argument_injection, test_restore_rejects_invalid_hex_chars, and test_restore_rejects_path_traversal were added. --- tests/tools/test_checkpoint_manager.py | 65 ++++++++++++++++++++++++++ tools/checkpoint_manager.py | 61 ++++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/tests/tools/test_checkpoint_manager.py b/tests/tools/test_checkpoint_manager.py index ef843465f..ae03dc31b 100644 --- a/tests/tools/test_checkpoint_manager.py +++ b/tests/tools/test_checkpoint_manager.py @@ -411,3 +411,68 @@ class TestErrorResilience: # Should not raise result = mgr.ensure_checkpoint(str(work_dir), "test") assert result is False + + +# ========================================================================= +# Security / Input validation +# ========================================================================= + +class TestSecurity: + def test_restore_rejects_argument_injection(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Try to pass a git flag as a commit hash + result = mgr.restore(str(work_dir), "--patch") + assert result["success"] is False + assert "Invalid commit hash" in result["error"] + assert "must not start with '-'" in result["error"] + + result = mgr.restore(str(work_dir), "-p") + assert result["success"] is False + assert "Invalid commit hash" in result["error"] + + def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Git hashes should not contain characters like ;, &, | + result = mgr.restore(str(work_dir), "abc; rm -rf /") + assert result["success"] is False + assert "expected 4-64 hex characters" in result["error"] + + result = mgr.diff(str(work_dir), "abc&def") + assert result["success"] is False + assert "expected 4-64 hex characters" in result["error"] + + def test_restore_rejects_path_traversal(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Real commit hash but malicious path + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + # Absolute path outside + result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd") + assert result["success"] is False + assert "got absolute path" in result["error"] + + # Relative traversal outside path + result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt") + assert result["success"] is False + assert "escapes the working directory" in result["error"] + + def test_restore_accepts_valid_file_path(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + # Valid path inside directory + result = mgr.restore(str(work_dir), target_hash, file_path="main.py") + assert result["success"] is True + + # Another valid path with subdirectories + (work_dir / "subdir").mkdir() + (work_dir / "subdir" / "test.txt").write_text("hello") + mgr.new_turn() + mgr.ensure_checkpoint(str(work_dir), "second") + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt") + assert result["success"] is True diff --git a/tools/checkpoint_manager.py b/tools/checkpoint_manager.py index c298aa0bb..3ea6b32fd 100644 --- a/tools/checkpoint_manager.py +++ b/tools/checkpoint_manager.py @@ -21,6 +21,7 @@ into the user's project directory. import hashlib import logging import os +import re import shutil import subprocess from pathlib import Path @@ -64,6 +65,49 @@ _GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", " # Max files to snapshot — skip huge directories to avoid slowdowns. _MAX_FILES = 50_000 +# Valid git commit hash pattern: 4–40 hex chars (short or full SHA-1/SHA-256). +_COMMIT_HASH_RE = re.compile(r'^[0-9a-fA-F]{4,64}$') + + +# --------------------------------------------------------------------------- +# Input validation helpers +# --------------------------------------------------------------------------- + +def _validate_commit_hash(commit_hash: str) -> Optional[str]: + """Validate a commit hash to prevent git argument injection. + + Returns an error string if invalid, None if valid. + Values starting with '-' would be interpreted as git flags + (e.g., '--patch', '-p') instead of revision specifiers. + """ + if not commit_hash or not commit_hash.strip(): + return "Empty commit hash" + if commit_hash.startswith("-"): + return f"Invalid commit hash (must not start with '-'): {commit_hash!r}" + if not _COMMIT_HASH_RE.match(commit_hash): + return f"Invalid commit hash (expected 4-64 hex characters): {commit_hash!r}" + return None + + +def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]: + """Validate a file path to prevent path traversal outside the working directory. + + Returns an error string if invalid, None if valid. + """ + if not file_path or not file_path.strip(): + return "Empty file path" + # Reject absolute paths — restore targets must be relative to the workdir + if os.path.isabs(file_path): + return f"File path must be relative, got absolute path: {file_path!r}" + # Resolve and check containment within working_dir + abs_workdir = Path(working_dir).resolve() + resolved = (abs_workdir / file_path).resolve() + try: + resolved.relative_to(abs_workdir) + except ValueError: + return f"File path escapes the working directory via traversal: {file_path!r}" + return None + # --------------------------------------------------------------------------- # Shadow repo helpers @@ -311,6 +355,11 @@ class CheckpointManager: Returns dict with success, diff text, and stat summary. """ + # Validate commit_hash to prevent git argument injection + hash_err = _validate_commit_hash(commit_hash) + if hash_err: + return {"success": False, "error": hash_err} + abs_dir = str(Path(working_dir).resolve()) shadow = _shadow_repo_path(abs_dir) @@ -364,7 +413,19 @@ class CheckpointManager: Returns dict with success/error info. """ + # Validate commit_hash to prevent git argument injection + hash_err = _validate_commit_hash(commit_hash) + if hash_err: + return {"success": False, "error": hash_err} + abs_dir = str(Path(working_dir).resolve()) + + # Validate file_path to prevent path traversal outside the working dir + if file_path: + path_err = _validate_file_path(file_path, abs_dir) + if path_err: + return {"success": False, "error": path_err} + shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): From f2893fe51a59e545ad05f459fb235296872c4561 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:26:11 -0700 Subject: [PATCH 24/35] fix(tools): neutralize shell injection in _write_to_sandbox via path quoting (#7940) _write_to_sandbox interpolated storage_dir and remote_path directly into a shell command passed to env.execute(). Paths containing shell metacharacters (spaces, semicolons, $(), backticks) could trigger arbitrary command execution inside the sandbox. Fix: wrap both paths with shlex.quote(). Clean paths (alphanumeric + slashes/hyphens/dots) are left unmodified by shlex.quote, so existing behavior is unchanged. Paths with unsafe characters get single-quoted. Tests added for spaces, $(command) substitution, and semicolon injection. --- tests/tools/test_tool_result_storage.py | 28 +++++++++++++++++++++++++ tools/tool_result_storage.py | 3 ++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_tool_result_storage.py b/tests/tools/test_tool_result_storage.py index f95b5dc08..0bbb95bbd 100644 --- a/tests/tools/test_tool_result_storage.py +++ b/tests/tools/test_tool_result_storage.py @@ -124,6 +124,34 @@ class TestWriteToSandbox: cmd = env.execute.call_args[0][0] assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd + def test_path_with_spaces_is_quoted(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + remote_path = "/tmp/hermes results/abc file.txt" + _write_to_sandbox("content", remote_path, env) + cmd = env.execute.call_args[0][0] + assert "'/tmp/hermes results'" in cmd + assert "'/tmp/hermes results/abc file.txt'" in cmd + + def test_shell_metacharacters_neutralized(self): + """Paths with shell metacharacters must be quoted to prevent injection.""" + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + malicious_path = "/tmp/hermes-results/$(whoami).txt" + _write_to_sandbox("content", malicious_path, env) + cmd = env.execute.call_args[0][0] + # The $() must not appear unquoted — shlex.quote wraps it + assert "'/tmp/hermes-results/$(whoami).txt'" in cmd + + def test_semicolon_injection_neutralized(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + malicious_path = "/tmp/x; rm -rf /; echo .txt" + _write_to_sandbox("content", malicious_path, env) + cmd = env.execute.call_args[0][0] + # The semicolons must be inside quotes, not acting as command separators + assert "'/tmp/x; rm -rf /; echo .txt'" in cmd + class TestResolveStorageDir: def test_defaults_to_storage_dir_without_env(self): diff --git a/tools/tool_result_storage.py b/tools/tool_result_storage.py index a8ec5440b..434226448 100644 --- a/tools/tool_result_storage.py +++ b/tools/tool_result_storage.py @@ -24,6 +24,7 @@ Defense against context-window overflow operates at three levels: import logging import os +import shlex import uuid from tools.budget_config import ( @@ -79,7 +80,7 @@ def _write_to_sandbox(content: str, remote_path: str, env) -> bool: marker = _heredoc_marker(content) storage_dir = os.path.dirname(remote_path) cmd = ( - f"mkdir -p {storage_dir} && cat > {remote_path} << '{marker}'\n" + f"mkdir -p {shlex.quote(storage_dir)} && cat > {shlex.quote(remote_path)} << '{marker}'\n" f"{content}\n" f"{marker}" ) From ef73babea1c3460e75ccb469c88ab54314ea4565 Mon Sep 17 00:00:00 2001 From: willy-scr <187001140+willy-scr@users.noreply.github.com> Date: Sun, 12 Apr 2026 04:36:31 +0800 Subject: [PATCH 25/35] fix(gateway): use source.thread_id instead of undefined event in queued response MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In _run_agent(), the pending message handler references 'event' which is not defined in that scope — it only exists in the caller. This causes a NameError when sending the first response before processing a queued follow-up message. Replace getattr(event, 'metadata', None) with the established pattern using source.thread_id, consistent with lines 2625, 2810, 3678, 4410, 4566 in the same file. --- gateway/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gateway/run.py b/gateway/run.py index 469abe9ec..9b1e5c275 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -8131,7 +8131,7 @@ class GatewayRunner: if first_response and not _already_streamed: try: await adapter.send(source.chat_id, first_response, - metadata=getattr(event, "metadata", None)) + metadata={"thread_id": source.thread_id} if source.thread_id else None) except Exception as e: logger.warning("Failed to send first response before queued message: %s", e) # else: interrupted — discard the interrupted response ("Operation From dfe7386a58a7fed3e1d4f2567e13213044ff8168 Mon Sep 17 00:00:00 2001 From: sauljwu Date: Sat, 11 Apr 2026 17:17:19 -0400 Subject: [PATCH 26/35] fix: deduplicate reasoning items in Responses API input When replaying codex_reasoning_items from previous turns, duplicate item IDs (rs_*) could appear in the input array, causing HTTP 400 "Duplicate item found" errors from the OpenAI Responses API. Add seen_item_ids tracking in both _chat_messages_to_responses_input() and _preflight_codex_input_items() to skip already-added reasoning items by their ID. Co-Authored-By: Claude Opus 4.6 (1M context) --- run_agent.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/run_agent.py b/run_agent.py index ba0a9f93d..b65a8574a 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3446,6 +3446,7 @@ class AIAgent: def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert internal chat-style messages to Responses input items.""" items: List[Dict[str, Any]] = [] + seen_item_ids: set = set() for msg in messages: if not isinstance(msg, dict): @@ -3466,7 +3467,12 @@ class AIAgent: if isinstance(codex_reasoning, list): for ri in codex_reasoning: if isinstance(ri, dict) and ri.get("encrypted_content"): + item_id = ri.get("id") + if item_id and item_id in seen_item_ids: + continue items.append(ri) + if item_id: + seen_item_ids.add(item_id) has_codex_reasoning = True if content_text.strip(): @@ -3546,6 +3552,7 @@ class AIAgent: raise ValueError("Codex Responses input must be a list of input items.") normalized: List[Dict[str, Any]] = [] + seen_ids: set = set() for idx, item in enumerate(raw_items): if not isinstance(item, dict): raise ValueError(f"Codex Responses input[{idx}] must be an object.") @@ -3598,8 +3605,12 @@ class AIAgent: if item_type == "reasoning": encrypted = item.get("encrypted_content") if isinstance(encrypted, str) and encrypted: - reasoning_item = {"type": "reasoning", "encrypted_content": encrypted} item_id = item.get("id") + if isinstance(item_id, str) and item_id: + if item_id in seen_ids: + continue + seen_ids.add(item_id) + reasoning_item = {"type": "reasoning", "encrypted_content": encrypted} if isinstance(item_id, str) and item_id: reasoning_item["id"] = item_id summary = item.get("summary") From 8160d7a03d9c203a6c9c67a72480fa0e246b83b3 Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 14:26:44 -0700 Subject: [PATCH 27/35] test: add dedup coverage for reasoning item ID deduplication Adds two tests verifying that duplicate reasoning item IDs across multi-turn Codex Responses conversations are correctly deduplicated in both _chat_messages_to_responses_input() and _preflight_codex_input_items(). --- .../test_run_agent_codex_responses.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/run_agent/test_run_agent_codex_responses.py b/tests/run_agent/test_run_agent_codex_responses.py index 6756ed6fd..17a70624d 100644 --- a/tests/run_agent/test_run_agent_codex_responses.py +++ b/tests/run_agent/test_run_agent_codex_responses.py @@ -1104,3 +1104,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch ] assert "enc_first" in encrypted_contents assert "enc_second" in encrypted_contents + + +def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch): + """Duplicate reasoning item IDs across multi-turn incomplete responses + must be deduplicated so the Responses API doesn't reject with HTTP 400.""" + agent = _build_agent(monkeypatch) + messages = [ + {"role": "user", "content": "think hard"}, + { + "role": "assistant", + "content": "", + "codex_reasoning_items": [ + {"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"}, + {"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"}, + ], + }, + { + "role": "assistant", + "content": "partial answer", + "codex_reasoning_items": [ + # rs_aaa is duplicated from the previous turn + {"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"}, + {"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"}, + ], + }, + ] + items = agent._chat_messages_to_responses_input(messages) + + reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"] + # rs_aaa should appear only once (first occurrence kept) + assert reasoning_ids.count("rs_aaa") == 1 + # rs_bbb and rs_ccc should each appear once + assert reasoning_ids.count("rs_bbb") == 1 + assert reasoning_ids.count("rs_ccc") == 1 + assert len(reasoning_ids) == 3 + + +def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch): + """_preflight_codex_input_items should also deduplicate reasoning items by ID.""" + agent = _build_agent(monkeypatch) + raw_input = [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"}, + {"role": "assistant", "content": "ok"}, + {"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"}, + {"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"}, + {"role": "assistant", "content": "done"}, + ] + normalized = agent._preflight_codex_input_items(raw_input) + + reasoning_items = [it for it in normalized if it.get("type") == "reasoning"] + reasoning_ids = [it["id"] for it in reasoning_items] + assert reasoning_ids.count("rs_xyz") == 1 + assert reasoning_ids.count("rs_zzz") == 1 + assert len(reasoning_items) == 2 From 72b345e068eebc9c0c482542b18af79384825261 Mon Sep 17 00:00:00 2001 From: etcircle <33860762+etcircle@users.noreply.github.com> Date: Sat, 11 Apr 2026 21:36:05 +0100 Subject: [PATCH 28/35] fix(gateway): preserve queued voice events for STT --- gateway/run.py | 397 ++++++++++++------------ tests/gateway/test_queue_consumption.py | 21 ++ tests/gateway/test_stt_config.py | 47 ++- 3 files changed, 269 insertions(+), 196 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index 9b1e5c275..5e1ed0e86 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str: return "\n".join(parts) -def _dequeue_pending_text(adapter, session_key: str) -> str | None: - """Consume and return the text of a pending queued message. +def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None: + """Consume and return the full pending event for a session. - Preserves media context for captionless photo/document events by - building a placeholder so the message isn't silently dropped. + Queued follow-ups must preserve their media metadata so they can re-enter + the normal image/STT/document preprocessing path instead of being reduced + to a placeholder string. """ - event = adapter.get_pending_message(session_key) - if not event: - return None - text = event.text - if not text and getattr(event, "media_urls", None): - text = _build_media_placeholder(event) - return text + return adapter.get_pending_message(session_key) def _check_unavailable_skill(command_name: str) -> str | None: @@ -2775,6 +2770,162 @@ class GatewayRunner: del self._running_agents[_quick_key] self._running_agents_ts.pop(_quick_key, None) + async def _prepare_inbound_message_text( + self, + *, + event: MessageEvent, + source: SessionSource, + history: List[Dict[str, Any]], + ) -> Optional[str]: + """Prepare inbound event text for the agent. + + Keep the normal inbound path and the queued follow-up path on the same + preprocessing pipeline so sender attribution, image enrichment, STT, + document notes, reply context, and @ references all behave the same. + """ + history = history or [] + message_text = event.text or "" + + _is_shared_thread = ( + source.chat_type != "dm" + and source.thread_id + and not getattr(self.config, "thread_sessions_per_user", False) + ) + if _is_shared_thread and source.user_name: + message_text = f"[{source.user_name}] {message_text}" + + if event.media_urls: + image_paths = [] + audio_paths = [] + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype.startswith("image/") or event.message_type == MessageType.PHOTO: + image_paths.append(path) + if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO): + audio_paths.append(path) + + if image_paths: + message_text = await self._enrich_message_with_vision( + message_text, + image_paths, + ) + + if audio_paths: + message_text = await self._enrich_message_with_transcription( + message_text, + audio_paths, + ) + _stt_fail_markers = ( + "No STT provider", + "STT is disabled", + "can't listen", + "VOICE_TOOLS_OPENAI_KEY", + ) + if any(marker in message_text for marker in _stt_fail_markers): + _stt_adapter = self.adapters.get(source.platform) + _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None + if _stt_adapter: + try: + _stt_msg = ( + "🎤 I received your voice message but can't transcribe it — " + "no speech-to-text provider is configured.\n\n" + "To enable voice: install faster-whisper " + "(`pip install faster-whisper` in the Hermes venv) " + "and set `stt.enabled: true` in config.yaml, " + "then /restart the gateway." + ) + if self._has_setup_skill(): + _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" + await _stt_adapter.send( + source.chat_id, + _stt_msg, + metadata=_stt_meta, + ) + except Exception: + pass + + if event.media_urls and event.message_type == MessageType.DOCUMENT: + import mimetypes as _mimetypes + + _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype in ("", "application/octet-stream"): + import os as _os2 + + _ext = _os2.path.splitext(path)[1].lower() + if _ext in _TEXT_EXTENSIONS: + mtype = "text/plain" + else: + guessed, _ = _mimetypes.guess_type(path) + if guessed: + mtype = guessed + if not mtype.startswith(("application/", "text/")): + continue + + import os as _os + import re as _re + + basename = _os.path.basename(path) + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + display_name = _re.sub(r'[^\w.\- ]', '_', display_name) + + if mtype.startswith("text/"): + context_note = ( + f"[The user sent a text document: '{display_name}'. " + f"Its content has been included below. " + f"The file is also saved at: {path}]" + ) + else: + context_note = ( + f"[The user sent a document: '{display_name}'. " + f"The file is saved at: {path}. " + f"Ask the user what they'd like you to do with it.]" + ) + message_text = f"{context_note}\n\n{message_text}" + + if getattr(event, "reply_to_text", None) and event.reply_to_message_id: + reply_snippet = event.reply_to_text[:500] + found_in_history = any( + reply_snippet[:200] in (msg.get("content") or "") + for msg in history + if msg.get("role") in ("assistant", "user", "tool") + ) + if not found_in_history: + message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + + if "@" in message_text: + try: + from agent.context_references import preprocess_context_references_async + from agent.model_metadata import get_model_context_length + + _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) + _msg_ctx_len = get_model_context_length( + self._model, + base_url=self._base_url or "", + ) + _ctx_result = await preprocess_context_references_async( + message_text, + cwd=_msg_cwd, + context_length=_msg_ctx_len, + allowed_root=_msg_cwd, + ) + if _ctx_result.blocked: + _adapter = self.adapters.get(source.platform) + if _adapter: + await _adapter.send( + source.chat_id, + "\n".join(_ctx_result.warnings) or "Context injection refused.", + ) + return None + if _ctx_result.expanded: + message_text = _ctx_result.message + except Exception as exc: + logger.debug("@ context reference expansion failed: %s", exc) + + return message_text + async def _handle_message_with_agent(self, event, source, _quick_key: str): """Inner handler that runs under the _running_agents sentinel guard.""" _msg_start_time = time.time() @@ -3215,149 +3366,13 @@ class GatewayRunner: # attachments (documents, audio, etc.) are not sent to the vision # tool even when they appear in the same message. # ----------------------------------------------------------------- - message_text = event.text or "" - - # ----------------------------------------------------------------- - # Sender attribution for shared thread sessions. - # - # When multiple users share a single thread session (the default for - # threads), prefix each message with [sender name] so the agent can - # tell participants apart. Skip for DMs (single-user by nature) and - # when per-user thread isolation is explicitly enabled. - # ----------------------------------------------------------------- - _is_shared_thread = ( - source.chat_type != "dm" - and source.thread_id - and not getattr(self.config, "thread_sessions_per_user", False) + message_text = await self._prepare_inbound_message_text( + event=event, + source=source, + history=history, ) - if _is_shared_thread and source.user_name: - message_text = f"[{source.user_name}] {message_text}" - - if event.media_urls: - image_paths = [] - for i, path in enumerate(event.media_urls): - # Check media_types if available; otherwise infer from message type - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_image = ( - mtype.startswith("image/") - or event.message_type == MessageType.PHOTO - ) - if is_image: - image_paths.append(path) - if image_paths: - message_text = await self._enrich_message_with_vision( - message_text, image_paths - ) - - # ----------------------------------------------------------------- - # Auto-transcribe voice/audio messages sent by the user - # ----------------------------------------------------------------- - if event.media_urls: - audio_paths = [] - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_audio = ( - mtype.startswith("audio/") - or event.message_type in (MessageType.VOICE, MessageType.AUDIO) - ) - if is_audio: - audio_paths.append(path) - if audio_paths: - message_text = await self._enrich_message_with_transcription( - message_text, audio_paths - ) - # If STT failed, send a direct message to the user so they - # know voice isn't configured — don't rely on the agent to - # relay the error clearly. - _stt_fail_markers = ( - "No STT provider", - "STT is disabled", - "can't listen", - "VOICE_TOOLS_OPENAI_KEY", - ) - if any(m in message_text for m in _stt_fail_markers): - _stt_adapter = self.adapters.get(source.platform) - _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None - if _stt_adapter: - try: - _stt_msg = ( - "🎤 I received your voice message but can't transcribe it — " - "no speech-to-text provider is configured.\n\n" - "To enable voice: install faster-whisper " - "(`pip install faster-whisper` in the Hermes venv) " - "and set `stt.enabled: true` in config.yaml, " - "then /restart the gateway." - ) - # Point to setup skill if it's installed - if self._has_setup_skill(): - _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" - await _stt_adapter.send( - source.chat_id, _stt_msg, - metadata=_stt_meta, - ) - except Exception: - pass - - # ----------------------------------------------------------------- - # Enrich document messages with context notes for the agent - # ----------------------------------------------------------------- - if event.media_urls and event.message_type == MessageType.DOCUMENT: - import mimetypes as _mimetypes - _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - # Fall back to extension-based detection when MIME type is unreliable. - if mtype in ("", "application/octet-stream"): - import os as _os2 - _ext = _os2.path.splitext(path)[1].lower() - if _ext in _TEXT_EXTENSIONS: - mtype = "text/plain" - else: - guessed, _ = _mimetypes.guess_type(path) - if guessed: - mtype = guessed - if not mtype.startswith(("application/", "text/")): - continue - # Extract display filename by stripping the doc_{uuid12}_ prefix - import os as _os - basename = _os.path.basename(path) - # Format: doc_<12hex>_ - parts = basename.split("_", 2) - display_name = parts[2] if len(parts) >= 3 else basename - # Sanitize to prevent prompt injection via filenames - import re as _re - display_name = _re.sub(r'[^\w.\- ]', '_', display_name) - - if mtype.startswith("text/"): - context_note = ( - f"[The user sent a text document: '{display_name}'. " - f"Its content has been included below. " - f"The file is also saved at: {path}]" - ) - else: - context_note = ( - f"[The user sent a document: '{display_name}'. " - f"The file is saved at: {path}. " - f"Ask the user what they'd like you to do with it.]" - ) - message_text = f"{context_note}\n\n{message_text}" - - # ----------------------------------------------------------------- - # Inject reply context when user replies to a message not in history. - # Telegram (and other platforms) let users reply to specific messages, - # but if the quoted message is from a previous session, cron delivery, - # or background task, the agent has no context about what's being - # referenced. Prepend the quoted text so the agent understands. (#1594) - # ----------------------------------------------------------------- - if getattr(event, 'reply_to_text', None) and event.reply_to_message_id: - reply_snippet = event.reply_to_text[:500] - found_in_history = any( - reply_snippet[:200] in (msg.get("content") or "") - for msg in history - if msg.get("role") in ("assistant", "user", "tool") - ) - if not found_in_history: - message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + if message_text is None: + return try: # Emit agent:start hook @@ -3369,30 +3384,6 @@ class GatewayRunner: } await self.hooks.emit("agent:start", hook_ctx) - # Expand @ context references (@file:, @folder:, @diff, etc.) - if "@" in message_text: - try: - from agent.context_references import preprocess_context_references_async - from agent.model_metadata import get_model_context_length - _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) - _msg_ctx_len = get_model_context_length( - self._model, base_url=self._base_url or "") - _ctx_result = await preprocess_context_references_async( - message_text, cwd=_msg_cwd, - context_length=_msg_ctx_len, allowed_root=_msg_cwd) - if _ctx_result.blocked: - _adapter = self.adapters.get(source.platform) - if _adapter: - await _adapter.send( - source.chat_id, - "\n".join(_ctx_result.warnings) or "Context injection refused.", - ) - return - if _ctx_result.expanded: - message_text = _ctx_result.message - except Exception as exc: - logger.debug("@ context reference expansion failed: %s", exc) - # Run the agent agent_result = await self._run_agent( message=message_text, @@ -8057,17 +8048,16 @@ class GatewayRunner: # Get pending message from adapter. # Use session_key (not source.chat_id) to match adapter's storage keys. + pending_event = None pending = None if result and adapter and session_key: - if result.get("interrupted"): - pending = _dequeue_pending_text(adapter, session_key) - if not pending and result.get("interrupt_message"): - pending = result.get("interrupt_message") - else: - pending = _dequeue_pending_text(adapter, session_key) - if pending: - logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) - + pending_event = _dequeue_pending_event(adapter, session_key) + if result.get("interrupted") and not pending_event and result.get("interrupt_message"): + pending = result.get("interrupt_message") + elif pending_event: + pending = pending_event.text or _build_media_placeholder(pending_event) + logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) + # Safety net: if the pending text is a slash command (e.g. "/stop", # "/new"), discard it — commands should never be passed to the agent # as user input. The primary fix is in base.py (commands bypass the @@ -8085,27 +8075,29 @@ class GatewayRunner: "commands must not be passed as agent input", _pending_cmd_word, ) + pending_event = None pending = None except Exception: pass - if self._draining and pending: + if self._draining and (pending_event or pending): logger.info( "Discarding pending follow-up for session %s during gateway %s", session_key[:20] if session_key else "?", self._status_action_label(), ) + pending_event = None pending = None - if pending: + if pending_event or pending: logger.debug("Processing pending message: '%s...'", pending[:40]) - + # Clear the adapter's interrupt event so the next _run_agent call # doesn't immediately re-trigger the interrupt before the new agent # even makes its first API call (this was causing an infinite loop). if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions: adapter._active_sessions[session_key].clear() - + # Cap recursion depth to prevent resource exhaustion when the # user sends multiple messages while the agent keeps failing. (#816) if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH: @@ -8114,9 +8106,10 @@ class GatewayRunner: "queueing message instead of recursing.", _interrupt_depth, session_key, ) - # Queue the pending message for normal processing on next turn adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, 'queue_message'): + if adapter and pending_event: + merge_pending_message_event(adapter._pending_messages, session_key, pending_event) + elif adapter and hasattr(adapter, 'queue_message'): adapter.queue_message(session_key, pending) return result_holder[0] or {"final_response": response, "messages": history} @@ -8138,16 +8131,30 @@ class GatewayRunner: # interrupted." is just noise; the user already knows they sent a # new message). - # Process the pending message with updated history updated_history = result.get("messages", history) + next_source = source + next_message = pending + next_message_id = None + if pending_event is not None: + next_source = getattr(pending_event, "source", None) or source + next_message = await self._prepare_inbound_message_text( + event=pending_event, + source=next_source, + history=updated_history, + ) + if next_message is None: + return result + next_message_id = getattr(pending_event, "message_id", None) + return await self._run_agent( - message=pending, + message=next_message, context_prompt=context_prompt, history=updated_history, - source=source, + source=next_source, session_id=session_id, session_key=session_key, _interrupt_depth=_interrupt_depth + 1, + event_message_id=next_message_id, ) finally: # Stop progress sender, interrupt monitor, and notification task diff --git a/tests/gateway/test_queue_consumption.py b/tests/gateway/test_queue_consumption.py index 2a4dd4ff0..50effc139 100644 --- a/tests/gateway/test_queue_consumption.py +++ b/tests/gateway/test_queue_consumption.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from gateway.run import _dequeue_pending_event from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -79,6 +80,26 @@ class TestQueueMessageStorage: # Should be consumed (cleared) assert adapter.get_pending_message(session_key) is None + def test_dequeue_pending_event_preserves_voice_media_metadata(self): + adapter = _StubAdapter() + session_key = "telegram:user:voice" + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=MagicMock(chat_id="123", platform=Platform.TELEGRAM), + message_id="voice-q1", + media_urls=["/tmp/voice.ogg"], + media_types=["audio/ogg"], + ) + adapter._pending_messages[session_key] = event + + retrieved = _dequeue_pending_event(adapter, session_key) + + assert retrieved is event + assert retrieved.media_urls == ["/tmp/voice.ogg"] + assert retrieved.media_types == ["audio/ogg"] + assert adapter.get_pending_message(session_key) is None + def test_queue_does_not_set_interrupt_event(self): """The whole point of /queue — no interrupt signal.""" adapter = _StubAdapter() diff --git a/tests/gateway/test_stt_config.py b/tests/gateway/test_stt_config.py index a49e40215..23ba06af2 100644 --- a/tests/gateway/test_stt_config.py +++ b/tests/gateway/test_stt_config.py @@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch import pytest import yaml -from gateway.config import GatewayConfig, load_gateway_config +from gateway.config import GatewayConfig, Platform, load_gateway_config +from gateway.platforms.base import MessageEvent, MessageType +from gateway.session import SessionSource def test_gateway_config_stt_disabled_from_dict_nested(): @@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag assert "No STT provider is configured" not in result assert "trouble transcribing" in result assert "caption" in result + + +@pytest.mark.asyncio +async def test_prepare_inbound_message_text_transcribes_queued_voice_event(): + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner.config = GatewayConfig(stt_enabled=True) + runner.adapters = {} + runner._model = "test-model" + runner._base_url = "" + runner._has_setup_skill = lambda: False + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=source, + media_urls=["/tmp/queued-voice.ogg"], + media_types=["audio/ogg"], + ) + + with patch( + "tools.transcription_tools.transcribe_audio", + return_value={ + "success": True, + "transcript": "queued voice transcript", + "provider": "local_command", + }, + ): + result = await runner._prepare_inbound_message_text( + event=event, + source=source, + history=[], + ) + + assert result is not None + assert "queued voice transcript" in result + assert "voice message" in result.lower() From b80e3181681214b7197d50d54b6ef4336f1c0816 Mon Sep 17 00:00:00 2001 From: Dominic Grieco Date: Sat, 11 Apr 2026 17:31:22 -0300 Subject: [PATCH 29/35] fix: scope gateway status to the active profile --- hermes_cli/gateway.py | 109 +++++++++++++----- tests/hermes_cli/test_gateway_service.py | 18 +++ .../hermes_cli/test_update_gateway_restart.py | 38 ++++++ 3 files changed, 139 insertions(+), 26 deletions(-) diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index b29511dd5..8670b5a78 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -158,29 +158,43 @@ def _request_gateway_self_restart(pid: int) -> bool: def find_gateway_pids(exclude_pids: set | None = None) -> list: - """Find PIDs of running gateway processes. - - Args: - exclude_pids: PIDs to exclude from the result (e.g. service-managed - PIDs that should not be killed during a stale-process sweep). - """ - pids = [] + """Find PIDs of running gateway processes for the current Hermes profile.""" _exclude = exclude_pids or set() + pids = [pid for pid in _get_service_pids() if pid not in _exclude] patterns = [ "hermes_cli.main gateway", + "hermes_cli.main --profile", + "hermes_cli.main -p", "hermes_cli/main.py gateway", + "hermes_cli/main.py --profile", + "hermes_cli/main.py -p", "hermes gateway", "gateway/run.py", ] + current_home = str(get_hermes_home().resolve()) + current_profile_arg = _profile_arg(current_home) + current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else "" + + def _matches_current_profile(command: str) -> bool: + if current_profile_name: + return ( + f"--profile {current_profile_name}" in command + or f"-p {current_profile_name}" in command + or f"HERMES_HOME={current_home}" in command + ) + + if "--profile " in command or " -p " in command: + return False + if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command: + return False + return True try: if is_windows(): - # Windows: use wmic to search command lines result = subprocess.run( ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True, timeout=10 ) - # Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n" current_cmd = "" for line in result.stdout.split('\n'): line = line.strip() @@ -188,7 +202,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = line[len("CommandLine="):] elif line.startswith("ProcessId="): pid_str = line[len("ProcessId="):] - if any(p in current_cmd for p in patterns): + if any(p in current_cmd for p in patterns) and _matches_current_profile(current_cmd): try: pid = int(pid_str) if pid != os.getpid() and pid not in pids and pid not in _exclude: @@ -198,26 +212,39 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = "" else: result = subprocess.run( - ["ps", "aux"], + ["ps", "eww", "-ax", "-o", "pid=,command="], capture_output=True, text=True, timeout=10, ) for line in result.stdout.split('\n'): - # Skip grep and current process - if 'grep' in line or str(os.getpid()) in line: + stripped = line.strip() + if not stripped or 'grep' in stripped: continue - for pattern in patterns: - if pattern in line: - parts = line.split() - if len(parts) > 1: - try: - pid = int(parts[1]) - if pid not in pids and pid not in _exclude: - pids.append(pid) - except ValueError: - continue - break + + pid = None + command = "" + + parts = stripped.split(None, 1) + if len(parts) == 2: + try: + pid = int(parts[0]) + command = parts[1] + except ValueError: + pid = None + + if pid is None: + aux_parts = stripped.split() + if len(aux_parts) > 10 and aux_parts[1].isdigit(): + pid = int(aux_parts[1]) + command = " ".join(aux_parts[10:]) + + if pid is None: + continue + if pid == os.getpid() or pid in pids or pid in _exclude: + continue + if any(pattern in command for pattern in patterns) and _matches_current_profile(command): + pids.append(pid) except Exception: pass @@ -633,6 +660,17 @@ def print_systemd_linger_guidance() -> None: print(" If you want the gateway user service to survive logout, run:") print(" sudo loginctl enable-linger $USER") +def _launchd_user_home() -> Path: + """Return the real macOS user home for launchd artifacts. + + Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but + launchd user agents still live under the actual account home. + """ + import pwd + + return Path(pwd.getpwuid(os.getuid()).pw_dir) + + def get_launchd_plist_path() -> Path: """Return the launchd plist path, scoped per profile. @@ -641,7 +679,7 @@ def get_launchd_plist_path() -> Path: """ suffix = _profile_suffix() name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway" - return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist" + return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist" def _detect_venv_dir() -> Path | None: """Detect the active virtualenv directory. @@ -839,6 +877,25 @@ def _normalize_service_definition(text: str) -> str: return "\n".join(line.rstrip() for line in text.strip().splitlines()) +def _normalize_launchd_plist_for_comparison(text: str) -> str: + """Normalize launchd plist text for staleness checks. + + The generated plist intentionally captures a broad PATH assembled from the + invoking shell so user-installed tools remain reachable under launchd. + That makes raw text comparison unstable across shells, so ignore the PATH + payload when deciding whether the installed plist is stale. + """ + import re + + normalized = _normalize_service_definition(text) + return re.sub( + r'(PATH\s*)(.*?)()', + r'\1__HERMES_PATH__\3', + normalized, + flags=re.S, + ) + + def systemd_unit_is_current(system: bool = False) -> bool: unit_path = get_systemd_unit_path(system=system) if not unit_path.exists(): @@ -1220,7 +1277,7 @@ def launchd_plist_is_current() -> bool: installed = plist_path.read_text(encoding="utf-8") expected = generate_launchd_plist() - return _normalize_service_definition(installed) == _normalize_service_definition(expected) + return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected) def refresh_launchd_plist_if_needed() -> bool: diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index c5d4cb4f5..482fb4ea5 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -1,6 +1,7 @@ """Tests for gateway service management helpers.""" import os +import pwd from pathlib import Path from types import SimpleNamespace @@ -924,6 +925,23 @@ class TestProfileArg: assert "--profile" in plist assert "mybot" in plist + def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch): + profile_dir = tmp_path / ".hermes" / "profiles" / "orcha" + profile_dir.mkdir(parents=True) + machine_home = tmp_path / "machine-home" + machine_home.mkdir() + profile_home = profile_dir / "home" + profile_home.mkdir() + + monkeypatch.setattr(Path, "home", lambda: profile_home) + monkeypatch.setenv("HERMES_HOME", str(profile_dir)) + monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir) + monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home))) + + plist_path = gateway_cli.get_launchd_plist_path() + + assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist" + class TestRemapPathForUser: """Unit tests for _remap_path_for_user().""" diff --git a/tests/hermes_cli/test_update_gateway_restart.py b/tests/hermes_cli/test_update_gateway_restart.py index ceb05f65c..1460f00ea 100644 --- a/tests/hermes_cli/test_update_gateway_restart.py +++ b/tests/hermes_cli/test_update_gateway_restart.py @@ -191,6 +191,19 @@ class TestLaunchdPlistPath: raise AssertionError("PATH key not found in plist") +class TestLaunchdPlistCurrentness: + def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch): + plist_path = tmp_path / "ai.hermes.gateway.plist" + monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path) + + monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin") + plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8") + + monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin") + + assert gateway_cli.launchd_plist_is_current() is True + + # --------------------------------------------------------------------------- # cmd_update — macOS launchd detection # --------------------------------------------------------------------------- @@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude: pids = gateway_cli.find_gateway_pids() assert 100 in pids assert 200 in pids + + def test_filters_to_current_profile(self, monkeypatch, tmp_path): + profile_dir = tmp_path / ".hermes" / "profiles" / "orcha" + profile_dir.mkdir(parents=True) + monkeypatch.setattr(gateway_cli, "is_windows", lambda: False) + monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir) + + def fake_run(cmd, **kwargs): + return subprocess.CompletedProcess( + cmd, 0, + stdout=( + "100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n" + "200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n" + ), + stderr="", + ) + + monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run) + monkeypatch.setattr("os.getpid", lambda: 999) + monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set()) + monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha") + + pids = gateway_cli.find_gateway_pids() + + assert pids == [100] From d82580b25b8136a1cac6e8ea0179db5dce477d78 Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 11 Apr 2026 14:30:29 -0700 Subject: [PATCH 30/35] fix: add all_profiles param + narrow exception handling - add all_profiles=False to find_gateway_pids() and kill_gateway_processes() so hermes update and gateway stop --all can still discover processes across all profiles - narrow bare 'except Exception' to (OSError, subprocess.TimeoutExpired) - update test mocks to match new signatures --- hermes_cli/gateway.py | 29 ++++++++++++++----- hermes_cli/main.py | 2 +- tests/hermes_cli/test_gateway.py | 2 +- tests/hermes_cli/test_gateway_service.py | 4 +-- .../hermes_cli/test_update_gateway_restart.py | 6 ++-- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 8670b5a78..633deac29 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -157,8 +157,18 @@ def _request_gateway_self_restart(pid: int) -> bool: return True -def find_gateway_pids(exclude_pids: set | None = None) -> list: - """Find PIDs of running gateway processes for the current Hermes profile.""" +def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list: + """Find PIDs of running gateway processes. + + Args: + exclude_pids: PIDs to exclude from the result (e.g. service-managed + PIDs that should not be killed during a stale-process sweep). + all_profiles: When ``True``, return gateway PIDs across **all** + profiles (the pre-7923 global behaviour). ``hermes update`` + needs this because a code update affects every profile. + When ``False`` (default), only PIDs belonging to the current + Hermes profile are returned. + """ _exclude = exclude_pids or set() pids = [pid for pid in _get_service_pids() if pid not in _exclude] patterns = [ @@ -202,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = line[len("CommandLine="):] elif line.startswith("ProcessId="): pid_str = line[len("ProcessId="):] - if any(p in current_cmd for p in patterns) and _matches_current_profile(current_cmd): + if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)): try: pid = int(pid_str) if pid != os.getpid() and pid not in pids and pid not in _exclude: @@ -243,23 +253,26 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: continue if pid == os.getpid() or pid in pids or pid in _exclude: continue - if any(pattern in command for pattern in patterns) and _matches_current_profile(command): + if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)): pids.append(pid) - except Exception: + except (OSError, subprocess.TimeoutExpired): pass return pids -def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int: +def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None, + all_profiles: bool = False) -> int: """Kill any running gateway processes. Returns count killed. Args: force: Use the platform's force-kill mechanism instead of graceful terminate. exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just restarted and should not be killed). + all_profiles: When ``True``, kill across all profiles. Passed + through to :func:`find_gateway_pids`. """ - pids = find_gateway_pids(exclude_pids=exclude_pids) + pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles) killed = 0 for pid in pids: @@ -2597,7 +2610,7 @@ def gateway_command(args): service_available = True except subprocess.CalledProcessError: pass - killed = kill_gateway_processes() + killed = kill_gateway_processes(all_profiles=True) total = killed + (1 if service_available else 0) if total: print(f"✓ Stopped {total} gateway process(es) across all profiles") diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 4b7dd600b..df87f9355 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -3876,7 +3876,7 @@ def cmd_update(args): # Exclude PIDs that belong to just-restarted services so we don't # immediately kill the process that systemd/launchd just spawned. service_pids = _get_service_pids() - manual_pids = find_gateway_pids(exclude_pids=service_pids) + manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True) for pid in manual_pids: try: os.kill(pid, _signal.SIGTERM) diff --git a/tests/hermes_cli/test_gateway.py b/tests/hermes_cli/test_gateway.py index 955449547..fd88a26c6 100644 --- a/tests/hermes_cli/test_gateway.py +++ b/tests/hermes_cli/test_gateway.py @@ -260,7 +260,7 @@ class TestWaitForGatewayExit: def test_kill_gateway_processes_force_uses_helper(self, monkeypatch): calls = [] - monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22]) + monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22]) monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force))) killed = gateway.kill_gateway_processes(force=True) diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index 482fb4ea5..cba3a8192 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -130,7 +130,7 @@ class TestGatewayStopCleanup: monkeypatch.setattr( gateway_cli, "kill_gateway_processes", - lambda force=False: kill_calls.append(force) or 2, + lambda force=False, all_profiles=False: kill_calls.append(force) or 2, ) gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop")) @@ -156,7 +156,7 @@ class TestGatewayStopCleanup: monkeypatch.setattr( gateway_cli, "kill_gateway_processes", - lambda force=False: kill_calls.append(force) or 2, + lambda force=False, all_profiles=False: kill_calls.append(force) or 2, ) gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True})) diff --git a/tests/hermes_cli/test_update_gateway_restart.py b/tests/hermes_cli/test_update_gateway_restart.py index 1460f00ea..822b22742 100644 --- a/tests/hermes_cli/test_update_gateway_restart.py +++ b/tests/hermes_cli/test_update_gateway_restart.py @@ -549,7 +549,7 @@ class TestServicePidExclusion: gateway_cli, "_get_service_pids", return_value={SERVICE_PID} ), patch.object( gateway_cli, "find_gateway_pids", - side_effect=lambda exclude_pids=None: ( + side_effect=lambda exclude_pids=None, all_profiles=False: ( [SERVICE_PID] if not exclude_pids else [p for p in [SERVICE_PID] if p not in exclude_pids] ), @@ -592,7 +592,7 @@ class TestServicePidExclusion: gateway_cli, "_get_service_pids", return_value={SERVICE_PID} ), patch.object( gateway_cli, "find_gateway_pids", - side_effect=lambda exclude_pids=None: ( + side_effect=lambda exclude_pids=None, all_profiles=False: ( [SERVICE_PID] if not exclude_pids else [p for p in [SERVICE_PID] if p not in exclude_pids] ), @@ -631,7 +631,7 @@ class TestServicePidExclusion: launchctl_loaded=True, ) - def fake_find(exclude_pids=None): + def fake_find(exclude_pids=None, all_profiles=False): _exclude = exclude_pids or set() return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude] From 1e5056ec30f4ef03789499311be774d1a41dc3c1 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:44:51 -0700 Subject: [PATCH 31/35] feat(gateway): add all missing platforms to interactive setup wizard (#7949) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire Signal, Email, SMS (Twilio), DingTalk, Feishu/Lark, and WeCom into the hermes setup gateway interactive wizard. These platforms all had working adapters and _PLATFORMS entries in gateway.py but were invisible in the setup checklist — users had to manually edit .env to configure them. Changes: - gateway.py: Add _setup_email/sms/dingtalk/feishu/wecom functions delegating to _setup_standard_platform (Signal already had a custom one) - setup.py: Add wrapper functions for all 6 new platforms - setup.py: Add all 6 to _GATEWAY_PLATFORMS checklist registry - setup.py: Add missing env vars to any_messaging check - setup.py: Add all missing platforms to _get_section_config_summary (was also missing Matrix, Mattermost, Weixin, Webhooks) - docs: Add FEISHU_ALLOWED_USERS and WECOM_ALLOWED_USERS examples Incorporates and extends the work from PR #7918 by bugmaker2. --- hermes_cli/gateway.py | 30 +++++++++ hermes_cli/setup.py | 71 +++++++++++++++++++++- website/docs/user-guide/messaging/index.md | 2 + 3 files changed, 101 insertions(+), 2 deletions(-) diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 633deac29..505bad0b5 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -2051,6 +2051,36 @@ def _setup_whatsapp(): cmd_whatsapp(argparse.Namespace()) +def _setup_email(): + """Configure Email via the standard platform setup.""" + email_platform = next(p for p in _PLATFORMS if p["key"] == "email") + _setup_standard_platform(email_platform) + + +def _setup_sms(): + """Configure SMS (Twilio) via the standard platform setup.""" + sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms") + _setup_standard_platform(sms_platform) + + +def _setup_dingtalk(): + """Configure DingTalk via the standard platform setup.""" + dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk") + _setup_standard_platform(dingtalk_platform) + + +def _setup_feishu(): + """Configure Feishu / Lark via the standard platform setup.""" + feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu") + _setup_standard_platform(feishu_platform) + + +def _setup_wecom(): + """Configure WeCom (Enterprise WeChat) via the standard platform setup.""" + wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom") + _setup_standard_platform(wecom_platform) + + def _is_service_installed() -> bool: """Check if the gateway is installed as a system service.""" if supports_systemd_services(): diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index fb70d9081..a25ce8491 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -1969,6 +1969,42 @@ def _setup_weixin(): _gateway_setup_weixin() +def _setup_signal(): + """Configure Signal via gateway setup.""" + from hermes_cli.gateway import _setup_signal as _gateway_setup_signal + _gateway_setup_signal() + + +def _setup_email(): + """Configure Email via gateway setup.""" + from hermes_cli.gateway import _setup_email as _gateway_setup_email + _gateway_setup_email() + + +def _setup_sms(): + """Configure SMS (Twilio) via gateway setup.""" + from hermes_cli.gateway import _setup_sms as _gateway_setup_sms + _gateway_setup_sms() + + +def _setup_dingtalk(): + """Configure DingTalk via gateway setup.""" + from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk + _gateway_setup_dingtalk() + + +def _setup_feishu(): + """Configure Feishu / Lark via gateway setup.""" + from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu + _gateway_setup_feishu() + + +def _setup_wecom(): + """Configure WeCom (Enterprise WeChat) via gateway setup.""" + from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom + _gateway_setup_wecom() + + def _setup_bluebubbles(): """Configure BlueBubbles iMessage gateway.""" print_header("BlueBubbles (iMessage)") @@ -2085,9 +2121,15 @@ _GATEWAY_PLATFORMS = [ ("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram), ("Discord", "DISCORD_BOT_TOKEN", _setup_discord), ("Slack", "SLACK_BOT_TOKEN", _setup_slack), + ("Signal", "SIGNAL_HTTP_URL", _setup_signal), + ("Email", "EMAIL_ADDRESS", _setup_email), + ("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms), ("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix), ("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost), ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), + ("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk), + ("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu), + ("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom), ("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin), ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), @@ -2129,10 +2171,17 @@ def setup_gateway(config: dict): get_env_value("TELEGRAM_BOT_TOKEN") or get_env_value("DISCORD_BOT_TOKEN") or get_env_value("SLACK_BOT_TOKEN") + or get_env_value("SIGNAL_HTTP_URL") + or get_env_value("EMAIL_ADDRESS") + or get_env_value("TWILIO_ACCOUNT_SID") or get_env_value("MATTERMOST_TOKEN") or get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD") or get_env_value("WHATSAPP_ENABLED") + or get_env_value("DINGTALK_CLIENT_ID") + or get_env_value("FEISHU_APP_ID") + or get_env_value("WECOM_BOT_ID") + or get_env_value("WEIXIN_ACCOUNT_ID") or get_env_value("BLUEBUBBLES_SERVER_URL") or get_env_value("WEBHOOK_ENABLED") ) @@ -2321,12 +2370,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str] platforms.append("Discord") if get_env_value("SLACK_BOT_TOKEN"): platforms.append("Slack") - if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): - platforms.append("WhatsApp") if get_env_value("SIGNAL_ACCOUNT"): platforms.append("Signal") + if get_env_value("EMAIL_ADDRESS"): + platforms.append("Email") + if get_env_value("TWILIO_ACCOUNT_SID"): + platforms.append("SMS") + if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"): + platforms.append("Matrix") + if get_env_value("MATTERMOST_TOKEN"): + platforms.append("Mattermost") + if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): + platforms.append("WhatsApp") + if get_env_value("DINGTALK_CLIENT_ID"): + platforms.append("DingTalk") + if get_env_value("FEISHU_APP_ID"): + platforms.append("Feishu") + if get_env_value("WECOM_BOT_ID"): + platforms.append("WeCom") + if get_env_value("WEIXIN_ACCOUNT_ID"): + platforms.append("Weixin") if get_env_value("BLUEBUBBLES_SERVER_URL"): platforms.append("BlueBubbles") + if get_env_value("WEBHOOK_ENABLED"): + platforms.append("Webhooks") if platforms: return ", ".join(platforms) return None # No platforms configured — section must run diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 335c6530b..41b031437 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -178,6 +178,8 @@ EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c MATRIX_ALLOWED_USERS=@alice:matrix.org DINGTALK_ALLOWED_USERS=user-id-1 +FEISHU_ALLOWED_USERS=ou_xxxxxxxx,ou_yyyyyyyy +WECOM_ALLOWED_USERS=user-id-1,user-id-2 # Or allow GATEWAY_ALLOWED_USERS=123456789,987654321 From 8c3935ebe82e91fadb561ad89e403beb66578bf0 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:46:18 -0700 Subject: [PATCH 32/35] fix: is_local_endpoint misses Docker/Podman DNS names (#7950) * fix(tools): neutralize shell injection in _write_to_sandbox via path quoting _write_to_sandbox interpolated storage_dir and remote_path directly into a shell command passed to env.execute(). Paths containing shell metacharacters (spaces, semicolons, $(), backticks) could trigger arbitrary command execution inside the sandbox. Fix: wrap both paths with shlex.quote(). Clean paths (alphanumeric + slashes/hyphens/dots) are left unmodified by shlex.quote, so existing behavior is unchanged. Paths with unsafe characters get single-quoted. Tests added for spaces, $(command) substitution, and semicolon injection. * fix: is_local_endpoint misses Docker/Podman DNS names host.docker.internal, host.containers.internal, gateway.docker.internal, and host.lima.internal are well-known DNS names that container runtimes use to resolve the host machine. Users running Ollama on the host with the agent in Docker/Podman hit the default 120s stream timeout instead of the bumped 1800s because these hostnames weren't recognized as local. Add _CONTAINER_LOCAL_SUFFIXES tuple and suffix check in is_local_endpoint(). Tests cover all three runtime families plus a negative case for domains that merely contain the suffix as a substring. --- agent/model_metadata.py | 9 ++++++ tests/agent/test_local_stream_timeout.py | 38 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 2ef6830e5..f12801777 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -179,6 +179,12 @@ _MAX_COMPLETION_KEYS = ( # Local server hostnames / address patterns _LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0") +# Docker / Podman / Lima DNS names that resolve to the host machine +_CONTAINER_LOCAL_SUFFIXES = ( + ".docker.internal", + ".containers.internal", + ".lima.internal", +) def _normalize_base_url(base_url: str) -> str: @@ -254,6 +260,9 @@ def is_local_endpoint(base_url: str) -> bool: return False if host in _LOCAL_HOSTS: return True + # Docker / Podman / Lima internal DNS names (e.g. host.docker.internal) + if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES): + return True # RFC-1918 private ranges and link-local import ipaddress try: diff --git a/tests/agent/test_local_stream_timeout.py b/tests/agent/test_local_stream_timeout.py index 929f2e3c8..8184dd2d4 100644 --- a/tests/agent/test_local_stream_timeout.py +++ b/tests/agent/test_local_stream_timeout.py @@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout: "http://0.0.0.0:5000", "http://192.168.1.100:8000", "http://10.0.0.5:1234", + "http://host.docker.internal:11434", + "http://host.containers.internal:11434", + "http://host.lima.internal:11434", ]) def test_local_endpoint_bumps_read_timeout(self, base_url): """Local endpoint + default timeout -> bumps to base_timeout.""" @@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout: if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url): _stream_read_timeout = _base_timeout assert _stream_read_timeout == 120.0 + + +class TestIsLocalEndpoint: + """Direct unit tests for is_local_endpoint.""" + + @pytest.mark.parametrize("url", [ + "http://localhost:11434", + "http://127.0.0.1:8080", + "http://0.0.0.0:5000", + "http://[::1]:11434", + "http://192.168.1.100:8000", + "http://10.0.0.5:1234", + "http://172.17.0.1:11434", + ]) + def test_classic_local_addresses(self, url): + assert is_local_endpoint(url) is True + + @pytest.mark.parametrize("url", [ + "http://host.docker.internal:11434", + "http://host.docker.internal:8080/v1", + "http://gateway.docker.internal:11434", + "http://host.containers.internal:11434", + "http://host.lima.internal:11434", + ]) + def test_container_dns_names(self, url): + assert is_local_endpoint(url) is True + + @pytest.mark.parametrize("url", [ + "https://api.openai.com", + "https://openrouter.ai/api", + "https://api.anthropic.com", + "https://evil.docker.internal.example.com", + ]) + def test_remote_endpoints(self, url): + assert is_local_endpoint(url) is False From b53f6819937533acf749fa961063687ca813b0f1 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:48:58 -0700 Subject: [PATCH 33/35] fix(cron): pass skip_context_files=True to AIAgent in run_job (#7958) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cron jobs run from whatever directory the scheduler process lives in (typically the hermes-agent install dir), so without this flag the agent picks up AGENTS.md, SOUL.md, or .cursorrules from that cwd — injecting irrelevant project context into the cron job's system prompt. batch_runner.py and gateway boot_md already pass skip_context_files=True for the same reason. This aligns cron with the established pattern for autonomous/headless agent runs. --- cron/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cron/scheduler.py b/cron/scheduler.py index 0e04fb047..870ebe141 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -722,6 +722,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: provider_sort=pr.get("sort"), disabled_toolsets=["cronjob", "messaging", "clarify"], quiet_mode=True, + skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd skip_memory=True, # Cron system prompts would corrupt user representations platform="cron", session_id=_cron_session_id, From ee39e88b037a8e85d949fead9abbd827d630bf90 Mon Sep 17 00:00:00 2001 From: SHL0MS Date: Sat, 11 Apr 2026 14:47:03 -0700 Subject: [PATCH 34/35] fix(claw): warn if gateway is running before migrating bot tokens When 'hermes claw migrate' copies Telegram/Discord/Slack bot tokens from OpenClaw while the Hermes gateway is already polling with those same tokens, the platforms conflict (e.g. Telegram 409). Add a pre-flight check that reads gateway_state.json via get_running_pid() + read_runtime_status(), warns the user, and lets them cancel or continue. Also improve the Telegram polling conflict error message to mention OpenClaw as a common cause and give the 'hermes start' restart command. Refs #7907 --- gateway/platforms/telegram.py | 6 ++++-- hermes_cli/claw.py | 39 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 884ef9c45..265329602 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -299,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter): # Exhausted retries — fatal message = ( - "Another Telegram bot poller is already using this token. " + "Another process is already polling this Telegram bot token " + "(possibly OpenClaw or another Hermes instance). " "Hermes stopped Telegram polling after %d retries. " - "Make sure only one gateway instance is running for this bot token." + "Only one poller can run per token — stop the other process " + "and restart with 'hermes start'." % MAX_CONFLICT_RETRIES ) logger.error("[%s] %s Original error: %s", self.name, message, error) diff --git a/hermes_cli/claw.py b/hermes_cli/claw.py index 3ab6bf9a8..d0bfd73d2 100644 --- a/hermes_cli/claw.py +++ b/hermes_cli/claw.py @@ -52,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = ( # Known OpenClaw directory names (current + legacy) _OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot") +def _warn_if_gateway_running(auto_yes: bool) -> None: + """Check if a Hermes gateway is running with connected platforms. + + Migrating bot tokens while the gateway is polling will cause conflicts + (e.g. Telegram 409 "terminated by other getUpdates request"). Warn the + user and let them decide whether to continue. + """ + from gateway.status import get_running_pid, read_runtime_status + + if not get_running_pid(): + return + + data = read_runtime_status() or {} + platforms = data.get("platforms") or {} + connected = [name for name, info in platforms.items() + if isinstance(info, dict) and info.get("state") == "connected"] + if not connected: + return + + print() + print_error( + "Hermes gateway is running with active connections: " + + ", ".join(connected) + ) + print_info( + "Migrating bot tokens while the gateway is active will cause " + "conflicts (Telegram, Discord, and Slack only allow one active " + "session per token)." + ) + print_info("Recommendation: stop the gateway first with 'hermes stop'.") + print() + if not auto_yes and not prompt_yes_no("Continue anyway?", default=False): + print_info("Migration cancelled. Stop the gateway and try again.") + sys.exit(0) + # State files commonly found in OpenClaw workspace directories that cause # confusion after migration (the agent discovers them and writes to them) _WORKSPACE_STATE_GLOBS = ( @@ -252,6 +287,10 @@ def _cmd_migrate(args): print_info(f"Workspace: {workspace_target}") print() + # Check if a gateway is running with connected platforms — migrating tokens + # while the gateway is active will cause conflicts (e.g. Telegram 409). + _warn_if_gateway_running(auto_yes) + # Ensure config.yaml exists before migration tries to read it config_path = get_config_path() if not config_path.exists(): From 90352b2adf30dadbf64e8fb74bc94b149f679581 Mon Sep 17 00:00:00 2001 From: faishal Date: Sat, 11 Apr 2026 14:48:51 -0700 Subject: [PATCH 35/35] fix: normalize checkpoint manager home-relative paths Adds _normalize_path() helper that calls expanduser().resolve() to properly handle tilde paths (e.g. ~/.hermes, ~/.config). Previously Path.resolve() alone treated ~ as a literal directory name, producing invalid paths like /root/~/.hermes. Also improves _run_git() error handling to distinguish missing working directories from missing git executable, and adds pre-flight directory validation. Cherry-picked from PR #7898 by faishal882. Fixes #7807 --- tests/tools/test_checkpoint_manager.py | 119 ++++++++++++++++++++++++- tools/checkpoint_manager.py | 49 +++++++--- 2 files changed, 150 insertions(+), 18 deletions(-) diff --git a/tests/tools/test_checkpoint_manager.py b/tests/tools/test_checkpoint_manager.py index ae03dc31b..ba9da6da1 100644 --- a/tests/tools/test_checkpoint_manager.py +++ b/tests/tools/test_checkpoint_manager.py @@ -1,9 +1,6 @@ """Tests for tools/checkpoint_manager.py — CheckpointManager.""" import logging -import os -import json -import shutil import subprocess import pytest from pathlib import Path @@ -42,6 +39,19 @@ def checkpoint_base(tmp_path): return tmp_path / "checkpoints" +@pytest.fixture() +def fake_home(tmp_path, monkeypatch): + """Set a deterministic fake home for expanduser/path-home behavior.""" + home = tmp_path / "home" + home.mkdir() + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + monkeypatch.delenv("HOMEDRIVE", raising=False) + monkeypatch.delenv("HOMEPATH", raising=False) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + return home + + @pytest.fixture() def mgr(work_dir, checkpoint_base, monkeypatch): """CheckpointManager with redirected checkpoint base.""" @@ -78,6 +88,16 @@ class TestShadowRepoPath: p = _shadow_repo_path(str(work_dir)) assert str(p).startswith(str(checkpoint_base)) + def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + project = fake_home / "project" + project.mkdir() + + tilde_path = f"~/{project.name}" + expanded_path = str(project) + + assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path) + # ========================================================================= # Shadow repo init @@ -221,6 +241,20 @@ class TestListCheckpoints: assert result[0]["reason"] == "third" assert result[2]["reason"] == "first" + def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + mgr = CheckpointManager(enabled=True, max_snapshots=50) + project = fake_home / "project" + project.mkdir() + (project / "main.py").write_text("v1\n") + + tilde_path = f"~/{project.name}" + assert mgr.ensure_checkpoint(tilde_path, "initial") is True + + listed = mgr.list_checkpoints(str(project)) + assert len(listed) == 1 + assert listed[0]["reason"] == "initial" + # ========================================================================= # CheckpointManager — restoring @@ -271,6 +305,28 @@ class TestRestore: assert len(all_cps) >= 2 assert "pre-rollback" in all_cps[0]["reason"] + def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + mgr = CheckpointManager(enabled=True, max_snapshots=50) + project = fake_home / "project" + project.mkdir() + file_path = project / "main.py" + file_path.write_text("original\n") + + tilde_path = f"~/{project.name}" + assert mgr.ensure_checkpoint(tilde_path, "initial") is True + mgr.new_turn() + + file_path.write_text("changed\n") + checkpoints = mgr.list_checkpoints(str(project)) + diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"]) + assert diff_result["success"] is True + assert "main.py" in diff_result["diff"] + + restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"]) + assert restore_result["success"] is True + assert file_path.read_text() == "original\n" + # ========================================================================= # CheckpointManager — working dir resolution @@ -310,6 +366,19 @@ class TestWorkingDirResolution: result = mgr.get_working_dir_for_path(str(filepath)) assert result == str(filepath.parent) + def test_resolves_tilde_path_to_project_root(self, fake_home): + mgr = CheckpointManager(enabled=True) + project = fake_home / "myproject" + project.mkdir() + (project / "pyproject.toml").write_text("[project]\n") + subdir = project / "src" + subdir.mkdir() + filepath = subdir / "main.py" + filepath.write_text("x\n") + + result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py") + assert result == str(project) + # ========================================================================= # Git env isolation @@ -333,6 +402,14 @@ class TestGitEnvIsolation: env = _git_env(shadow, str(tmp_path)) assert "GIT_INDEX_FILE" not in env + def test_expands_tilde_in_work_tree(self, fake_home, tmp_path): + shadow = tmp_path / "shadow" + work = fake_home / "work" + work.mkdir() + + env = _git_env(shadow, f"~/{work.name}") + assert env["GIT_WORK_TREE"] == str(work.resolve()) + # ========================================================================= # format_checkpoint_list @@ -384,6 +461,8 @@ class TestErrorResilience: assert result is False def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog): + work = tmp_path / "work" + work.mkdir() completed = subprocess.CompletedProcess( args=["git", "diff", "--cached", "--quiet"], returncode=1, @@ -395,7 +474,7 @@ class TestErrorResilience: ok, stdout, stderr = _run_git( ["diff", "--cached", "--quiet"], tmp_path / "shadow", - str(tmp_path / "work"), + str(work), allowed_returncodes={1}, ) assert ok is False @@ -403,6 +482,38 @@ class TestErrorResilience: assert stderr == "" assert not caplog.records + def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog): + missing = tmp_path / "missing" + with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"): + ok, stdout, stderr = _run_git( + ["status"], + tmp_path / "shadow", + str(missing), + ) + assert ok is False + assert stdout == "" + assert "working directory not found" in stderr + assert not any("Git executable not found" in r.getMessage() for r in caplog.records) + + def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog): + work = tmp_path / "work" + work.mkdir() + + def raise_missing_git(*args, **kwargs): + raise FileNotFoundError(2, "No such file or directory", "git") + + monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git) + with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"): + ok, stdout, stderr = _run_git( + ["status"], + tmp_path / "shadow", + str(work), + ) + assert ok is False + assert stdout == "" + assert stderr == "git not found" + assert any("Git executable not found" in r.getMessage() for r in caplog.records) + def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch): """Checkpoint failures should never raise — they're silently logged.""" def broken_run_git(*args, **kwargs): diff --git a/tools/checkpoint_manager.py b/tools/checkpoint_manager.py index 3ea6b32fd..42900a643 100644 --- a/tools/checkpoint_manager.py +++ b/tools/checkpoint_manager.py @@ -100,7 +100,7 @@ def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]: if os.path.isabs(file_path): return f"File path must be relative, got absolute path: {file_path!r}" # Resolve and check containment within working_dir - abs_workdir = Path(working_dir).resolve() + abs_workdir = _normalize_path(working_dir) resolved = (abs_workdir / file_path).resolve() try: resolved.relative_to(abs_workdir) @@ -113,18 +113,24 @@ def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]: # Shadow repo helpers # --------------------------------------------------------------------------- +def _normalize_path(path_value: str) -> Path: + """Return a canonical absolute path for checkpoint operations.""" + return Path(path_value).expanduser().resolve() + + def _shadow_repo_path(working_dir: str) -> Path: """Deterministic shadow repo path: sha256(abs_path)[:16].""" - abs_path = str(Path(working_dir).resolve()) + abs_path = str(_normalize_path(working_dir)) dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16] return CHECKPOINT_BASE / dir_hash def _git_env(shadow_repo: Path, working_dir: str) -> dict: """Build env dict that redirects git to the shadow repo.""" + normalized_working_dir = _normalize_path(working_dir) env = os.environ.copy() env["GIT_DIR"] = str(shadow_repo) - env["GIT_WORK_TREE"] = str(Path(working_dir).resolve()) + env["GIT_WORK_TREE"] = str(normalized_working_dir) env.pop("GIT_INDEX_FILE", None) env.pop("GIT_NAMESPACE", None) env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None) @@ -144,7 +150,17 @@ def _run_git( exits while preserving the normal ``ok = (returncode == 0)`` contract. Example: ``git diff --cached --quiet`` returns 1 when changes exist. """ - env = _git_env(shadow_repo, working_dir) + normalized_working_dir = _normalize_path(working_dir) + if not normalized_working_dir.exists(): + msg = f"working directory not found: {normalized_working_dir}" + logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg) + return False, "", msg + if not normalized_working_dir.is_dir(): + msg = f"working directory is not a directory: {normalized_working_dir}" + logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg) + return False, "", msg + + env = _git_env(shadow_repo, str(normalized_working_dir)) cmd = ["git"] + list(args) allowed_returncodes = allowed_returncodes or set() try: @@ -154,7 +170,7 @@ def _run_git( text=True, timeout=timeout, env=env, - cwd=str(Path(working_dir).resolve()), + cwd=str(normalized_working_dir), ) ok = result.returncode == 0 stdout = result.stdout.strip() @@ -169,9 +185,14 @@ def _run_git( msg = f"git timed out after {timeout}s: {' '.join(cmd)}" logger.error(msg, exc_info=True) return False, "", msg - except FileNotFoundError: - logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True) - return False, "", "git not found" + except FileNotFoundError as exc: + missing_target = getattr(exc, "filename", None) + if missing_target == "git": + logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True) + return False, "", "git not found" + msg = f"working directory not found: {normalized_working_dir}" + logger.error("Git command failed before execution: %s (%s)", " ".join(cmd), msg, exc_info=True) + return False, "", msg except Exception as exc: logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True) return False, "", str(exc) @@ -198,7 +219,7 @@ def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]: ) (shadow_repo / "HERMES_WORKDIR").write_text( - str(Path(working_dir).resolve()) + "\n", encoding="utf-8" + str(_normalize_path(working_dir)) + "\n", encoding="utf-8" ) logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir) @@ -273,7 +294,7 @@ class CheckpointManager: if not self._git_available: return False - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) # Skip root, home, and other overly broad directories if abs_dir in ("/", str(Path.home())): @@ -298,7 +319,7 @@ class CheckpointManager: Returns a list of dicts with keys: hash, short_hash, timestamp, reason, files_changed, insertions, deletions. Most recent first. """ - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -360,7 +381,7 @@ class CheckpointManager: if hash_err: return {"success": False, "error": hash_err} - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -418,7 +439,7 @@ class CheckpointManager: if hash_err: return {"success": False, "error": hash_err} - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) # Validate file_path to prevent path traversal outside the working dir if file_path: @@ -474,7 +495,7 @@ class CheckpointManager: (directory containing .git, pyproject.toml, package.json, etc.). Falls back to the file's parent directory. """ - path = Path(file_path).resolve() + path = _normalize_path(file_path) if path.is_dir(): candidate = path else: