Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor

This commit is contained in:
Brooklyn Nicholson
2026-04-13 18:32:13 -05:00
220 changed files with 23482 additions and 1959 deletions

View File

@@ -17,7 +17,6 @@ from agent.auxiliary_client import (
call_llm,
async_call_llm,
_read_codex_access_token,
_get_auxiliary_provider,
_get_provider_chain,
_is_payment_error,
_try_payment_fallback,
@@ -32,12 +31,6 @@ def _clean_env(monkeypatch):
"OPENROUTER_API_KEY", "OPENAI_BASE_URL", "OPENAI_API_KEY",
"OPENAI_MODEL", "LLM_MODEL", "NOUS_INFERENCE_BASE_URL",
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN",
# Per-task provider/model/direct-endpoint overrides
"AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL",
"AUXILIARY_VISION_BASE_URL", "AUXILIARY_VISION_API_KEY",
"AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL",
"AUXILIARY_WEB_EXTRACT_BASE_URL", "AUXILIARY_WEB_EXTRACT_API_KEY",
"CONTEXT_COMPRESSION_PROVIDER", "CONTEXT_COMPRESSION_MODEL",
):
monkeypatch.delenv(key, raising=False)
@@ -568,29 +561,6 @@ class TestGetTextAuxiliaryClient:
call_kwargs = mock_openai.call_args
assert call_kwargs.kwargs["base_url"] == "http://localhost:1234/v1"
def test_task_direct_endpoint_override(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1")
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_API_KEY", "task-key")
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_text_auxiliary_client("web_extract")
assert model == "task-model"
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1"
assert mock_openai.call_args.kwargs["api_key"] == "task-key"
def test_task_direct_endpoint_without_openai_key_uses_placeholder(self, monkeypatch):
"""Local endpoints without an API key should use 'no-key-required' placeholder."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1")
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_text_auxiliary_client("web_extract")
assert client is not None
assert model == "task-model"
assert mock_openai.call_args.kwargs["api_key"] == "no-key-required"
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1"
def test_custom_endpoint_uses_config_saved_base_url(self, monkeypatch):
config = {
"model": {
@@ -879,73 +849,9 @@ class TestAuxiliaryPoolAwareness:
class TestGetAuxiliaryProvider:
"""Tests for _get_auxiliary_provider env var resolution."""
def test_no_task_returns_auto(self):
assert _get_auxiliary_provider() == "auto"
assert _get_auxiliary_provider("") == "auto"
def test_auxiliary_prefix_takes_priority(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "openrouter")
assert _get_auxiliary_provider("vision") == "openrouter"
def test_context_prefix_fallback(self, monkeypatch):
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
assert _get_auxiliary_provider("compression") == "nous"
def test_auxiliary_prefix_over_context_prefix(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_COMPRESSION_PROVIDER", "openrouter")
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
assert _get_auxiliary_provider("compression") == "openrouter"
def test_auto_value_treated_as_auto(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "auto")
assert _get_auxiliary_provider("vision") == "auto"
def test_whitespace_stripped(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", " openrouter ")
assert _get_auxiliary_provider("vision") == "openrouter"
def test_case_insensitive(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "OpenRouter")
assert _get_auxiliary_provider("vision") == "openrouter"
def test_main_provider(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_PROVIDER", "main")
assert _get_auxiliary_provider("web_extract") == "main"
class TestTaskSpecificOverrides:
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
def test_text_with_vision_provider_override(self, monkeypatch):
"""AUXILIARY_VISION_PROVIDER should not affect text tasks."""
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "nous")
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI"):
client, model = get_text_auxiliary_client() # no task → auto
assert model == "google/gemini-3-flash-preview" # OpenRouter, not Nous
def test_compression_task_reads_context_prefix(self, monkeypatch):
"""Compression task should check CONTEXT_COMPRESSION_PROVIDER env var."""
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") # would win in auto
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
patch("agent.auxiliary_client.OpenAI"):
mock_nous.return_value = {"access_token": "***"}
client, model = get_text_auxiliary_client("compression")
# Config-first: model comes from config.yaml summary_model default,
# but provider is forced to Nous via env var
assert client is not None
def test_web_extract_task_override(self, monkeypatch):
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_PROVIDER", "openrouter")
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI"):
client, model = get_text_auxiliary_client("web_extract")
assert model == "google/gemini-3-flash-preview"
def test_task_direct_endpoint_from_config(self, monkeypatch, tmp_path):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
@@ -979,8 +885,6 @@ class TestTaskSpecificOverrides:
"""model:
default: glm-5.1
provider: opencode-go
compression:
summary_provider: auto
"""
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@@ -1039,24 +943,45 @@ model:
"model": "gpt-5.4",
}
def test_compression_summary_base_url_from_config(self, monkeypatch, tmp_path):
"""compression.summary_base_url should produce a custom-endpoint client."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "config.yaml").write_text(
"""compression:
summary_provider: custom
summary_model: glm-4.7
summary_base_url: https://api.z.ai/api/coding/paas/v4
"""
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
# Custom endpoints need an API key to build the client
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_text_auxiliary_client("compression")
assert model == "glm-4.7"
assert mock_openai.call_args.kwargs["base_url"] == "https://api.z.ai/api/coding/paas/v4"
def test_resolve_provider_client_supports_copilot_acp_external_process():
fake_client = MagicMock()
with patch("agent.auxiliary_client._read_main_model", return_value="gpt-5.4-mini"), \
patch("agent.auxiliary_client.CodexAuxiliaryClient", MagicMock()), \
patch("agent.copilot_acp_client.CopilotACPClient", return_value=fake_client) as mock_acp, \
patch("hermes_cli.auth.resolve_external_process_provider_credentials", return_value={
"provider": "copilot-acp",
"api_key": "copilot-acp",
"base_url": "acp://copilot",
"command": "/usr/bin/copilot",
"args": ["--acp", "--stdio"],
}):
client, model = resolve_provider_client("copilot-acp")
assert client is fake_client
assert model == "gpt-5.4-mini"
assert mock_acp.call_args.kwargs["api_key"] == "copilot-acp"
assert mock_acp.call_args.kwargs["base_url"] == "acp://copilot"
assert mock_acp.call_args.kwargs["command"] == "/usr/bin/copilot"
assert mock_acp.call_args.kwargs["args"] == ["--acp", "--stdio"]
def test_resolve_provider_client_copilot_acp_requires_explicit_or_configured_model():
with patch("agent.auxiliary_client._read_main_model", return_value=""), \
patch("agent.copilot_acp_client.CopilotACPClient") as mock_acp, \
patch("hermes_cli.auth.resolve_external_process_provider_credentials", return_value={
"provider": "copilot-acp",
"api_key": "copilot-acp",
"base_url": "acp://copilot",
"command": "/usr/bin/copilot",
"args": ["--acp", "--stdio"],
}):
client, model = resolve_provider_client("copilot-acp")
assert client is None
assert model is None
mock_acp.assert_not_called()
class TestAuxiliaryMaxTokensParam:

View File

@@ -273,18 +273,6 @@ class TestDefaultConfigShape:
assert web["provider"] == "auto"
assert web["model"] == ""
def test_compression_provider_default(self):
from hermes_cli.config import DEFAULT_CONFIG
compression = DEFAULT_CONFIG["compression"]
assert "summary_provider" in compression
assert compression["summary_provider"] == "auto"
def test_compression_base_url_default(self):
from hermes_cli.config import DEFAULT_CONFIG
compression = DEFAULT_CONFIG["compression"]
assert "summary_base_url" in compression
assert compression["summary_base_url"] is None
# ── CLI defaults parity ─────────────────────────────────────────────────────

View File

@@ -12,17 +12,6 @@ def _isolate(tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
for env_var in (
"AUXILIARY_VISION_PROVIDER",
"AUXILIARY_VISION_MODEL",
"AUXILIARY_VISION_BASE_URL",
"AUXILIARY_VISION_API_KEY",
"CONTEXT_VISION_PROVIDER",
"CONTEXT_VISION_MODEL",
"CONTEXT_VISION_BASE_URL",
"CONTEXT_VISION_API_KEY",
):
monkeypatch.delenv(env_var, raising=False)
# Write a minimal config so load_config doesn't fail
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
@@ -69,6 +58,10 @@ class TestNormalizeVisionProvider:
assert _normalize_vision_provider("beans") == "beans"
assert _normalize_vision_provider("deepseek") == "deepseek"
def test_custom_colon_named_provider_preserved(self):
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("custom:beans") == "beans"
def test_codex_alias_still_works(self):
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("codex") == "openai-codex"
@@ -240,3 +233,22 @@ class TestResolveVisionProviderClientModelNormalization:
assert provider == "zai"
assert client is not None
assert model == "glm-5.1"
class TestVisionPathApiMode:
"""Vision path should propagate api_mode to _get_cached_client."""
def test_explicit_provider_passes_api_mode(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "test-model"},
"auxiliary": {"vision": {"api_mode": "chat_completions"}},
})
with patch("agent.auxiliary_client._get_cached_client") as mock_gcc:
mock_gcc.return_value = (MagicMock(), "test-model")
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client(provider="deepseek")
mock_gcc.assert_called_once()
_, kwargs = mock_gcc.call_args
assert kwargs.get("api_mode") == "chat_completions"

View File

@@ -308,6 +308,34 @@ class TestMinimaxPreserveDots:
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is False
def test_opencode_zen_provider_preserves_dots(self):
from types import SimpleNamespace
agent = SimpleNamespace(provider="opencode-zen", base_url="")
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_opencode_zen_url_preserves_dots(self):
from types import SimpleNamespace
agent = SimpleNamespace(provider="custom", base_url="https://opencode.ai/zen/v1")
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_zai_provider_preserves_dots(self):
from types import SimpleNamespace
agent = SimpleNamespace(provider="zai", base_url="")
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_bigmodel_cn_url_preserves_dots(self):
from types import SimpleNamespace
agent = SimpleNamespace(provider="custom", base_url="https://open.bigmodel.cn/api/paas/v4")
from run_agent import AIAgent
assert AIAgent._anthropic_preserve_dots(agent) is True
def test_normalize_preserves_m25_free_dot(self):
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name("minimax-m2.5-free", preserve_dots=True) == "minimax-m2.5-free"
def test_normalize_preserves_m27_dot(self):
from agent.anthropic_adapter import normalize_model_name
assert normalize_model_name("MiniMax-M2.7", preserve_dots=True) == "MiniMax-M2.7"

View File

@@ -70,6 +70,44 @@ class TestQueryLocalContextLengthOllama:
assert result == 32768
def test_ollama_num_ctx_wins_over_model_info(self):
"""When both num_ctx (Modelfile) and model_info (GGUF) are present,
num_ctx wins because it's the *runtime* context Ollama actually
allocates KV cache for. The GGUF model_info.context_length is the
training max — using it would let Hermes grow conversations past
the runtime limit and Ollama would silently truncate.
Concrete example: hermes-brain:qwen3-14b-ctx32k is a Modelfile
derived from qwen3:14b with `num_ctx 32768`, but the underlying
GGUF reports `qwen3.context_length: 40960` (training max). If
Hermes used 40960 it would let the conversation grow past 32768
before compressing, and Ollama would truncate the prefix.
"""
from agent.model_metadata import _query_local_context_length
show_resp = self._make_resp(200, {
"model_info": {"qwen3.context_length": 40960},
"parameters": "num_ctx 32768\ntemperature 0.6\n",
})
models_resp = self._make_resp(404, {})
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = show_resp
client_mock.get.return_value = models_resp
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length(
"hermes-brain:qwen3-14b-ctx32k", "http://100.77.243.5:11434/v1"
)
assert result == 32768, (
f"Expected num_ctx (32768) to win over model_info (40960), got {result}. "
"If Hermes uses the GGUF training max, conversations will silently truncate."
)
def test_ollama_show_404_falls_through(self):
"""When /api/show returns 404, falls through to /v1/models/{model}."""
from agent.model_metadata import _query_local_context_length

View File

@@ -51,10 +51,10 @@ class TestSaveConfigValueAtomic:
def test_creates_nested_keys(self, config_env):
"""Dot-separated paths create intermediate dicts as needed."""
from cli import save_config_value
save_config_value("compression.summary_model", "google/gemini-3-flash-preview")
save_config_value("auxiliary.compression.model", "google/gemini-3-flash-preview")
result = yaml.safe_load(config_env.read_text())
assert result["compression"]["summary_model"] == "google/gemini-3-flash-preview"
assert result["auxiliary"]["compression"]["model"] == "google/gemini-3-flash-preview"
def test_overwrites_existing_value(self, config_env):
"""Updating an existing key replaces the value."""

View File

@@ -180,33 +180,71 @@ class TestDisplayResumedHistory:
assert 200 <= a_count <= 310 # roughly 300 chars (±panel padding)
def test_long_assistant_message_truncated(self):
"""Non-last assistant messages are still truncated."""
cli = _make_cli()
long_text = "B" * 400
cli.conversation_history = [
{"role": "user", "content": "Tell me a lot."},
{"role": "assistant", "content": long_text},
{"role": "user", "content": "And more?"},
{"role": "assistant", "content": "Short final reply."},
]
output = self._capture_display(cli)
assert "..." in output
# The non-last assistant message should be truncated
assert "B" * 400 not in output
# The last assistant message shown in full
assert "Short final reply." in output
def test_multiline_assistant_truncated(self):
"""Non-last multiline assistant messages are truncated to 3 lines."""
cli = _make_cli()
multi = "\n".join([f"Line {i}" for i in range(20)])
cli.conversation_history = [
{"role": "user", "content": "Show me lines."},
{"role": "assistant", "content": multi},
{"role": "user", "content": "What else?"},
{"role": "assistant", "content": "Done."},
]
output = self._capture_display(cli)
# First 3 lines should be there
# First 3 lines of non-last assistant should be there
assert "Line 0" in output
assert "Line 1" in output
assert "Line 2" in output
# Line 19 should NOT be there (truncated after 3 lines)
# Line 19 should NOT be in the truncated message
assert "Line 19" not in output
def test_last_assistant_response_shown_in_full(self):
"""The last assistant response is shown un-truncated so the user
knows where they left off without wasting tokens re-asking."""
cli = _make_cli()
long_text = "X" * 500
cli.conversation_history = [
{"role": "user", "content": "Tell me everything."},
{"role": "assistant", "content": long_text},
]
output = self._capture_display(cli)
# Full 500-char text should be present (may be line-wrapped by Rich)
x_count = output.count("X")
assert x_count >= 490 # allow small Rich formatting variance
def test_last_assistant_multiline_shown_in_full(self):
"""The last assistant response shows all lines, not just 3."""
cli = _make_cli()
multi = "\n".join([f"Line {i}" for i in range(20)])
cli.conversation_history = [
{"role": "user", "content": "Show me everything."},
{"role": "assistant", "content": multi},
]
output = self._capture_display(cli)
# All 20 lines should be present since it's the last response
assert "Line 0" in output
assert "Line 10" in output
assert "Line 19" in output
def test_large_history_shows_truncation_indicator(self):
cli = _make_cli()
cli.conversation_history = _large_history(n_exchanges=15)

View File

@@ -35,6 +35,7 @@ def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> Sessi
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
user_id="u1",
)

View File

@@ -0,0 +1,87 @@
"""Tests for _normalize_chat_content in the API server adapter."""
from gateway.platforms.api_server import _normalize_chat_content
class TestNormalizeChatContent:
"""Content normalization converts array-based content parts to plain text."""
def test_none_returns_empty_string(self):
assert _normalize_chat_content(None) == ""
def test_plain_string_returned_as_is(self):
assert _normalize_chat_content("hello world") == "hello world"
def test_empty_string_returned_as_is(self):
assert _normalize_chat_content("") == ""
def test_text_content_part(self):
content = [{"type": "text", "text": "hello"}]
assert _normalize_chat_content(content) == "hello"
def test_input_text_content_part(self):
content = [{"type": "input_text", "text": "user input"}]
assert _normalize_chat_content(content) == "user input"
def test_output_text_content_part(self):
content = [{"type": "output_text", "text": "assistant output"}]
assert _normalize_chat_content(content) == "assistant output"
def test_multiple_text_parts_joined_with_newline(self):
content = [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
assert _normalize_chat_content(content) == "first\nsecond"
def test_mixed_string_and_dict_parts(self):
content = ["plain string", {"type": "text", "text": "dict part"}]
assert _normalize_chat_content(content) == "plain string\ndict part"
def test_image_url_parts_silently_skipped(self):
content = [
{"type": "text", "text": "check this:"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
]
assert _normalize_chat_content(content) == "check this:"
def test_integer_content_converted(self):
assert _normalize_chat_content(42) == "42"
def test_boolean_content_converted(self):
assert _normalize_chat_content(True) == "True"
def test_deeply_nested_list_respects_depth_limit(self):
"""Nesting beyond max_depth returns empty string."""
content = [[[[[[[[[[[["deep"]]]]]]]]]]]]
result = _normalize_chat_content(content)
# The deep nesting should be truncated, not crash
assert isinstance(result, str)
def test_large_list_capped(self):
"""Lists beyond MAX_CONTENT_LIST_SIZE are truncated."""
content = [{"type": "text", "text": f"item{i}"} for i in range(2000)]
result = _normalize_chat_content(content)
# Should not contain all 2000 items
assert result.count("item") <= 1000
def test_oversized_string_truncated(self):
"""Strings beyond 64KB are truncated."""
huge = "x" * 100_000
result = _normalize_chat_content(huge)
assert len(result) == 65_536
def test_empty_text_parts_filtered(self):
content = [
{"type": "text", "text": ""},
{"type": "text", "text": "actual"},
{"type": "text", "text": ""},
]
assert _normalize_chat_content(content) == "actual"
def test_dict_without_type_skipped(self):
content = [{"foo": "bar"}, {"type": "text", "text": "real"}]
assert _normalize_chat_content(content) == "real"
def test_empty_list_returns_empty(self):
assert _normalize_chat_content([]) == ""

View File

@@ -359,3 +359,44 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
await adapter._handle_message(message)
assert "777" in adapter._threads
@pytest.mark.asyncio
async def test_discord_voice_linked_channel_skips_mention_requirement_and_auto_thread(adapter, monkeypatch):
"""Active voice-linked text channels should behave like free-response channels."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
adapter._voice_text_channels[111] = 789
adapter._auto_create_thread = AsyncMock()
message = make_message(
channel=FakeTextChannel(channel_id=789),
content="follow-up from voice text chat",
)
await adapter._handle_message(message)
adapter._auto_create_thread.assert_not_awaited()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "follow-up from voice text chat"
assert event.source.chat_type == "group"
@pytest.mark.asyncio
async def test_discord_voice_linked_parent_thread_still_requires_mention(adapter, monkeypatch):
"""Threads under a voice-linked channel should still require @mention."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
adapter._voice_text_channels[111] = 789
message = make_message(
channel=FakeThread(channel_id=790, parent=FakeTextChannel(channel_id=789)),
content="thread reply without mention",
)
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()

View File

@@ -124,7 +124,7 @@ class TestSendWithReplyToMode:
@pytest.mark.asyncio
async def test_off_mode_no_reply_reference(self):
adapter, channel, ref_msg = _make_discord_adapter("off")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
@@ -137,7 +137,7 @@ class TestSendWithReplyToMode:
@pytest.mark.asyncio
async def test_first_mode_only_first_chunk_references(self):
adapter, channel, ref_msg = _make_discord_adapter("first")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
@@ -152,7 +152,7 @@ class TestSendWithReplyToMode:
@pytest.mark.asyncio
async def test_all_mode_all_chunks_reference(self):
adapter, channel, ref_msg = _make_discord_adapter("all")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
@@ -165,7 +165,7 @@ class TestSendWithReplyToMode:
@pytest.mark.asyncio
async def test_no_reply_to_param_no_reference(self):
adapter, channel, ref_msg = _make_discord_adapter("all")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
await adapter.send("12345", "test content", reply_to=None)
@@ -176,7 +176,7 @@ class TestSendWithReplyToMode:
@pytest.mark.asyncio
async def test_single_chunk_respects_first_mode(self):
adapter, channel, ref_msg = _make_discord_adapter("first")
adapter.truncate_message = lambda content, max_len: ["single chunk"]
adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
await adapter.send("12345", "test", reply_to="999")
@@ -187,7 +187,7 @@ class TestSendWithReplyToMode:
@pytest.mark.asyncio
async def test_single_chunk_off_mode(self):
adapter, channel, ref_msg = _make_discord_adapter("off")
adapter.truncate_message = lambda content, max_len: ["single chunk"]
adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
await adapter.send("12345", "test", reply_to="999")
@@ -200,7 +200,7 @@ class TestSendWithReplyToMode:
async def test_invalid_mode_falls_back_to_first_behavior(self):
"""Invalid mode behaves like 'first' — only first chunk gets reference."""
adapter, channel, ref_msg = _make_discord_adapter("banana")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
await adapter.send("12345", "test", reply_to="999")

View File

@@ -189,14 +189,14 @@ class TestPlatformDefaults:
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
from gateway.display_config import resolve_display_setting
for plat in ("slack", "mattermost", "matrix", "feishu"):
for plat in ("slack", "mattermost", "matrix", "feishu", "whatsapp"):
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
def test_low_tier_platforms(self):
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
"""Signal, BlueBubbles, etc. default to 'off' tool progress."""
from gateway.display_config import resolve_display_setting
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
for plat in ("signal", "bluebubbles", "weixin", "wecom", "dingtalk"):
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
def test_minimal_tier_platforms(self):

View File

@@ -0,0 +1,438 @@
"""Tests for gateway.platforms.feishu — Feishu scan-to-create registration."""
import json
from unittest.mock import patch, MagicMock
import pytest
def _mock_urlopen(response_data, status=200):
"""Create a mock for urllib.request.urlopen that returns JSON response_data."""
mock_response = MagicMock()
mock_response.read.return_value = json.dumps(response_data).encode("utf-8")
mock_response.status = status
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
return mock_response
class TestPostRegistration:
"""Tests for the low-level HTTP helper."""
@patch("gateway.platforms.feishu.urlopen")
def test_post_registration_returns_parsed_json(self, mock_urlopen_fn):
from gateway.platforms.feishu import _post_registration
mock_urlopen_fn.return_value = _mock_urlopen({"nonce": "abc", "supported_auth_methods": ["client_secret"]})
result = _post_registration("https://accounts.feishu.cn", {"action": "init"})
assert result["nonce"] == "abc"
assert "client_secret" in result["supported_auth_methods"]
@patch("gateway.platforms.feishu.urlopen")
def test_post_registration_sends_form_encoded_body(self, mock_urlopen_fn):
from gateway.platforms.feishu import _post_registration
mock_urlopen_fn.return_value = _mock_urlopen({})
_post_registration("https://accounts.feishu.cn", {"action": "init", "key": "val"})
call_args = mock_urlopen_fn.call_args
request = call_args[0][0]
body = request.data.decode("utf-8")
assert "action=init" in body
assert "key=val" in body
assert request.get_header("Content-type") == "application/x-www-form-urlencoded"
class TestInitRegistration:
"""Tests for the init step."""
@patch("gateway.platforms.feishu.urlopen")
def test_init_succeeds_when_client_secret_supported(self, mock_urlopen_fn):
from gateway.platforms.feishu import _init_registration
mock_urlopen_fn.return_value = _mock_urlopen({
"nonce": "abc",
"supported_auth_methods": ["client_secret"],
})
_init_registration("feishu")
@patch("gateway.platforms.feishu.urlopen")
def test_init_raises_when_client_secret_not_supported(self, mock_urlopen_fn):
from gateway.platforms.feishu import _init_registration
mock_urlopen_fn.return_value = _mock_urlopen({
"nonce": "abc",
"supported_auth_methods": ["other_method"],
})
with pytest.raises(RuntimeError, match="client_secret"):
_init_registration("feishu")
@patch("gateway.platforms.feishu.urlopen")
def test_init_uses_lark_url_for_lark_domain(self, mock_urlopen_fn):
from gateway.platforms.feishu import _init_registration
mock_urlopen_fn.return_value = _mock_urlopen({
"nonce": "abc",
"supported_auth_methods": ["client_secret"],
})
_init_registration("lark")
call_args = mock_urlopen_fn.call_args
request = call_args[0][0]
assert "larksuite.com" in request.full_url
class TestBeginRegistration:
"""Tests for the begin step."""
@patch("gateway.platforms.feishu.urlopen")
def test_begin_returns_device_code_and_qr_url(self, mock_urlopen_fn):
from gateway.platforms.feishu import _begin_registration
mock_urlopen_fn.return_value = _mock_urlopen({
"device_code": "dc_123",
"verification_uri_complete": "https://accounts.feishu.cn/qr/abc",
"user_code": "ABCD-1234",
"interval": 5,
"expire_in": 600,
})
result = _begin_registration("feishu")
assert result["device_code"] == "dc_123"
assert "qr_url" in result
assert "accounts.feishu.cn" in result["qr_url"]
assert result["user_code"] == "ABCD-1234"
assert result["interval"] == 5
assert result["expire_in"] == 600
@patch("gateway.platforms.feishu.urlopen")
def test_begin_sends_correct_archetype(self, mock_urlopen_fn):
from gateway.platforms.feishu import _begin_registration
mock_urlopen_fn.return_value = _mock_urlopen({
"device_code": "dc_123",
"verification_uri_complete": "https://example.com/qr",
"user_code": "X",
"interval": 5,
"expire_in": 600,
})
_begin_registration("feishu")
request = mock_urlopen_fn.call_args[0][0]
body = request.data.decode("utf-8")
assert "archetype=PersonalAgent" in body
assert "auth_method=client_secret" in body
class TestPollRegistration:
"""Tests for the poll step."""
@patch("gateway.platforms.feishu.time")
@patch("gateway.platforms.feishu.urlopen")
def test_poll_returns_credentials_on_success(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1]
mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({
"client_id": "cli_app123",
"client_secret": "secret456",
"user_info": {"open_id": "ou_owner", "tenant_brand": "feishu"},
})
result = _poll_registration(
device_code="dc_123", interval=1, expire_in=60, domain="feishu"
)
assert result is not None
assert result["app_id"] == "cli_app123"
assert result["app_secret"] == "secret456"
assert result["domain"] == "feishu"
assert result["open_id"] == "ou_owner"
@patch("gateway.platforms.feishu.time")
@patch("gateway.platforms.feishu.urlopen")
def test_poll_switches_domain_on_lark_tenant_brand(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1, 2]
mock_time.sleep = MagicMock()
pending_resp = _mock_urlopen({
"error": "authorization_pending",
"user_info": {"tenant_brand": "lark"},
})
success_resp = _mock_urlopen({
"client_id": "cli_lark",
"client_secret": "secret_lark",
"user_info": {"open_id": "ou_lark", "tenant_brand": "lark"},
})
mock_urlopen_fn.side_effect = [pending_resp, success_resp]
result = _poll_registration(
device_code="dc_123", interval=0, expire_in=60, domain="feishu"
)
assert result is not None
assert result["domain"] == "lark"
@patch("gateway.platforms.feishu.time")
@patch("gateway.platforms.feishu.urlopen")
def test_poll_success_with_lark_brand_in_same_response(self, mock_urlopen_fn, mock_time):
"""Credentials and lark tenant_brand in one response must not be discarded."""
from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1]
mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({
"client_id": "cli_lark_direct",
"client_secret": "secret_lark_direct",
"user_info": {"open_id": "ou_lark_direct", "tenant_brand": "lark"},
})
result = _poll_registration(
device_code="dc_123", interval=1, expire_in=60, domain="feishu"
)
assert result is not None
assert result["app_id"] == "cli_lark_direct"
assert result["domain"] == "lark"
assert result["open_id"] == "ou_lark_direct"
@patch("gateway.platforms.feishu.time")
@patch("gateway.platforms.feishu.urlopen")
def test_poll_returns_none_on_access_denied(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1]
mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({
"error": "access_denied",
})
result = _poll_registration(
device_code="dc_123", interval=1, expire_in=60, domain="feishu"
)
assert result is None
@patch("gateway.platforms.feishu.time")
@patch("gateway.platforms.feishu.urlopen")
def test_poll_returns_none_on_timeout(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 999]
mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({
"error": "authorization_pending",
})
result = _poll_registration(
device_code="dc_123", interval=1, expire_in=1, domain="feishu"
)
assert result is None
class TestRenderQr:
"""Tests for QR code terminal rendering."""
@patch("gateway.platforms.feishu._qrcode_mod", create=True)
def test_render_qr_returns_true_on_success(self, mock_qrcode_mod):
from gateway.platforms.feishu import _render_qr
mock_qr = MagicMock()
mock_qrcode_mod.QRCode.return_value = mock_qr
assert _render_qr("https://example.com/qr") is True
mock_qr.add_data.assert_called_once_with("https://example.com/qr")
mock_qr.make.assert_called_once_with(fit=True)
mock_qr.print_ascii.assert_called_once()
def test_render_qr_returns_false_when_qrcode_missing(self):
from gateway.platforms.feishu import _render_qr
with patch("gateway.platforms.feishu._qrcode_mod", None):
assert _render_qr("https://example.com/qr") is False
class TestProbeBot:
"""Tests for bot connectivity verification."""
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True)
def test_probe_returns_bot_info_on_success(self):
from gateway.platforms.feishu import probe_bot
with patch("gateway.platforms.feishu._probe_bot_sdk") as mock_sdk:
mock_sdk.return_value = {"bot_name": "TestBot", "bot_open_id": "ou_bot123"}
result = probe_bot("cli_app", "secret", "feishu")
assert result is not None
assert result["bot_name"] == "TestBot"
assert result["bot_open_id"] == "ou_bot123"
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True)
def test_probe_returns_none_on_failure(self):
from gateway.platforms.feishu import probe_bot
with patch("gateway.platforms.feishu._probe_bot_sdk") as mock_sdk:
mock_sdk.return_value = None
result = probe_bot("bad_id", "bad_secret", "feishu")
assert result is None
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", False)
@patch("gateway.platforms.feishu.urlopen")
def test_http_fallback_when_sdk_unavailable(self, mock_urlopen_fn):
"""Without lark_oapi, probe falls back to raw HTTP."""
from gateway.platforms.feishu import probe_bot
token_resp = _mock_urlopen({"code": 0, "tenant_access_token": "t-123"})
bot_resp = _mock_urlopen({"code": 0, "bot": {"bot_name": "HttpBot", "open_id": "ou_http"}})
mock_urlopen_fn.side_effect = [token_resp, bot_resp]
result = probe_bot("cli_app", "secret", "feishu")
assert result is not None
assert result["bot_name"] == "HttpBot"
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", False)
@patch("gateway.platforms.feishu.urlopen")
def test_http_fallback_returns_none_on_network_error(self, mock_urlopen_fn):
from gateway.platforms.feishu import probe_bot
from urllib.error import URLError
mock_urlopen_fn.side_effect = URLError("connection refused")
result = probe_bot("cli_app", "secret", "feishu")
assert result is None
class TestQrRegister:
"""Tests for the public qr_register entry point."""
@patch("gateway.platforms.feishu.probe_bot")
@patch("gateway.platforms.feishu._render_qr")
@patch("gateway.platforms.feishu._poll_registration")
@patch("gateway.platforms.feishu._begin_registration")
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_success_flow(
self, mock_init, mock_begin, mock_poll, mock_render, mock_probe
):
from gateway.platforms.feishu import qr_register
mock_begin.return_value = {
"device_code": "dc_123",
"qr_url": "https://example.com/qr",
"user_code": "ABCD",
"interval": 1,
"expire_in": 60,
}
mock_poll.return_value = {
"app_id": "cli_app",
"app_secret": "secret",
"domain": "feishu",
"open_id": "ou_owner",
}
mock_probe.return_value = {"bot_name": "MyBot", "bot_open_id": "ou_bot"}
result = qr_register()
assert result is not None
assert result["app_id"] == "cli_app"
assert result["app_secret"] == "secret"
assert result["bot_name"] == "MyBot"
mock_init.assert_called_once()
mock_render.assert_called_once()
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_returns_none_on_init_failure(self, mock_init):
from gateway.platforms.feishu import qr_register
mock_init.side_effect = RuntimeError("not supported")
result = qr_register()
assert result is None
@patch("gateway.platforms.feishu._render_qr")
@patch("gateway.platforms.feishu._poll_registration")
@patch("gateway.platforms.feishu._begin_registration")
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_returns_none_on_poll_failure(
self, mock_init, mock_begin, mock_poll, mock_render
):
from gateway.platforms.feishu import qr_register
mock_begin.return_value = {
"device_code": "dc_123",
"qr_url": "https://example.com/qr",
"user_code": "ABCD",
"interval": 1,
"expire_in": 60,
}
mock_poll.return_value = None
result = qr_register()
assert result is None
# -- Contract: expected errors → None, unexpected errors → propagate --
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_returns_none_on_network_error(self, mock_init):
"""URLError (network down) is an expected failure → None."""
from gateway.platforms.feishu import qr_register
from urllib.error import URLError
mock_init.side_effect = URLError("DNS resolution failed")
result = qr_register()
assert result is None
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_returns_none_on_json_error(self, mock_init):
"""Malformed server response is an expected failure → None."""
from gateway.platforms.feishu import qr_register
mock_init.side_effect = json.JSONDecodeError("bad json", "", 0)
result = qr_register()
assert result is None
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_propagates_unexpected_errors(self, mock_init):
"""Bugs (e.g. AttributeError) must not be swallowed — they propagate."""
from gateway.platforms.feishu import qr_register
mock_init.side_effect = AttributeError("some internal bug")
with pytest.raises(AttributeError, match="some internal bug"):
qr_register()
# -- Negative paths: partial/malformed server responses --
@patch("gateway.platforms.feishu._render_qr")
@patch("gateway.platforms.feishu._begin_registration")
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_returns_none_when_begin_missing_device_code(
self, mock_init, mock_begin, mock_render
):
"""Server returns begin response without device_code → RuntimeError → None."""
from gateway.platforms.feishu import qr_register
mock_begin.side_effect = RuntimeError("Feishu registration did not return a device_code")
result = qr_register()
assert result is None
@patch("gateway.platforms.feishu.probe_bot")
@patch("gateway.platforms.feishu._render_qr")
@patch("gateway.platforms.feishu._poll_registration")
@patch("gateway.platforms.feishu._begin_registration")
@patch("gateway.platforms.feishu._init_registration")
def test_qr_register_succeeds_even_when_probe_fails(
self, mock_init, mock_begin, mock_poll, mock_render, mock_probe
):
"""Registration succeeds but probe fails → result with bot_name=None."""
from gateway.platforms.feishu import qr_register
mock_begin.return_value = {
"device_code": "dc_123",
"qr_url": "https://example.com/qr",
"user_code": "ABCD",
"interval": 1,
"expire_in": 60,
}
mock_poll.return_value = {
"app_id": "cli_app",
"app_secret": "secret",
"domain": "feishu",
"open_id": "ou_owner",
}
mock_probe.return_value = None # probe failed
result = qr_register()
assert result is not None
assert result["app_id"] == "cli_app"
assert result["bot_name"] is None
assert result["bot_open_id"] is None

View File

@@ -48,6 +48,7 @@ def _make_event(
room_id="!room1:example.org",
formatted_body=None,
thread_id=None,
mention_user_ids=None,
):
"""Create a fake room message event.
@@ -60,6 +61,9 @@ def _make_event(
content["formatted_body"] = formatted_body
content["format"] = "org.matrix.custom.html"
if mention_user_ids is not None:
content["m.mentions"] = {"user_ids": mention_user_ids}
relates_to = {}
if thread_id:
relates_to["rel_type"] = "m.thread"
@@ -108,6 +112,44 @@ class TestIsBotMentioned:
# "hermesbot" should not match word-boundary check for "hermes"
assert not self.adapter._is_bot_mentioned("hermesbot is here")
# m.mentions.user_ids — MSC3952 / Matrix v1.7 authoritative mentions
# Ported from openclaw/openclaw#64796
def test_m_mentions_user_ids_authoritative(self):
"""m.mentions.user_ids alone is sufficient — no body text needed."""
assert self.adapter._is_bot_mentioned(
"please reply", # no @hermes anywhere in body
mention_user_ids=["@hermes:example.org"],
)
def test_m_mentions_user_ids_with_body_mention(self):
"""Both m.mentions and body mention — should still be True."""
assert self.adapter._is_bot_mentioned(
"hey @hermes:example.org help",
mention_user_ids=["@hermes:example.org"],
)
def test_m_mentions_user_ids_other_user_only(self):
"""m.mentions with a different user — bot is NOT mentioned."""
assert not self.adapter._is_bot_mentioned(
"hello",
mention_user_ids=["@alice:example.org"],
)
def test_m_mentions_user_ids_empty_list(self):
"""Empty user_ids list — falls through to text detection."""
assert not self.adapter._is_bot_mentioned(
"hello everyone",
mention_user_ids=[],
)
def test_m_mentions_user_ids_none(self):
"""None mention_user_ids — falls through to text detection."""
assert not self.adapter._is_bot_mentioned(
"hello everyone",
mention_user_ids=None,
)
class TestStripMention:
def setup_method(self):
@@ -176,6 +218,44 @@ async def test_require_mention_html_pill(monkeypatch):
adapter.handle_message.assert_awaited_once()
@pytest.mark.asyncio
async def test_require_mention_m_mentions_user_ids(monkeypatch):
"""m.mentions.user_ids is authoritative per MSC3952 — no body mention needed.
Ported from openclaw/openclaw#64796.
"""
monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False)
monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False)
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
# Body has NO mention, but m.mentions.user_ids includes the bot.
event = _make_event(
"please reply",
mention_user_ids=["@hermes:example.org"],
)
await adapter._on_room_message(event)
adapter.handle_message.assert_awaited_once()
@pytest.mark.asyncio
async def test_require_mention_m_mentions_other_user_ignored(monkeypatch):
"""m.mentions.user_ids mentioning another user should NOT activate the bot."""
monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False)
monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False)
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
event = _make_event(
"hey alice check this",
mention_user_ids=["@alice:example.org"],
)
await adapter._on_room_message(event)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_require_mention_dm_always_responds(monkeypatch):
"""DMs always respond regardless of mention setting."""

View File

@@ -9,6 +9,8 @@ from gateway.platforms.base import (
MessageEvent,
MessageType,
safe_url_for_log,
utf16_len,
_prefix_within_utf16_limit,
)
@@ -448,3 +450,135 @@ class TestGetHumanDelay:
with patch.dict(os.environ, env):
delay = BasePlatformAdapter._get_human_delay()
assert 0.1 <= delay <= 0.2
# ---------------------------------------------------------------------------
# utf16_len / _prefix_within_utf16_limit / truncate_message with len_fn
# ---------------------------------------------------------------------------
# Ported from nearai/ironclaw#2304 — Telegram counts message length in UTF-16
# code units, not Unicode code-points. Astral-plane characters (emoji, CJK
# Extension B) are surrogate pairs: 1 Python char but 2 UTF-16 units.
class TestUtf16Len:
"""Verify the UTF-16 length helper."""
def test_ascii(self):
assert utf16_len("hello") == 5
def test_bmp_cjk(self):
# CJK ideographs in the BMP are 1 code unit each
assert utf16_len("你好") == 2
def test_emoji_surrogate_pair(self):
# 😀 (U+1F600) is outside BMP → 2 UTF-16 code units
assert utf16_len("😀") == 2
def test_mixed(self):
# "hi😀" = 2 + 2 = 4 UTF-16 units
assert utf16_len("hi😀") == 4
def test_musical_symbol(self):
# 𝄞 (U+1D11E) — Musical Symbol G Clef, surrogate pair
assert utf16_len("𝄞") == 2
def test_empty(self):
assert utf16_len("") == 0
class TestPrefixWithinUtf16Limit:
"""Verify UTF-16-aware prefix truncation."""
def test_fits_entirely(self):
assert _prefix_within_utf16_limit("hello", 10) == "hello"
def test_ascii_truncation(self):
result = _prefix_within_utf16_limit("hello world", 5)
assert result == "hello"
assert utf16_len(result) <= 5
def test_does_not_split_surrogate_pair(self):
# "a😀b" = 1 + 2 + 1 = 4 UTF-16 units; limit 2 should give "a"
result = _prefix_within_utf16_limit("a😀b", 2)
assert result == "a"
assert utf16_len(result) <= 2
def test_emoji_at_limit(self):
# "😀" = 2 UTF-16 units; limit 2 should include it
result = _prefix_within_utf16_limit("😀x", 2)
assert result == "😀"
def test_all_emoji(self):
msg = "😀" * 10 # 20 UTF-16 units
result = _prefix_within_utf16_limit(msg, 6)
assert result == "😀😀😀"
assert utf16_len(result) == 6
def test_empty(self):
assert _prefix_within_utf16_limit("", 5) == ""
class TestTruncateMessageUtf16:
"""Verify truncate_message respects UTF-16 lengths when len_fn=utf16_len."""
def test_short_emoji_message_no_split(self):
"""A short message under the UTF-16 limit should not be split."""
msg = "Hello 😀 world"
chunks = BasePlatformAdapter.truncate_message(msg, 4096, len_fn=utf16_len)
assert len(chunks) == 1
assert chunks[0] == msg
def test_emoji_near_limit_triggers_split(self):
"""A message at 4096 codepoints but >4096 UTF-16 units must split."""
# 2049 emoji = 2049 codepoints but 4098 UTF-16 units → exceeds 4096
msg = "😀" * 2049
assert len(msg) == 2049 # Python len sees 2049 chars
assert utf16_len(msg) == 4098 # but it's 4098 UTF-16 units
# Without UTF-16 awareness, this would NOT split (2049 < 4096)
chunks_naive = BasePlatformAdapter.truncate_message(msg, 4096)
assert len(chunks_naive) == 1, "Without len_fn, no split expected"
# With UTF-16 awareness, it MUST split
chunks = BasePlatformAdapter.truncate_message(msg, 4096, len_fn=utf16_len)
assert len(chunks) > 1, "With utf16_len, message should be split"
# Each chunk must fit within the UTF-16 limit
for i, chunk in enumerate(chunks):
assert utf16_len(chunk) <= 4096, (
f"Chunk {i} exceeds 4096 UTF-16 units: {utf16_len(chunk)}"
)
def test_each_utf16_chunk_within_limit(self):
"""All chunks produced with utf16_len must fit the limit."""
# Mix of BMP and astral-plane characters
msg = ("Hello 😀 world 🎵 test 𝄞 " * 200).strip()
max_len = 200
chunks = BasePlatformAdapter.truncate_message(msg, max_len, len_fn=utf16_len)
for i, chunk in enumerate(chunks):
u16_len = utf16_len(chunk)
assert u16_len <= max_len + 20, (
f"Chunk {i} UTF-16 length {u16_len} exceeds {max_len}"
)
def test_all_content_preserved(self):
"""Splitting with utf16_len must not lose content."""
words = ["emoji😀", "music🎵", "cjk你好", "plain"] * 100
msg = " ".join(words)
chunks = BasePlatformAdapter.truncate_message(msg, 200, len_fn=utf16_len)
reassembled = " ".join(chunks)
for word in words:
assert word in reassembled, f"Word '{word}' lost during UTF-16 split"
def test_code_blocks_preserved_with_utf16(self):
"""Code block fence handling should work with utf16_len too."""
msg = "Before\n```python\n" + "x = '😀'\n" * 200 + "```\nAfter"
chunks = BasePlatformAdapter.truncate_message(msg, 300, len_fn=utf16_len)
assert len(chunks) > 1
# Each chunk should have balanced fences
for i, chunk in enumerate(chunks):
fence_count = chunk.count("```")
assert fence_count % 2 == 0, (
f"Chunk {i} has unbalanced fences ({fence_count})"
)

View File

@@ -0,0 +1,215 @@
"""Tests for /restart notification — the gateway notifies the requester on comeback."""
import asyncio
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
import gateway.run as gateway_run
from gateway.config import Platform
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import (
make_restart_runner,
make_restart_source,
)
# ── _handle_restart_command writes .restart_notify.json ──────────────────
@pytest.mark.asyncio
async def test_restart_command_writes_notify_file(tmp_path, monkeypatch):
"""When /restart fires, the requester's routing info is persisted to disk."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
source = make_restart_source(chat_id="42")
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=source,
message_id="m1",
)
result = await runner._handle_restart_command(event)
assert "Restarting" in result
notify_path = tmp_path / ".restart_notify.json"
assert notify_path.exists()
data = json.loads(notify_path.read_text())
assert data["platform"] == "telegram"
assert data["chat_id"] == "42"
assert "thread_id" not in data # no thread → omitted
@pytest.mark.asyncio
async def test_restart_command_uses_service_restart_under_systemd(tmp_path, monkeypatch):
"""Under systemd (INVOCATION_ID set), /restart uses via_service=True."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setenv("INVOCATION_ID", "abc123")
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
source = make_restart_source(chat_id="42")
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=source,
message_id="m1",
)
await runner._handle_restart_command(event)
runner.request_restart.assert_called_once_with(detached=False, via_service=True)
@pytest.mark.asyncio
async def test_restart_command_uses_detached_without_systemd(tmp_path, monkeypatch):
"""Without systemd, /restart uses the detached subprocess approach."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("INVOCATION_ID", raising=False)
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
source = make_restart_source(chat_id="42")
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=source,
message_id="m1",
)
await runner._handle_restart_command(event)
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
@pytest.mark.asyncio
async def test_restart_command_preserves_thread_id(tmp_path, monkeypatch):
"""Thread ID is saved when the requester is in a threaded chat."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
source = make_restart_source(chat_id="99")
source.thread_id = "topic_7"
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=source,
message_id="m2",
)
await runner._handle_restart_command(event)
data = json.loads((tmp_path / ".restart_notify.json").read_text())
assert data["thread_id"] == "topic_7"
# ── _send_restart_notification ───────────────────────────────────────────
@pytest.mark.asyncio
async def test_send_restart_notification_delivers_and_cleans_up(tmp_path, monkeypatch):
"""On startup, the notification is sent and the file is removed."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
notify_path = tmp_path / ".restart_notify.json"
notify_path.write_text(json.dumps({
"platform": "telegram",
"chat_id": "42",
}))
runner, adapter = make_restart_runner()
adapter.send = AsyncMock()
await runner._send_restart_notification()
adapter.send.assert_called_once()
call_args = adapter.send.call_args
assert call_args[0][0] == "42" # chat_id
assert "restarted" in call_args[0][1].lower()
assert call_args[1].get("metadata") is None # no thread
assert not notify_path.exists()
@pytest.mark.asyncio
async def test_send_restart_notification_with_thread(tmp_path, monkeypatch):
"""Thread ID is passed as metadata so the message lands in the right topic."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
notify_path = tmp_path / ".restart_notify.json"
notify_path.write_text(json.dumps({
"platform": "telegram",
"chat_id": "99",
"thread_id": "topic_7",
}))
runner, adapter = make_restart_runner()
adapter.send = AsyncMock()
await runner._send_restart_notification()
call_args = adapter.send.call_args
assert call_args[1]["metadata"] == {"thread_id": "topic_7"}
assert not notify_path.exists()
@pytest.mark.asyncio
async def test_send_restart_notification_noop_when_no_file(tmp_path, monkeypatch):
"""Nothing happens if there's no pending restart notification."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
runner, adapter = make_restart_runner()
adapter.send = AsyncMock()
await runner._send_restart_notification()
adapter.send.assert_not_called()
@pytest.mark.asyncio
async def test_send_restart_notification_skips_when_adapter_missing(tmp_path, monkeypatch):
"""If the requester's platform isn't connected, clean up without crashing."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
notify_path = tmp_path / ".restart_notify.json"
notify_path.write_text(json.dumps({
"platform": "discord", # runner only has telegram adapter
"chat_id": "42",
}))
runner, _adapter = make_restart_runner()
await runner._send_restart_notification()
# File cleaned up even though we couldn't send
assert not notify_path.exists()
@pytest.mark.asyncio
async def test_send_restart_notification_cleans_up_on_send_failure(
tmp_path, monkeypatch
):
"""If the adapter.send() raises, the file is still cleaned up."""
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
notify_path = tmp_path / ".restart_notify.json"
notify_path.write_text(json.dumps({
"platform": "telegram",
"chat_id": "42",
}))
runner, adapter = make_restart_runner()
adapter.send = AsyncMock(side_effect=RuntimeError("network down"))
await runner._send_restart_notification()
assert not notify_path.exists() # cleaned up despite error

View File

@@ -396,6 +396,27 @@ class QueuedCommentaryAgent:
}
class VerboseAgent:
"""Agent that emits a tool call with args whose JSON exceeds 200 chars."""
LONG_CODE = "x" * 300
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback(
"tool.started", "execute_code", None,
{"code": self.LONG_CODE},
)
time.sleep(0.35)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
async def _run_with_agent(
monkeypatch,
tmp_path,
@@ -575,3 +596,45 @@ async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monke
assert result["final_response"] == "final response 2"
assert "I'll inspect the repo first." in sent_texts
assert "final response 1" in sent_texts
@pytest.mark.asyncio
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
Previously, verbose mode capped args at 200 chars when tool_preview_length
was 0 (default). The user explicitly opted into verbose — show full detail.
"""
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
VerboseAgent,
session_id="sess-verbose-no-truncate",
config_data={"display": {"tool_progress": "verbose", "tool_preview_length": 0}},
)
assert result["final_response"] == "done"
# The full 300-char 'x' string should be present, not truncated to 200
all_content = " ".join(call["content"] for call in adapter.sent)
all_content += " ".join(call["content"] for call in adapter.edits)
assert VerboseAgent.LONG_CODE in all_content
@pytest.mark.asyncio
async def test_verbose_mode_respects_explicit_tool_preview_length(monkeypatch, tmp_path):
"""When tool_preview_length is set to a positive value, verbose truncates to that."""
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
VerboseAgent,
session_id="sess-verbose-explicit-cap",
config_data={"display": {"tool_progress": "verbose", "tool_preview_length": 50}},
)
assert result["final_response"] == "done"
all_content = " ".join(call["content"] for call in adapter.sent)
all_content += " ".join(call["content"] for call in adapter.edits)
# Should be truncated — full 300-char string NOT present
assert VerboseAgent.LONG_CODE not in all_content
# But should still contain the truncated portion with "..."
assert "..." in all_content

View File

@@ -552,6 +552,45 @@ class TestLoadTranscriptPreferLongerSource:
assert result[0]["content"] == "db-q"
class TestSessionStoreSwitchSession:
"""Regression coverage for gateway /resume session switching semantics."""
def test_switch_session_reopens_target_session_in_db(self, tmp_path):
from hermes_state import SessionDB
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path / "sessions", config=config)
db = SessionDB(db_path=tmp_path / "state.db")
store._db = db
store._loaded = True
source = SessionSource(
platform=Platform.FEISHU,
chat_id="chat-1",
chat_type="dm",
user_id="user-1",
user_name="tester",
)
current_entry = store.get_or_create_session(source)
current_session_id = current_entry.session_id
target_session_id = "old_session_abc"
db.create_session(target_session_id, source="feishu", user_id="user-1")
db.end_session(target_session_id, end_reason="user_exit")
assert db.get_session(target_session_id)["ended_at"] is not None
switched = store.switch_session(current_entry.session_key, target_session_id)
assert switched is not None
assert switched.session_id == target_session_id
assert db.get_session(current_session_id)["end_reason"] == "session_switch"
resumed = db.get_session(target_session_id)
assert resumed["ended_at"] is None
assert resumed["end_reason"] is None
db.close()
class TestWhatsAppDMSessionKeyConsistency:
"""Regression: all session-key construction must go through build_session_key
so DMs are isolated by chat_id across platforms."""

View File

@@ -60,7 +60,8 @@ def _make_runner():
def _make_event(text="hello", chat_id="12345"):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm",
user_id="u1",
)
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
@@ -192,7 +193,8 @@ async def test_command_messages_do_not_leave_sentinel():
_handle_message. They must NOT leave a sentinel behind."""
runner = _make_runner()
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm",
user_id="u1",
)
event = MessageEvent(
text="/help", message_type=MessageType.TEXT, source=source
@@ -240,9 +242,7 @@ async def test_stop_during_sentinel_force_cleans_session():
stop_event = _make_event(text="/stop")
result = await runner._handle_message(stop_event)
assert result is not None, "/stop during sentinel should return a message"
assert "force-stopped" in result.lower() or "unlocked" in result.lower()
# Sentinel must be cleaned up
assert "stopped" in result.lower()
assert session_key not in runner._running_agents, (
"/stop must remove sentinel so the session is unlocked"
)
@@ -268,7 +268,7 @@ async def test_stop_hard_kills_running_agent():
forever — showing 'writing...' but never producing output."""
runner = _make_runner()
session_key = build_session_key(
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm", user_id="u1")
)
# Simulate a running (possibly hung) agent
@@ -289,7 +289,7 @@ async def test_stop_hard_kills_running_agent():
# Must return a confirmation
assert result is not None
assert "force-stopped" in result.lower() or "unlocked" in result.lower()
assert "stopped" in result.lower()
# ------------------------------------------------------------------
@@ -301,7 +301,7 @@ async def test_stop_clears_pending_messages():
queued during the run must be discarded."""
runner = _make_runner()
session_key = build_session_key(
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm", user_id="u1")
)
fake_agent = MagicMock()

View File

@@ -0,0 +1,279 @@
"""Tests for _setup_feishu() in hermes_cli/gateway.py.
Verifies that the interactive setup writes env vars that correctly drive the
Feishu adapter: credentials, connection mode, DM policy, and group policy.
"""
import os
from unittest.mock import patch
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _run_setup_feishu(
*,
qr_result=None,
prompt_yes_no_responses=None,
prompt_choice_responses=None,
prompt_responses=None,
existing_env=None,
):
"""Run _setup_feishu() with mocked I/O and return the env vars that were saved.
Returns a dict of {env_var_name: value} for all save_env_value calls.
"""
existing_env = existing_env or {}
prompt_yes_no_responses = list(prompt_yes_no_responses or [True])
# QR path: method(0), dm(0), group(0) — 3 choices (no connection mode)
# Manual path: method(1), domain(0), connection(0), dm(0), group(0) — 5 choices
prompt_choice_responses = list(prompt_choice_responses or [0, 0, 0])
prompt_responses = list(prompt_responses or [""])
saved_env = {}
def mock_save(name, value):
saved_env[name] = value
def mock_get(name):
return existing_env.get(name, "")
with patch("hermes_cli.gateway.save_env_value", side_effect=mock_save), \
patch("hermes_cli.gateway.get_env_value", side_effect=mock_get), \
patch("hermes_cli.gateway.prompt_yes_no", side_effect=prompt_yes_no_responses), \
patch("hermes_cli.gateway.prompt_choice", side_effect=prompt_choice_responses), \
patch("hermes_cli.gateway.prompt", side_effect=prompt_responses), \
patch("hermes_cli.gateway.print_info"), \
patch("hermes_cli.gateway.print_success"), \
patch("hermes_cli.gateway.print_warning"), \
patch("hermes_cli.gateway.print_error"), \
patch("hermes_cli.gateway.color", side_effect=lambda t, c: t), \
patch("gateway.platforms.feishu.qr_register", return_value=qr_result):
from hermes_cli.gateway import _setup_feishu
_setup_feishu()
return saved_env
# ---------------------------------------------------------------------------
# QR scan-to-create path
# ---------------------------------------------------------------------------
class TestSetupFeishuQrPath:
"""Tests for the QR scan-to-create happy path."""
def test_qr_success_saves_core_credentials(self):
env = _run_setup_feishu(
qr_result={
"app_id": "cli_test",
"app_secret": "secret_test",
"domain": "feishu",
"open_id": "ou_owner",
"bot_name": "TestBot",
"bot_open_id": "ou_bot",
},
prompt_yes_no_responses=[True], # Start QR
prompt_choice_responses=[0, 0, 0], # method=QR, dm=pairing, group=open
prompt_responses=[""], # home channel: skip
)
assert env["FEISHU_APP_ID"] == "cli_test"
assert env["FEISHU_APP_SECRET"] == "secret_test"
assert env["FEISHU_DOMAIN"] == "feishu"
def test_qr_success_does_not_persist_bot_identity(self):
"""Bot identity is discovered at runtime by _hydrate_bot_identity — not persisted
in env, so it stays fresh if the user renames the bot later."""
env = _run_setup_feishu(
qr_result={
"app_id": "cli_test",
"app_secret": "secret_test",
"domain": "feishu",
"open_id": "ou_owner",
"bot_name": "TestBot",
"bot_open_id": "ou_bot",
},
prompt_yes_no_responses=[True],
prompt_choice_responses=[0, 0, 0],
prompt_responses=[""],
)
assert "FEISHU_BOT_OPEN_ID" not in env
assert "FEISHU_BOT_NAME" not in env
# ---------------------------------------------------------------------------
# Connection mode
# ---------------------------------------------------------------------------
class TestSetupFeishuConnectionMode:
"""Connection mode: QR always websocket, manual path lets user choose."""
def test_qr_path_defaults_to_websocket(self):
env = _run_setup_feishu(
qr_result={
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
"open_id": None, "bot_name": None, "bot_open_id": None,
},
prompt_choice_responses=[0, 0, 0], # method=QR, dm=pairing, group=open
prompt_responses=[""],
)
assert env["FEISHU_CONNECTION_MODE"] == "websocket"
@patch("gateway.platforms.feishu.probe_bot", return_value=None)
def test_manual_path_websocket(self, _mock_probe):
env = _run_setup_feishu(
qr_result=None,
prompt_choice_responses=[1, 0, 0, 0, 0], # method=manual, domain=feishu, connection=ws, dm=pairing, group=open
prompt_responses=["cli_manual", "secret_manual", ""], # app_id, app_secret, home_channel
)
assert env["FEISHU_CONNECTION_MODE"] == "websocket"
@patch("gateway.platforms.feishu.probe_bot", return_value=None)
def test_manual_path_webhook(self, _mock_probe):
env = _run_setup_feishu(
qr_result=None,
prompt_choice_responses=[1, 0, 1, 0, 0], # method=manual, domain=feishu, connection=webhook, dm=pairing, group=open
prompt_responses=["cli_manual", "secret_manual", ""], # app_id, app_secret, home_channel
)
assert env["FEISHU_CONNECTION_MODE"] == "webhook"
# ---------------------------------------------------------------------------
# DM security policy
# ---------------------------------------------------------------------------
class TestSetupFeishuDmPolicy:
"""DM policy must use platform-scoped FEISHU_ALLOW_ALL_USERS, not the global flag."""
def _run_with_dm_choice(self, dm_choice_idx, prompt_responses=None):
return _run_setup_feishu(
qr_result={
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
"open_id": "ou_owner", "bot_name": None, "bot_open_id": None,
},
prompt_yes_no_responses=[True],
prompt_choice_responses=[0, dm_choice_idx, 0], # method=QR, dm=<choice>, group=open
prompt_responses=prompt_responses or [""],
)
def test_pairing_sets_feishu_allow_all_false(self):
env = self._run_with_dm_choice(0)
assert env["FEISHU_ALLOW_ALL_USERS"] == "false"
assert env["FEISHU_ALLOWED_USERS"] == ""
assert "GATEWAY_ALLOW_ALL_USERS" not in env
def test_allow_all_sets_feishu_allow_all_true(self):
env = self._run_with_dm_choice(1)
assert env["FEISHU_ALLOW_ALL_USERS"] == "true"
assert env["FEISHU_ALLOWED_USERS"] == ""
assert "GATEWAY_ALLOW_ALL_USERS" not in env
def test_allowlist_sets_feishu_allow_all_false_with_list(self):
env = self._run_with_dm_choice(2, prompt_responses=["ou_user1,ou_user2", ""])
assert env["FEISHU_ALLOW_ALL_USERS"] == "false"
assert env["FEISHU_ALLOWED_USERS"] == "ou_user1,ou_user2"
assert "GATEWAY_ALLOW_ALL_USERS" not in env
def test_allowlist_prepopulates_with_scan_owner_open_id(self):
"""When open_id is available from QR scan, it should be the default allowlist value."""
# We return the owner's open_id from prompt (+ empty home channel).
env = self._run_with_dm_choice(2, prompt_responses=["ou_owner", ""])
assert env["FEISHU_ALLOWED_USERS"] == "ou_owner"
# ---------------------------------------------------------------------------
# Group policy
# ---------------------------------------------------------------------------
class TestSetupFeishuGroupPolicy:
def test_open_with_mention(self):
env = _run_setup_feishu(
qr_result={
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
"open_id": None, "bot_name": None, "bot_open_id": None,
},
prompt_yes_no_responses=[True],
prompt_choice_responses=[0, 0, 0], # method=QR, dm=pairing, group=open
prompt_responses=[""],
)
assert env["FEISHU_GROUP_POLICY"] == "open"
def test_disabled(self):
env = _run_setup_feishu(
qr_result={
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
"open_id": None, "bot_name": None, "bot_open_id": None,
},
prompt_yes_no_responses=[True],
prompt_choice_responses=[0, 0, 1], # method=QR, dm=pairing, group=disabled
prompt_responses=[""],
)
assert env["FEISHU_GROUP_POLICY"] == "disabled"
# ---------------------------------------------------------------------------
# Adapter integration: env vars → FeishuAdapterSettings
# ---------------------------------------------------------------------------
class TestSetupFeishuAdapterIntegration:
"""Verify that env vars written by _setup_feishu() produce a valid adapter config.
This bridges the gap between 'setup wrote the right env vars' and
'the adapter will actually initialize correctly from those vars'.
"""
def _make_env_from_setup(self, dm_idx=0, group_idx=0):
"""Run _setup_feishu via QR path and return the env vars it would write."""
return _run_setup_feishu(
qr_result={
"app_id": "cli_test_app",
"app_secret": "test_secret_value",
"domain": "feishu",
"open_id": "ou_owner",
"bot_name": "IntegrationBot",
"bot_open_id": "ou_bot_integration",
},
prompt_yes_no_responses=[True],
prompt_choice_responses=[0, dm_idx, group_idx], # method=QR, dm, group
prompt_responses=[""],
)
@patch.dict(os.environ, {}, clear=True)
def test_qr_env_produces_valid_adapter_settings(self):
"""QR setup → adapter initializes with websocket mode."""
env = self._make_env_from_setup()
with patch.dict(os.environ, env, clear=True):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
assert adapter._app_id == "cli_test_app"
assert adapter._app_secret == "test_secret_value"
assert adapter._domain_name == "feishu"
assert adapter._connection_mode == "websocket"
@patch.dict(os.environ, {}, clear=True)
def test_open_dm_env_sets_correct_adapter_state(self):
"""Setup with 'allow all DMs' → adapter sees allow-all flag."""
env = self._make_env_from_setup(dm_idx=1)
with patch.dict(os.environ, env, clear=True):
from gateway.platforms.feishu import FeishuAdapter
from gateway.config import PlatformConfig
# Verify adapter initializes without error and env var is correct.
FeishuAdapter(PlatformConfig())
assert os.getenv("FEISHU_ALLOW_ALL_USERS") == "true"
@patch.dict(os.environ, {}, clear=True)
def test_group_open_env_sets_adapter_group_policy(self):
"""Setup with 'open groups' → adapter group_policy is 'open'."""
env = self._make_env_from_setup(group_idx=0)
with patch.dict(os.environ, env, clear=True):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
assert adapter._group_policy == "open"

View File

@@ -209,6 +209,33 @@ class TestScopedLocks:
assert payload["pid"] == os.getpid()
assert payload["metadata"]["platform"] == "telegram"
def test_acquire_scoped_lock_recovers_empty_lock_file(self, tmp_path, monkeypatch):
"""Empty lock file (0 bytes) left by a crashed process should be treated as stale."""
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
lock_path = tmp_path / "locks" / "slack-app-token-2bb80d537b1da3e3.lock"
lock_path.parent.mkdir(parents=True, exist_ok=True)
lock_path.write_text("") # simulate crash between O_CREAT and json.dump
acquired, existing = status.acquire_scoped_lock("slack-app-token", "secret", metadata={"platform": "slack"})
assert acquired is True
payload = json.loads(lock_path.read_text())
assert payload["pid"] == os.getpid()
assert payload["metadata"]["platform"] == "slack"
def test_acquire_scoped_lock_recovers_corrupt_lock_file(self, tmp_path, monkeypatch):
"""Lock file with invalid JSON should be treated as stale."""
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
lock_path = tmp_path / "locks" / "slack-app-token-2bb80d537b1da3e3.lock"
lock_path.parent.mkdir(parents=True, exist_ok=True)
lock_path.write_text("{truncated") # simulate partial write
acquired, existing = status.acquire_scoped_lock("slack-app-token", "secret", metadata={"platform": "slack"})
assert acquired is True
payload = json.loads(lock_path.read_text())
assert payload["pid"] == os.getpid()
def test_release_scoped_lock_only_removes_current_owner(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))

View File

@@ -29,7 +29,7 @@ def _make_runner():
@pytest.mark.asyncio
async def test_handle_message_does_not_priority_interrupt_photo_followup():
runner = _make_runner()
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm", user_id="u1")
session_key = build_session_key(source)
running_agent = MagicMock()
runner._running_agents[session_key] = running_agent

View File

@@ -121,7 +121,7 @@ class TestSendWithReplyToMode:
adapter = adapter_factory(reply_to_mode="off")
adapter._bot = MagicMock()
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
@@ -133,7 +133,7 @@ class TestSendWithReplyToMode:
adapter = adapter_factory(reply_to_mode="first")
adapter._bot = MagicMock()
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
@@ -148,7 +148,7 @@ class TestSendWithReplyToMode:
adapter = adapter_factory(reply_to_mode="all")
adapter._bot = MagicMock()
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
@@ -162,7 +162,7 @@ class TestSendWithReplyToMode:
adapter = adapter_factory(reply_to_mode="all")
adapter._bot = MagicMock()
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
await adapter.send("12345", "test content", reply_to=None)
@@ -175,7 +175,7 @@ class TestSendWithReplyToMode:
adapter = adapter_factory(reply_to_mode="first")
adapter._bot = MagicMock()
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
adapter.truncate_message = lambda content, max_len: ["single chunk"]
adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
await adapter.send("12345", "test", reply_to="999")

View File

@@ -417,6 +417,7 @@ class TestDiscordPlayTtsSkip:
adapter.config = config
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
@@ -702,13 +703,18 @@ class TestVoiceChannelCommands:
mock_adapter.join_voice_channel = AsyncMock(return_value=True)
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
mock_adapter._voice_text_channels = {}
mock_adapter._voice_sources = {}
mock_adapter._voice_input_callback = None
event = self._make_discord_event()
event.source.chat_type = "group"
event.source.chat_name = "Hermes Server / #general"
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "joined" in result.lower()
assert "General" in result
assert runner._voice_mode["123"] == "all"
assert mock_adapter._voice_sources[111]["chat_id"] == "123"
assert mock_adapter._voice_sources[111]["chat_type"] == "group"
@pytest.mark.asyncio
async def test_join_failure(self, runner):
@@ -815,6 +821,7 @@ class TestVoiceChannelCommands:
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_adapter._voice_sources = {}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
@@ -828,12 +835,45 @@ class TestVoiceChannelCommands:
assert event.source.chat_id == "123"
assert event.source.chat_type == "channel"
@pytest.mark.asyncio
async def test_input_reuses_bound_source_metadata(self, runner):
"""Voice input should share the linked text channel session metadata."""
from gateway.config import Platform
bound_source = SessionSource(
chat_id="123",
chat_name="Hermes Server / #general",
chat_type="group",
user_id="user1",
user_name="user1",
platform=Platform.DISCORD,
)
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_adapter._voice_sources = {111: bound_source.to_dict()}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
mock_adapter.handle_message = AsyncMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Hello from VC")
mock_adapter.handle_message.assert_called_once()
event = mock_adapter.handle_message.call_args[0][0]
assert event.source.chat_id == "123"
assert event.source.chat_type == "group"
assert event.source.chat_name == "Hermes Server / #general"
assert event.source.user_id == "42"
@pytest.mark.asyncio
async def test_input_posts_transcript_in_text_channel(self, runner):
"""Voice input sends transcript message to text channel."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_adapter._voice_sources = {}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
@@ -892,6 +932,7 @@ class TestDiscordVoiceChannelMethods:
adapter._client = MagicMock()
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
@@ -926,6 +967,7 @@ class TestDiscordVoiceChannelMethods:
mock_vc.disconnect = AsyncMock()
adapter._voice_clients[111] = mock_vc
adapter._voice_text_channels[111] = 123
adapter._voice_sources[111] = {"chat_id": "123", "chat_type": "group"}
mock_receiver = MagicMock()
adapter._voice_receivers[111] = mock_receiver
@@ -944,6 +986,7 @@ class TestDiscordVoiceChannelMethods:
mock_timeout.cancel.assert_called_once()
assert 111 not in adapter._voice_clients
assert 111 not in adapter._voice_text_channels
assert 111 not in adapter._voice_sources
assert 111 not in adapter._voice_receivers
@pytest.mark.asyncio
@@ -1670,6 +1713,7 @@ class TestVoiceTimeoutCleansRunnerState:
adapter.config = config
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
@@ -1759,6 +1803,7 @@ class TestPlaybackTimeout:
adapter.config = config
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_timeout_tasks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
@@ -1939,6 +1984,7 @@ class TestVoiceChannelAwareness:
adapter = object.__new__(DiscordAdapter)
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_receivers = {}
adapter._client = MagicMock()
adapter._client.user = SimpleNamespace(id=99999, name="HermesBot")
@@ -2408,6 +2454,7 @@ class TestVoiceTTSPlayback:
adapter.config = config
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_receivers = {}
return adapter
@@ -2587,6 +2634,7 @@ class TestUDPKeepalive:
adapter.config = config
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}

View File

@@ -0,0 +1,141 @@
"""Tests for gateway weak credential rejection at startup.
Ported from openclaw/openclaw#64586: rejects known-weak placeholder
tokens at gateway startup instead of letting them silently fail
against platform APIs.
"""
import logging
import pytest
from gateway.config import PlatformConfig, Platform, _validate_gateway_config
# ---------------------------------------------------------------------------
# Helper: create a minimal GatewayConfig with one enabled platform
# ---------------------------------------------------------------------------
def _make_gateway_config(platform, token, enabled=True, **extra_kwargs):
"""Create a minimal GatewayConfig-like object for validation testing."""
from gateway.config import GatewayConfig
config = GatewayConfig(platforms={})
pconfig = PlatformConfig(enabled=enabled, token=token, **extra_kwargs)
config.platforms[platform] = pconfig
return config
def _validate_and_return(config):
"""Call _validate_gateway_config and return the config (mutated in place)."""
_validate_gateway_config(config)
return config
# ---------------------------------------------------------------------------
# Unit tests: platform token placeholder rejection
# ---------------------------------------------------------------------------
class TestPlatformTokenPlaceholderGuard:
"""Verify that _validate_gateway_config disables platforms with placeholder tokens."""
def test_rejects_triple_asterisk(self, caplog):
"""'***' is the .env.example placeholder — should be rejected."""
config = _make_gateway_config(Platform.TELEGRAM, "***")
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert config.platforms[Platform.TELEGRAM].enabled is False
assert "placeholder" in caplog.text.lower()
def test_rejects_changeme(self, caplog):
config = _make_gateway_config(Platform.DISCORD, "changeme")
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert config.platforms[Platform.DISCORD].enabled is False
def test_rejects_your_api_key(self, caplog):
config = _make_gateway_config(Platform.SLACK, "your_api_key")
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert config.platforms[Platform.SLACK].enabled is False
def test_rejects_placeholder(self, caplog):
config = _make_gateway_config(Platform.MATRIX, "placeholder")
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert config.platforms[Platform.MATRIX].enabled is False
def test_accepts_real_token(self, caplog):
"""A real-looking bot token should pass validation."""
config = _make_gateway_config(
Platform.TELEGRAM, "7123456789:AAHdqTcvCH1vGWJxfSeOfSAs0K5PALDsaw"
)
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert config.platforms[Platform.TELEGRAM].enabled is True
assert "placeholder" not in caplog.text.lower()
def test_accepts_empty_token_without_error(self, caplog):
"""Empty tokens get a warning (existing behavior), not a placeholder error."""
config = _make_gateway_config(Platform.TELEGRAM, "")
with caplog.at_level(logging.WARNING):
_validate_and_return(config)
# Empty token doesn't trigger placeholder rejection — enabled stays True
# (the existing empty-token warning is separate)
assert config.platforms[Platform.TELEGRAM].enabled is True
def test_disabled_platform_not_checked(self, caplog):
"""Disabled platforms should not be validated."""
config = _make_gateway_config(Platform.TELEGRAM, "***", enabled=False)
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert "placeholder" not in caplog.text.lower()
def test_rejects_whitespace_padded_placeholder(self, caplog):
"""Whitespace-padded placeholders should still be caught."""
config = _make_gateway_config(Platform.TELEGRAM, " *** ")
with caplog.at_level(logging.ERROR):
_validate_and_return(config)
assert config.platforms[Platform.TELEGRAM].enabled is False
# ---------------------------------------------------------------------------
# Integration test: API server placeholder key on network-accessible host
# ---------------------------------------------------------------------------
class TestAPIServerPlaceholderKeyGuard:
"""Verify that the API server rejects placeholder keys on network hosts."""
@pytest.mark.asyncio
async def test_refuses_wildcard_with_placeholder_key(self):
from gateway.platforms.api_server import APIServerAdapter
adapter = APIServerAdapter(
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "changeme"})
)
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_refuses_wildcard_with_asterisk_key(self):
from gateway.platforms.api_server import APIServerAdapter
adapter = APIServerAdapter(
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "***"})
)
result = await adapter.connect()
assert result is False
def test_allows_loopback_with_placeholder_key(self):
"""Loopback with a placeholder key is fine — not network-exposed."""
from gateway.platforms.api_server import APIServerAdapter
from gateway.platforms.base import is_network_accessible
adapter = APIServerAdapter(
PlatformConfig(enabled=True, extra={"host": "127.0.0.1", "key": "changeme"})
)
# On loopback the placeholder guard doesn't fire
assert is_network_accessible(adapter._host) is False

View File

@@ -30,7 +30,7 @@ class TestWeixinFormatting:
assert (
adapter.format_message(content)
== "【Title】\n\n**Plan**\n\nUse **bold** and [docs](https://example.com)."
== "【Title】\n\n**Plan**\n\nUse **bold** and docs (https://example.com)."
)
def test_format_message_rewrites_markdown_tables(self):
@@ -374,3 +374,149 @@ class TestWeixinRemoteMediaSafety:
assert "Blocked unsafe URL" in str(exc)
else:
raise AssertionError("expected ValueError for unsafe URL")
class TestWeixinMarkdownLinks:
"""Markdown links should be converted to plaintext since WeChat can't render them."""
def test_format_message_converts_markdown_links_to_plain_text(self):
adapter = _make_adapter()
content = "Check [the docs](https://example.com) and [GitHub](https://github.com) for details"
assert (
adapter.format_message(content)
== "Check the docs (https://example.com) and GitHub (https://github.com) for details"
)
def test_format_message_preserves_links_inside_code_blocks(self):
adapter = _make_adapter()
content = "See below:\n\n```\n[link](https://example.com)\n```\n\nDone."
result = adapter.format_message(content)
assert "[link](https://example.com)" in result
class TestWeixinBlankMessagePrevention:
"""Regression tests for the blank-bubble bugs.
Three separate guards now prevent a blank WeChat message from ever being
dispatched:
1. ``_split_text_for_weixin_delivery("")`` returns ``[]`` — not ``[""]``.
2. ``send()`` filters out empty/whitespace-only chunks before calling
``_send_text_chunk``.
3. ``_send_message()`` raises ``ValueError`` for empty text as a last-resort
safety net.
"""
def test_split_text_returns_empty_list_for_empty_string(self):
adapter = _make_adapter()
assert adapter._split_text("") == []
def test_split_text_returns_empty_list_for_empty_string_split_per_line(self):
adapter = WeixinAdapter(
PlatformConfig(
enabled=True,
extra={
"account_id": "acct",
"token": "test-tok",
"split_multiline_messages": True,
},
)
)
assert adapter._split_text("") == []
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
def test_send_empty_content_does_not_call_send_message(self, send_message_mock):
adapter = _make_adapter()
adapter._session = object()
adapter._token = "test-token"
adapter._base_url = "https://weixin.example.com"
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
result = asyncio.run(adapter.send("wxid_test123", ""))
# Empty content → no chunks → no _send_message calls
assert result.success is True
send_message_mock.assert_not_awaited()
def test_send_message_rejects_empty_text(self):
"""_send_message raises ValueError for empty/whitespace text."""
import pytest
with pytest.raises(ValueError, match="text must not be empty"):
asyncio.run(
weixin._send_message(
AsyncMock(),
base_url="https://example.com",
token="tok",
to="wxid_test",
text="",
context_token=None,
client_id="cid",
)
)
class TestWeixinStreamingCursorSuppression:
"""WeChat doesn't support message editing — cursor must be suppressed."""
def test_supports_message_editing_is_false(self):
adapter = _make_adapter()
assert adapter.SUPPORTS_MESSAGE_EDITING is False
class TestWeixinMediaBuilder:
"""Media builder uses base64(hex_key), not base64(raw_bytes) for aes_key."""
def test_image_builder_aes_key_is_base64_of_hex(self):
import base64
adapter = _make_adapter()
media_type, builder = adapter._outbound_media_builder("photo.jpg")
assert media_type == weixin.MEDIA_IMAGE
fake_hex_key = "0123456789abcdef0123456789abcdef"
expected_aes = base64.b64encode(fake_hex_key.encode("ascii")).decode("ascii")
item = builder(
encrypt_query_param="eq",
aes_key_for_api=expected_aes,
ciphertext_size=1024,
plaintext_size=1000,
filename="photo.jpg",
rawfilemd5="abc123",
)
assert item["image_item"]["media"]["aes_key"] == expected_aes
def test_video_builder_includes_md5(self):
adapter = _make_adapter()
media_type, builder = adapter._outbound_media_builder("clip.mp4")
assert media_type == weixin.MEDIA_VIDEO
item = builder(
encrypt_query_param="eq",
aes_key_for_api="fakekey",
ciphertext_size=2048,
plaintext_size=2000,
filename="clip.mp4",
rawfilemd5="deadbeef",
)
assert item["video_item"]["video_md5"] == "deadbeef"
def test_voice_builder_for_audio_files(self):
adapter = _make_adapter()
media_type, builder = adapter._outbound_media_builder("note.mp3")
assert media_type == weixin.MEDIA_VOICE
item = builder(
encrypt_query_param="eq",
aes_key_for_api="fakekey",
ciphertext_size=512,
plaintext_size=500,
filename="note.mp3",
rawfilemd5="abc",
)
assert item["type"] == weixin.ITEM_VOICE
assert "voice_item" in item
def test_voice_builder_for_silk_files(self):
adapter = _make_adapter()
media_type, builder = adapter._outbound_media_builder("recording.silk")
assert media_type == weixin.MEDIA_VOICE

View File

@@ -0,0 +1,271 @@
"""Tests for WhatsApp message formatting and chunking.
Covers:
- format_message(): markdown → WhatsApp syntax conversion
- send(): message chunking for long responses
- MAX_MESSAGE_LENGTH: practical UX limit
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a WhatsAppAdapter with test attributes (bypass __init__)."""
from gateway.platforms.whatsapp import WhatsAppAdapter
adapter = WhatsAppAdapter.__new__(WhatsAppAdapter)
adapter.platform = Platform.WHATSAPP
adapter.config = MagicMock()
adapter.config.extra = {}
adapter._bridge_port = 3000
adapter._bridge_script = "/tmp/test-bridge.js"
adapter._session_path = MagicMock()
adapter._bridge_log_fh = None
adapter._bridge_log = None
adapter._bridge_process = None
adapter._reply_prefix = None
adapter._running = True
adapter._message_handler = None
adapter._fatal_error_code = None
adapter._fatal_error_message = None
adapter._fatal_error_retryable = True
adapter._fatal_error_handler = None
adapter._active_sessions = {}
adapter._pending_messages = {}
adapter._background_tasks = set()
adapter._auto_tts_disabled_chats = set()
adapter._message_queue = asyncio.Queue()
adapter._http_session = MagicMock()
adapter._mention_patterns = []
return adapter
class _AsyncCM:
"""Minimal async context manager returning a fixed value."""
def __init__(self, value):
self.value = value
async def __aenter__(self):
return self.value
async def __aexit__(self, *exc):
return False
# ---------------------------------------------------------------------------
# format_message tests
# ---------------------------------------------------------------------------
class TestFormatMessage:
"""WhatsApp markdown conversion."""
def test_bold_double_asterisk(self):
adapter = _make_adapter()
assert adapter.format_message("**hello**") == "*hello*"
def test_bold_double_underscore(self):
adapter = _make_adapter()
assert adapter.format_message("__hello__") == "*hello*"
def test_strikethrough(self):
adapter = _make_adapter()
assert adapter.format_message("~~deleted~~") == "~deleted~"
def test_headers_converted_to_bold(self):
adapter = _make_adapter()
assert adapter.format_message("# Title") == "*Title*"
assert adapter.format_message("## Subtitle") == "*Subtitle*"
assert adapter.format_message("### Deep") == "*Deep*"
def test_links_converted(self):
adapter = _make_adapter()
result = adapter.format_message("[click here](https://example.com)")
assert result == "click here (https://example.com)"
def test_code_blocks_protected(self):
"""Code blocks should not have their content reformatted."""
adapter = _make_adapter()
content = "before **bold** ```python\n**not bold**\n``` after **bold**"
result = adapter.format_message(content)
assert "```python\n**not bold**\n```" in result
assert result.startswith("before *bold*")
assert result.endswith("after *bold*")
def test_inline_code_protected(self):
"""Inline code should not have its content reformatted."""
adapter = _make_adapter()
content = "use `**raw**` here"
result = adapter.format_message(content)
assert "`**raw**`" in result
assert result.startswith("use ")
def test_empty_content(self):
adapter = _make_adapter()
assert adapter.format_message("") == ""
assert adapter.format_message(None) is None
def test_plain_text_unchanged(self):
adapter = _make_adapter()
assert adapter.format_message("hello world") == "hello world"
def test_already_whatsapp_italic(self):
"""Single *italic* should pass through unchanged."""
adapter = _make_adapter()
# After bold conversion, *text* is WhatsApp italic
assert adapter.format_message("*italic*") == "*italic*"
def test_multiline_mixed(self):
adapter = _make_adapter()
content = "# Header\n\n**Bold text** and ~~strike~~\n\n```\ncode\n```"
result = adapter.format_message(content)
assert "*Header*" in result
assert "*Bold text*" in result
assert "~strike~" in result
assert "```\ncode\n```" in result
# ---------------------------------------------------------------------------
# MAX_MESSAGE_LENGTH tests
# ---------------------------------------------------------------------------
class TestMessageLimits:
"""WhatsApp message length limits."""
def test_max_message_length_is_practical(self):
from gateway.platforms.whatsapp import WhatsAppAdapter
assert WhatsAppAdapter.MAX_MESSAGE_LENGTH == 4096
# ---------------------------------------------------------------------------
# send() chunking tests
# ---------------------------------------------------------------------------
class TestSendChunking:
"""WhatsApp send() splits long messages into chunks."""
@pytest.mark.asyncio
async def test_short_message_single_send(self):
adapter = _make_adapter()
resp = MagicMock(status=200)
resp.json = AsyncMock(return_value={"messageId": "msg1"})
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
result = await adapter.send("chat1", "short message")
assert result.success
# Only one call to bridge /send
assert adapter._http_session.post.call_count == 1
@pytest.mark.asyncio
async def test_long_message_chunked(self):
adapter = _make_adapter()
resp = MagicMock(status=200)
resp.json = AsyncMock(return_value={"messageId": "msg1"})
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
# Create a message longer than MAX_MESSAGE_LENGTH (4096)
long_msg = "a " * 3000 # ~6000 chars
result = await adapter.send("chat1", long_msg)
assert result.success
# Should have made multiple calls
assert adapter._http_session.post.call_count > 1
@pytest.mark.asyncio
async def test_empty_message_no_send(self):
adapter = _make_adapter()
result = await adapter.send("chat1", "")
assert result.success
assert adapter._http_session.post.call_count == 0
@pytest.mark.asyncio
async def test_whitespace_only_no_send(self):
adapter = _make_adapter()
result = await adapter.send("chat1", " \n ")
assert result.success
assert adapter._http_session.post.call_count == 0
@pytest.mark.asyncio
async def test_format_applied_before_send(self):
"""Markdown should be converted to WhatsApp format before sending."""
adapter = _make_adapter()
resp = MagicMock(status=200)
resp.json = AsyncMock(return_value={"messageId": "msg1"})
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
await adapter.send("chat1", "**bold text**")
# Check the payload sent to the bridge
call_args = adapter._http_session.post.call_args
payload = call_args.kwargs.get("json") or call_args[1].get("json")
assert payload["message"] == "*bold text*"
@pytest.mark.asyncio
async def test_reply_to_only_on_first_chunk(self):
"""reply_to should only be set on the first chunk."""
adapter = _make_adapter()
resp = MagicMock(status=200)
resp.json = AsyncMock(return_value={"messageId": "msg1"})
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
long_msg = "word " * 2000 # ~10000 chars, multiple chunks
await adapter.send("chat1", long_msg, reply_to="orig123")
calls = adapter._http_session.post.call_args_list
assert len(calls) > 1
# First chunk should have replyTo
first_payload = calls[0].kwargs.get("json") or calls[0][1].get("json")
assert first_payload.get("replyTo") == "orig123"
# Subsequent chunks should NOT have replyTo
for call in calls[1:]:
payload = call.kwargs.get("json") or call[1].get("json")
assert "replyTo" not in payload
@pytest.mark.asyncio
async def test_bridge_error_returns_failure(self):
adapter = _make_adapter()
resp = MagicMock(status=500)
resp.text = AsyncMock(return_value="Internal Server Error")
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
result = await adapter.send("chat1", "hello")
assert not result.success
assert "Internal Server Error" in result.error
@pytest.mark.asyncio
async def test_not_connected_returns_failure(self):
adapter = _make_adapter()
adapter._running = False
result = await adapter.send("chat1", "hello")
assert not result.success
assert "Not connected" in result.error
# ---------------------------------------------------------------------------
# display_config tier classification
# ---------------------------------------------------------------------------
class TestWhatsAppTier:
"""WhatsApp should be classified as TIER_MEDIUM."""
def test_whatsapp_streaming_follows_global(self):
from gateway.display_config import resolve_display_setting
# TIER_MEDIUM has streaming: None (follow global), not False
assert resolve_display_setting({}, "whatsapp", "streaming") is None
def test_whatsapp_tool_progress_is_new(self):
from gateway.display_config import resolve_display_setting
assert resolve_display_setting({}, "whatsapp", "tool_progress") == "new"

View File

@@ -23,9 +23,9 @@ from hermes_cli.auth import (
get_auth_status,
AuthError,
KIMI_CODE_BASE_URL,
_try_gh_cli_token,
_resolve_kimi_base_url,
)
from hermes_cli.copilot_auth import _try_gh_cli_token
# =============================================================================
@@ -68,7 +68,7 @@ class TestProviderRegistry:
def test_copilot_env_vars(self):
pconfig = PROVIDER_REGISTRY["copilot"]
assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")
assert pconfig.base_url_env_var == ""
assert pconfig.base_url_env_var == "COPILOT_API_BASE_URL"
def test_kimi_env_vars(self):
pconfig = PROVIDER_REGISTRY["kimi-coding"]
@@ -381,13 +381,13 @@ class TestResolveApiKeyProviderCredentials:
assert creds["source"] == "gh auth token"
def test_try_gh_cli_token_uses_homebrew_path_when_not_on_path(self, monkeypatch):
monkeypatch.setattr("hermes_cli.auth.shutil.which", lambda command: None)
monkeypatch.setattr("hermes_cli.copilot_auth.shutil.which", lambda command: None)
monkeypatch.setattr(
"hermes_cli.auth.os.path.isfile",
"hermes_cli.copilot_auth.os.path.isfile",
lambda path: path == "/opt/homebrew/bin/gh",
)
monkeypatch.setattr(
"hermes_cli.auth.os.access",
"hermes_cli.copilot_auth.os.access",
lambda path, mode: path == "/opt/homebrew/bin/gh" and mode == os.X_OK,
)
@@ -397,11 +397,11 @@ class TestResolveApiKeyProviderCredentials:
returncode = 0
stdout = "gh-cli-secret\n"
def _fake_run(cmd, capture_output, text, timeout):
def _fake_run(cmd, **kwargs):
calls.append(cmd)
return _Result()
monkeypatch.setattr("hermes_cli.auth.subprocess.run", _fake_run)
monkeypatch.setattr("hermes_cli.copilot_auth.subprocess.run", _fake_run)
assert _try_gh_cli_token() == "gh-cli-secret"
assert calls == [["/opt/homebrew/bin/gh", "auth", "token"]]

View File

@@ -1,6 +1,8 @@
"""Tests for hermes backup and import commands."""
import json
import os
import sqlite3
import zipfile
from argparse import Namespace
from pathlib import Path
@@ -232,6 +234,44 @@ class TestBackup:
assert len(zips) == 1
# ---------------------------------------------------------------------------
# _validate_backup_zip tests
# ---------------------------------------------------------------------------
class TestValidateBackupZip:
def _make_zip(self, zip_path: Path, filenames: list[str]) -> None:
with zipfile.ZipFile(zip_path, "w") as zf:
for name in filenames:
zf.writestr(name, "dummy")
def test_state_db_passes(self, tmp_path):
"""A zip containing state.db is accepted as a valid Hermes backup."""
from hermes_cli.backup import _validate_backup_zip
zip_path = tmp_path / "backup.zip"
self._make_zip(zip_path, ["state.db", "sessions/abc.json"])
with zipfile.ZipFile(zip_path, "r") as zf:
ok, reason = _validate_backup_zip(zf)
assert ok, reason
def test_old_wrong_db_name_fails(self, tmp_path):
"""A zip with only hermes_state.db (old wrong name) is rejected."""
from hermes_cli.backup import _validate_backup_zip
zip_path = tmp_path / "old.zip"
self._make_zip(zip_path, ["hermes_state.db", "memory_store.db"])
with zipfile.ZipFile(zip_path, "r") as zf:
ok, reason = _validate_backup_zip(zf)
assert not ok
def test_config_yaml_passes(self, tmp_path):
"""A zip containing config.yaml is accepted (existing behaviour preserved)."""
from hermes_cli.backup import _validate_backup_zip
zip_path = tmp_path / "backup.zip"
self._make_zip(zip_path, ["config.yaml", "skills/x/SKILL.md"])
with zipfile.ZipFile(zip_path, "r") as zf:
ok, reason = _validate_backup_zip(zf)
assert ok, reason
# ---------------------------------------------------------------------------
# Import tests
# ---------------------------------------------------------------------------
@@ -895,3 +935,181 @@ class TestProfileRestoration:
# Files should still be restored even if wrappers can't be created
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
# ---------------------------------------------------------------------------
# SQLite safe copy tests
# ---------------------------------------------------------------------------
class TestSafeCopyDb:
def test_copies_valid_database(self, tmp_path):
from hermes_cli.backup import _safe_copy_db
src = tmp_path / "test.db"
dst = tmp_path / "copy.db"
conn = sqlite3.connect(str(src))
conn.execute("CREATE TABLE t (x INTEGER)")
conn.execute("INSERT INTO t VALUES (42)")
conn.commit()
conn.close()
result = _safe_copy_db(src, dst)
assert result is True
conn = sqlite3.connect(str(dst))
rows = conn.execute("SELECT x FROM t").fetchall()
conn.close()
assert rows == [(42,)]
def test_copies_wal_mode_database(self, tmp_path):
from hermes_cli.backup import _safe_copy_db
src = tmp_path / "wal.db"
dst = tmp_path / "copy.db"
conn = sqlite3.connect(str(src))
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("CREATE TABLE t (x TEXT)")
conn.execute("INSERT INTO t VALUES ('wal-test')")
conn.commit()
conn.close()
result = _safe_copy_db(src, dst)
assert result is True
conn = sqlite3.connect(str(dst))
rows = conn.execute("SELECT x FROM t").fetchall()
conn.close()
assert rows == [("wal-test",)]
# ---------------------------------------------------------------------------
# Quick state snapshot tests
# ---------------------------------------------------------------------------
class TestQuickSnapshot:
@pytest.fixture
def hermes_home(self, tmp_path):
"""Create a fake HERMES_HOME with critical state files."""
home = tmp_path / ".hermes"
home.mkdir()
(home / "config.yaml").write_text("model:\n provider: openrouter\n")
(home / ".env").write_text("OPENROUTER_API_KEY=test-key-123\n")
(home / "auth.json").write_text('{"providers": {}}\n')
(home / "cron").mkdir()
(home / "cron" / "jobs.json").write_text('{"jobs": []}\n')
# Real SQLite database
db_path = home / "state.db"
conn = sqlite3.connect(str(db_path))
conn.execute("CREATE TABLE sessions (id TEXT PRIMARY KEY, data TEXT)")
conn.execute("INSERT INTO sessions VALUES ('s1', 'hello world')")
conn.commit()
conn.close()
return home
def test_creates_snapshot(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot
snap_id = create_quick_snapshot(hermes_home=hermes_home)
assert snap_id is not None
snap_dir = hermes_home / "state-snapshots" / snap_id
assert snap_dir.is_dir()
assert (snap_dir / "manifest.json").exists()
def test_label_in_id(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot
snap_id = create_quick_snapshot(label="before-upgrade", hermes_home=hermes_home)
assert "before-upgrade" in snap_id
def test_state_db_safely_copied(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot
snap_id = create_quick_snapshot(hermes_home=hermes_home)
db_copy = hermes_home / "state-snapshots" / snap_id / "state.db"
assert db_copy.exists()
conn = sqlite3.connect(str(db_copy))
rows = conn.execute("SELECT * FROM sessions").fetchall()
conn.close()
assert len(rows) == 1
assert rows[0] == ("s1", "hello world")
def test_copies_nested_files(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot
snap_id = create_quick_snapshot(hermes_home=hermes_home)
assert (hermes_home / "state-snapshots" / snap_id / "cron" / "jobs.json").exists()
def test_missing_files_skipped(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot
snap_id = create_quick_snapshot(hermes_home=hermes_home)
with open(hermes_home / "state-snapshots" / snap_id / "manifest.json") as f:
meta = json.load(f)
# gateway_state.json etc. don't exist in fixture
assert "gateway_state.json" not in meta["files"]
def test_empty_home_returns_none(self, tmp_path):
from hermes_cli.backup import create_quick_snapshot
empty = tmp_path / "empty"
empty.mkdir()
assert create_quick_snapshot(hermes_home=empty) is None
def test_list_snapshots(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot, list_quick_snapshots
id1 = create_quick_snapshot(label="first", hermes_home=hermes_home)
id2 = create_quick_snapshot(label="second", hermes_home=hermes_home)
snaps = list_quick_snapshots(hermes_home=hermes_home)
assert len(snaps) == 2
assert snaps[0]["id"] == id2 # most recent first
assert snaps[1]["id"] == id1
def test_list_limit(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot, list_quick_snapshots
for i in range(5):
create_quick_snapshot(label=f"s{i}", hermes_home=hermes_home)
snaps = list_quick_snapshots(limit=3, hermes_home=hermes_home)
assert len(snaps) == 3
def test_restore_config(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot, restore_quick_snapshot
snap_id = create_quick_snapshot(hermes_home=hermes_home)
(hermes_home / "config.yaml").write_text("model:\n provider: anthropic\n")
assert "anthropic" in (hermes_home / "config.yaml").read_text()
result = restore_quick_snapshot(snap_id, hermes_home=hermes_home)
assert result is True
assert "openrouter" in (hermes_home / "config.yaml").read_text()
def test_restore_state_db(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot, restore_quick_snapshot
snap_id = create_quick_snapshot(hermes_home=hermes_home)
conn = sqlite3.connect(str(hermes_home / "state.db"))
conn.execute("INSERT INTO sessions VALUES ('s2', 'new')")
conn.commit()
conn.close()
restore_quick_snapshot(snap_id, hermes_home=hermes_home)
conn = sqlite3.connect(str(hermes_home / "state.db"))
rows = conn.execute("SELECT * FROM sessions").fetchall()
conn.close()
assert len(rows) == 1
def test_restore_nonexistent(self, hermes_home):
from hermes_cli.backup import restore_quick_snapshot
assert restore_quick_snapshot("nonexistent", hermes_home=hermes_home) is False
def test_auto_prune(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot, list_quick_snapshots, _QUICK_DEFAULT_KEEP
for i in range(_QUICK_DEFAULT_KEEP + 5):
create_quick_snapshot(label=f"snap-{i:03d}", hermes_home=hermes_home)
snaps = list_quick_snapshots(limit=100, hermes_home=hermes_home)
assert len(snaps) <= _QUICK_DEFAULT_KEEP
def test_manual_prune(self, hermes_home):
from hermes_cli.backup import create_quick_snapshot, prune_quick_snapshots, list_quick_snapshots
for i in range(10):
create_quick_snapshot(label=f"s{i}", hermes_home=hermes_home)
deleted = prune_quick_snapshots(keep=3, hermes_home=hermes_home)
assert deleted == 7
assert len(list_quick_snapshots(hermes_home=hermes_home)) == 3

View File

@@ -1,6 +1,7 @@
"""Tests for hermes claw commands."""
from argparse import Namespace
import subprocess
from types import ModuleType
from unittest.mock import MagicMock, patch
@@ -197,6 +198,11 @@ class TestClawCommand:
class TestCmdMigrate:
"""Test the migrate command handler."""
@pytest.fixture(autouse=True)
def _mock_openclaw_running(self):
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=[]):
yield
def test_error_when_source_missing(self, tmp_path, capsys):
args = Namespace(
source=str(tmp_path / "nonexistent"),
@@ -626,3 +632,120 @@ class TestPrintMigrationReport:
claw_mod._print_migration_report(report, dry_run=False)
captured = capsys.readouterr()
assert "Nothing to migrate" in captured.out
class TestDetectOpenclawProcesses:
def test_returns_match_when_pgrep_finds_openclaw(self):
with patch.object(claw_mod, "sys") as mock_sys:
mock_sys.platform = "linux"
with patch.object(claw_mod, "subprocess") as mock_subprocess:
# systemd check misses, pgrep finds openclaw
mock_subprocess.run.side_effect = [
MagicMock(returncode=1, stdout=""), # systemctl
MagicMock(returncode=0, stdout="1234\n"), # pgrep
]
mock_subprocess.TimeoutExpired = subprocess.TimeoutExpired
result = claw_mod._detect_openclaw_processes()
assert len(result) == 1
assert "1234" in result[0]
def test_returns_empty_when_pgrep_finds_nothing(self):
with patch.object(claw_mod, "sys") as mock_sys:
mock_sys.platform = "darwin"
with patch.object(claw_mod, "subprocess") as mock_subprocess:
mock_subprocess.run.side_effect = [
MagicMock(returncode=1, stdout=""), # systemctl (not found)
MagicMock(returncode=1, stdout=""), # pgrep
]
mock_subprocess.TimeoutExpired = subprocess.TimeoutExpired
result = claw_mod._detect_openclaw_processes()
assert result == []
def test_detects_systemd_service(self):
with patch.object(claw_mod, "sys") as mock_sys:
mock_sys.platform = "linux"
with patch.object(claw_mod, "subprocess") as mock_subprocess:
mock_subprocess.run.side_effect = [
MagicMock(returncode=0, stdout="active\n"), # systemctl
MagicMock(returncode=1, stdout=""), # pgrep
]
mock_subprocess.TimeoutExpired = subprocess.TimeoutExpired
result = claw_mod._detect_openclaw_processes()
assert len(result) == 1
assert "systemd" in result[0]
def test_returns_match_on_windows_when_openclaw_exe_running(self):
with patch.object(claw_mod, "sys") as mock_sys:
mock_sys.platform = "win32"
with patch.object(claw_mod, "subprocess") as mock_subprocess:
mock_subprocess.run.side_effect = [
MagicMock(returncode=0, stdout="openclaw.exe 1234 Console 1 45,056 K\n"),
]
result = claw_mod._detect_openclaw_processes()
assert len(result) >= 1
assert any("openclaw.exe" in r for r in result)
def test_returns_match_on_windows_when_node_exe_has_openclaw_in_cmdline(self):
with patch.object(claw_mod, "sys") as mock_sys:
mock_sys.platform = "win32"
with patch.object(claw_mod, "subprocess") as mock_subprocess:
mock_subprocess.run.side_effect = [
MagicMock(returncode=0, stdout=""), # tasklist openclaw.exe
MagicMock(returncode=0, stdout=""), # tasklist clawd.exe
MagicMock(returncode=0, stdout="1234\n"), # PowerShell
]
result = claw_mod._detect_openclaw_processes()
assert len(result) >= 1
assert any("node.exe" in r for r in result)
def test_returns_empty_on_windows_when_nothing_found(self):
with patch.object(claw_mod, "sys") as mock_sys:
mock_sys.platform = "win32"
with patch.object(claw_mod, "subprocess") as mock_subprocess:
mock_subprocess.run.side_effect = [
MagicMock(returncode=0, stdout=""),
MagicMock(returncode=0, stdout=""),
MagicMock(returncode=0, stdout=""),
]
result = claw_mod._detect_openclaw_processes()
assert result == []
class TestWarnIfOpenclawRunning:
def test_noop_when_not_running(self, capsys):
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=[]):
claw_mod._warn_if_openclaw_running(auto_yes=False)
captured = capsys.readouterr()
assert captured.out == ""
def test_warns_and_exits_when_running_and_user_declines(self, capsys):
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
with patch.object(claw_mod.sys.stdin, "isatty", return_value=True):
with pytest.raises(SystemExit) as exc_info:
claw_mod._warn_if_openclaw_running(auto_yes=False)
assert exc_info.value.code == 0
captured = capsys.readouterr()
assert "OpenClaw appears to be running" in captured.out
def test_warns_and_continues_when_running_and_user_accepts(self, capsys):
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
with patch.object(claw_mod, "prompt_yes_no", return_value=True):
with patch.object(claw_mod.sys.stdin, "isatty", return_value=True):
claw_mod._warn_if_openclaw_running(auto_yes=False)
captured = capsys.readouterr()
assert "OpenClaw appears to be running" in captured.out
def test_warns_and_continues_in_auto_yes_mode(self, capsys):
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
claw_mod._warn_if_openclaw_running(auto_yes=True)
captured = capsys.readouterr()
assert "OpenClaw appears to be running" in captured.out
def test_warns_and_continues_in_non_interactive_session(self, capsys):
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
with patch.object(claw_mod.sys.stdin, "isatty", return_value=False):
claw_mod._warn_if_openclaw_running(auto_yes=False)
captured = capsys.readouterr()
assert "OpenClaw appears to be running" in captured.out
assert "Non-interactive session" in captured.out

View File

@@ -10,6 +10,7 @@ from hermes_cli.config import (
DEFAULT_CONFIG,
get_hermes_home,
ensure_hermes_home,
get_compatible_custom_providers,
load_config,
load_env,
migrate_config,
@@ -424,6 +425,146 @@ class TestAnthropicTokenMigration:
assert load_env().get("ANTHROPIC_TOKEN") == "current-token"
class TestCustomProviderCompatibility:
"""Custom provider compatibility across legacy and v12+ config schemas."""
def test_v11_upgrade_moves_custom_providers_into_providers(self, tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"_config_version": 11,
"model": {
"default": "openai/gpt-5.4",
"provider": "openrouter",
},
"custom_providers": [
{
"name": "OpenAI Direct",
"base_url": "https://api.openai.com/v1",
"api_key": "test-key",
"api_mode": "codex_responses",
"model": "gpt-5-mini",
}
],
"fallback_providers": [
{"provider": "openai-direct", "model": "gpt-5-mini"}
],
}
),
encoding="utf-8",
)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 17
assert raw["providers"]["openai-direct"] == {
"api": "https://api.openai.com/v1",
"api_key": "test-key",
"default_model": "gpt-5-mini",
"name": "OpenAI Direct",
"transport": "codex_responses",
}
# custom_providers removed by migration — runtime reads via compat layer
assert "custom_providers" not in raw
def test_providers_dict_resolves_at_runtime(self, tmp_path):
"""After migration deleted custom_providers, get_compatible_custom_providers
still finds entries from the providers dict."""
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"_config_version": 17,
"providers": {
"openai-direct": {
"api": "https://api.openai.com/v1",
"api_key": "test-key",
"default_model": "gpt-5-mini",
"name": "OpenAI Direct",
"transport": "codex_responses",
}
},
}
),
encoding="utf-8",
)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
compatible = get_compatible_custom_providers()
assert len(compatible) == 1
assert compatible[0]["name"] == "OpenAI Direct"
assert compatible[0]["base_url"] == "https://api.openai.com/v1"
assert compatible[0]["provider_key"] == "openai-direct"
assert compatible[0]["api_mode"] == "codex_responses"
def test_compatible_custom_providers_prefers_api_then_url_then_base_url(self, tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"_config_version": 17,
"providers": {
"my-provider": {
"name": "My Provider",
"api": "https://api.example.com/v1",
"url": "https://url.example.com/v1",
"base_url": "https://base.example.com/v1",
}
},
}
),
encoding="utf-8",
)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
compatible = get_compatible_custom_providers()
assert compatible == [
{
"name": "My Provider",
"base_url": "https://api.example.com/v1",
"provider_key": "my-provider",
}
]
def test_dedup_across_legacy_and_providers(self, tmp_path):
"""Same name+url in both schemas should not produce duplicates."""
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"_config_version": 17,
"custom_providers": [
{
"name": "OpenAI Direct",
"base_url": "https://api.openai.com/v1",
"api_key": "legacy-key",
}
],
"providers": {
"openai-direct": {
"api": "https://api.openai.com/v1",
"api_key": "new-key",
"name": "OpenAI Direct",
}
},
}
),
encoding="utf-8",
)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
compatible = get_compatible_custom_providers()
assert len(compatible) == 1
# Legacy entry wins (read first)
assert compatible[0]["api_key"] == "legacy-key"
class TestInterimAssistantMessageConfig:
"""Test the explicit gateway interim-message config gate."""
@@ -441,6 +582,6 @@ class TestInterimAssistantMessageConfig:
migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 16
assert raw["_config_version"] == 17
assert raw["display"]["tool_progress"] == "off"
assert raw["display"]["interim_assistant_messages"] is True

View File

@@ -12,49 +12,10 @@ from unittest.mock import MagicMock, patch
import pytest
from hermes_cli.config import (
_is_inside_container,
get_container_exec_info,
)
# =============================================================================
# _is_inside_container
# =============================================================================
def test_is_inside_container_dockerenv():
"""Detects /.dockerenv marker file."""
with patch("os.path.exists") as mock_exists:
mock_exists.side_effect = lambda p: p == "/.dockerenv"
assert _is_inside_container() is True
def test_is_inside_container_containerenv():
"""Detects Podman's /run/.containerenv marker."""
with patch("os.path.exists") as mock_exists:
mock_exists.side_effect = lambda p: p == "/run/.containerenv"
assert _is_inside_container() is True
def test_is_inside_container_cgroup_docker():
"""Detects 'docker' in /proc/1/cgroup."""
with patch("os.path.exists", return_value=False), \
patch("builtins.open", create=True) as mock_open:
mock_open.return_value.__enter__ = lambda s: s
mock_open.return_value.__exit__ = MagicMock(return_value=False)
mock_open.return_value.read = MagicMock(
return_value="12:memory:/docker/abc123\n"
)
assert _is_inside_container() is True
def test_is_inside_container_false_on_host():
"""Returns False when none of the container indicators are present."""
with patch("os.path.exists", return_value=False), \
patch("builtins.open", side_effect=OSError("no such file")):
assert _is_inside_container() is False
# =============================================================================
# get_container_exec_info
# =============================================================================
@@ -81,7 +42,7 @@ def container_env(tmp_path, monkeypatch):
def test_get_container_exec_info_returns_metadata(container_env):
"""Reads .container-mode and returns all fields including exec_user."""
with patch("hermes_cli.config._is_inside_container", return_value=False):
with patch("hermes_constants.is_container", return_value=False):
info = get_container_exec_info()
assert info is not None
@@ -93,7 +54,7 @@ def test_get_container_exec_info_returns_metadata(container_env):
def test_get_container_exec_info_none_inside_container(container_env):
"""Returns None when we're already inside a container."""
with patch("hermes_cli.config._is_inside_container", return_value=True):
with patch("hermes_constants.is_container", return_value=True):
info = get_container_exec_info()
assert info is None
@@ -106,7 +67,7 @@ def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("HERMES_DEV", raising=False)
with patch("hermes_cli.config._is_inside_container", return_value=False):
with patch("hermes_constants.is_container", return_value=False):
info = get_container_exec_info()
assert info is None
@@ -116,7 +77,7 @@ def test_get_container_exec_info_skipped_when_hermes_dev(container_env, monkeypa
"""Returns None when HERMES_DEV=1 is set (dev mode bypass)."""
monkeypatch.setenv("HERMES_DEV", "1")
with patch("hermes_cli.config._is_inside_container", return_value=False):
with patch("hermes_constants.is_container", return_value=False):
info = get_container_exec_info()
assert info is None
@@ -126,7 +87,7 @@ def test_get_container_exec_info_not_skipped_when_hermes_dev_zero(container_env,
"""HERMES_DEV=0 does NOT trigger bypass — only '1' does."""
monkeypatch.setenv("HERMES_DEV", "0")
with patch("hermes_cli.config._is_inside_container", return_value=False):
with patch("hermes_constants.is_container", return_value=False):
info = get_container_exec_info()
assert info is not None
@@ -143,7 +104,7 @@ def test_get_container_exec_info_defaults():
"# minimal file with no keys\n"
)
with patch("hermes_cli.config._is_inside_container", return_value=False), \
with patch("hermes_constants.is_container", return_value=False), \
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \
patch.dict(os.environ, {}, clear=False):
os.environ.pop("HERMES_DEV", None)
@@ -165,7 +126,7 @@ def test_get_container_exec_info_docker_backend(container_env):
"hermes_bin=/opt/hermes/bin/hermes\n"
)
with patch("hermes_cli.config._is_inside_container", return_value=False):
with patch("hermes_constants.is_container", return_value=False):
info = get_container_exec_info()
assert info["backend"] == "docker"
@@ -176,7 +137,7 @@ def test_get_container_exec_info_docker_backend(container_env):
def test_get_container_exec_info_crashes_on_permission_error(container_env):
"""PermissionError propagates instead of being silently swallowed."""
with patch("hermes_cli.config._is_inside_container", return_value=False), \
with patch("hermes_constants.is_container", return_value=False), \
patch("builtins.open", side_effect=PermissionError("permission denied")):
with pytest.raises(PermissionError):
get_container_exec_info()

View File

@@ -0,0 +1,461 @@
"""Tests for ``hermes debug`` CLI command and debug utilities."""
import os
import sys
import urllib.error
from pathlib import Path
from unittest.mock import MagicMock, patch, call
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def hermes_home(tmp_path, monkeypatch):
"""Set up an isolated HERMES_HOME with minimal logs."""
home = tmp_path / ".hermes"
home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(home))
# Create log files
logs_dir = home / "logs"
logs_dir.mkdir()
(logs_dir / "agent.log").write_text(
"2026-04-12 17:00:00 INFO agent: session started\n"
"2026-04-12 17:00:01 INFO tools.terminal: running ls\n"
"2026-04-12 17:00:02 WARNING agent: high token usage\n"
)
(logs_dir / "errors.log").write_text(
"2026-04-12 17:00:05 ERROR gateway.run: connection lost\n"
)
(logs_dir / "gateway.log").write_text(
"2026-04-12 17:00:10 INFO gateway.run: started\n"
)
return home
# ---------------------------------------------------------------------------
# Unit tests for upload helpers
# ---------------------------------------------------------------------------
class TestUploadPasteRs:
"""Test paste.rs upload path."""
def test_upload_paste_rs_success(self):
from hermes_cli.debug import _upload_paste_rs
mock_resp = MagicMock()
mock_resp.read.return_value = b"https://paste.rs/abc123\n"
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("hermes_cli.debug.urllib.request.urlopen", return_value=mock_resp):
url = _upload_paste_rs("hello world")
assert url == "https://paste.rs/abc123"
def test_upload_paste_rs_bad_response(self):
from hermes_cli.debug import _upload_paste_rs
mock_resp = MagicMock()
mock_resp.read.return_value = b"<html>error</html>"
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("hermes_cli.debug.urllib.request.urlopen", return_value=mock_resp):
with pytest.raises(ValueError, match="Unexpected response"):
_upload_paste_rs("test")
def test_upload_paste_rs_network_error(self):
from hermes_cli.debug import _upload_paste_rs
with patch(
"hermes_cli.debug.urllib.request.urlopen",
side_effect=urllib.error.URLError("connection refused"),
):
with pytest.raises(urllib.error.URLError):
_upload_paste_rs("test")
class TestUploadDpasteCom:
"""Test dpaste.com fallback upload path."""
def test_upload_dpaste_com_success(self):
from hermes_cli.debug import _upload_dpaste_com
mock_resp = MagicMock()
mock_resp.read.return_value = b"https://dpaste.com/ABCDEFG\n"
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("hermes_cli.debug.urllib.request.urlopen", return_value=mock_resp):
url = _upload_dpaste_com("hello world", expiry_days=7)
assert url == "https://dpaste.com/ABCDEFG"
class TestUploadToPastebin:
"""Test the combined upload with fallback."""
def test_tries_paste_rs_first(self):
from hermes_cli.debug import upload_to_pastebin
with patch("hermes_cli.debug._upload_paste_rs",
return_value="https://paste.rs/test") as prs:
url = upload_to_pastebin("content")
assert url == "https://paste.rs/test"
prs.assert_called_once()
def test_falls_back_to_dpaste_com(self):
from hermes_cli.debug import upload_to_pastebin
with patch("hermes_cli.debug._upload_paste_rs",
side_effect=Exception("down")), \
patch("hermes_cli.debug._upload_dpaste_com",
return_value="https://dpaste.com/TEST") as dp:
url = upload_to_pastebin("content")
assert url == "https://dpaste.com/TEST"
dp.assert_called_once()
def test_raises_when_both_fail(self):
from hermes_cli.debug import upload_to_pastebin
with patch("hermes_cli.debug._upload_paste_rs",
side_effect=Exception("err1")), \
patch("hermes_cli.debug._upload_dpaste_com",
side_effect=Exception("err2")):
with pytest.raises(RuntimeError, match="Failed to upload"):
upload_to_pastebin("content")
# ---------------------------------------------------------------------------
# Log reading
# ---------------------------------------------------------------------------
class TestReadFullLog:
"""Test _read_full_log for standalone log uploads."""
def test_reads_small_file(self, hermes_home):
from hermes_cli.debug import _read_full_log
content = _read_full_log("agent")
assert content is not None
assert "session started" in content
def test_returns_none_for_missing(self, tmp_path, monkeypatch):
home = tmp_path / ".hermes"
home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(home))
from hermes_cli.debug import _read_full_log
assert _read_full_log("agent") is None
def test_returns_none_for_empty(self, hermes_home):
# Truncate agent.log to empty
(hermes_home / "logs" / "agent.log").write_text("")
from hermes_cli.debug import _read_full_log
assert _read_full_log("agent") is None
def test_truncates_large_file(self, hermes_home):
"""Files larger than max_bytes get tail-truncated."""
from hermes_cli.debug import _read_full_log
# Write a file larger than 1KB
big_content = "x" * 100 + "\n"
(hermes_home / "logs" / "agent.log").write_text(big_content * 200)
content = _read_full_log("agent", max_bytes=1024)
assert content is not None
assert "truncated" in content
def test_unknown_log_returns_none(self, hermes_home):
from hermes_cli.debug import _read_full_log
assert _read_full_log("nonexistent") is None
def test_falls_back_to_rotated_file(self, hermes_home):
"""When gateway.log doesn't exist, falls back to gateway.log.1."""
from hermes_cli.debug import _read_full_log
logs_dir = hermes_home / "logs"
# Remove the primary (if any) and create a .1 rotation
(logs_dir / "gateway.log").unlink(missing_ok=True)
(logs_dir / "gateway.log.1").write_text(
"2026-04-12 10:00:00 INFO gateway.run: rotated content\n"
)
content = _read_full_log("gateway")
assert content is not None
assert "rotated content" in content
def test_prefers_primary_over_rotated(self, hermes_home):
"""Primary log is used when it exists, even if .1 also exists."""
from hermes_cli.debug import _read_full_log
logs_dir = hermes_home / "logs"
(logs_dir / "gateway.log").write_text("primary content\n")
(logs_dir / "gateway.log.1").write_text("rotated content\n")
content = _read_full_log("gateway")
assert "primary content" in content
assert "rotated" not in content
def test_falls_back_when_primary_empty(self, hermes_home):
"""Empty primary log falls back to .1 rotation."""
from hermes_cli.debug import _read_full_log
logs_dir = hermes_home / "logs"
(logs_dir / "agent.log").write_text("")
(logs_dir / "agent.log.1").write_text("rotated agent data\n")
content = _read_full_log("agent")
assert content is not None
assert "rotated agent data" in content
# ---------------------------------------------------------------------------
# Debug report collection
# ---------------------------------------------------------------------------
class TestCollectDebugReport:
"""Test the debug report builder."""
def test_report_includes_dump_output(self, hermes_home):
from hermes_cli.debug import collect_debug_report
with patch("hermes_cli.dump.run_dump") as mock_dump:
mock_dump.side_effect = lambda args: print(
"--- hermes dump ---\nversion: 0.8.0\n--- end dump ---"
)
report = collect_debug_report(log_lines=50)
assert "--- hermes dump ---" in report
assert "version: 0.8.0" in report
def test_report_includes_agent_log(self, hermes_home):
from hermes_cli.debug import collect_debug_report
with patch("hermes_cli.dump.run_dump"):
report = collect_debug_report(log_lines=50)
assert "--- agent.log" in report
assert "session started" in report
def test_report_includes_errors_log(self, hermes_home):
from hermes_cli.debug import collect_debug_report
with patch("hermes_cli.dump.run_dump"):
report = collect_debug_report(log_lines=50)
assert "--- errors.log" in report
assert "connection lost" in report
def test_report_includes_gateway_log(self, hermes_home):
from hermes_cli.debug import collect_debug_report
with patch("hermes_cli.dump.run_dump"):
report = collect_debug_report(log_lines=50)
assert "--- gateway.log" in report
def test_missing_logs_handled(self, tmp_path, monkeypatch):
home = tmp_path / ".hermes"
home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(home))
from hermes_cli.debug import collect_debug_report
with patch("hermes_cli.dump.run_dump"):
report = collect_debug_report(log_lines=50)
assert "(file not found)" in report
# ---------------------------------------------------------------------------
# CLI entry point — run_debug_share
# ---------------------------------------------------------------------------
class TestRunDebugShare:
"""Test the run_debug_share CLI handler."""
def test_local_flag_prints_full_logs(self, hermes_home, capsys):
"""--local prints the report plus full log contents."""
from hermes_cli.debug import run_debug_share
args = MagicMock()
args.lines = 50
args.expire = 7
args.local = True
with patch("hermes_cli.dump.run_dump"):
run_debug_share(args)
out = capsys.readouterr().out
assert "--- agent.log" in out
assert "FULL agent.log" in out
assert "FULL gateway.log" in out
def test_share_uploads_three_pastes(self, hermes_home, capsys):
"""Successful share uploads report + agent.log + gateway.log."""
from hermes_cli.debug import run_debug_share
args = MagicMock()
args.lines = 50
args.expire = 7
args.local = False
call_count = [0]
uploaded_content = []
def _mock_upload(content, expiry_days=7):
call_count[0] += 1
uploaded_content.append(content)
return f"https://paste.rs/paste{call_count[0]}"
with patch("hermes_cli.dump.run_dump") as mock_dump, \
patch("hermes_cli.debug.upload_to_pastebin",
side_effect=_mock_upload):
mock_dump.side_effect = lambda a: print("--- hermes dump ---\nversion: test\n--- end dump ---")
run_debug_share(args)
out = capsys.readouterr().out
# Should have 3 uploads: report, agent.log, gateway.log
assert call_count[0] == 3
assert "paste.rs/paste1" in out # Report
assert "paste.rs/paste2" in out # agent.log
assert "paste.rs/paste3" in out # gateway.log
assert "Report" in out
assert "agent.log" in out
assert "gateway.log" in out
# Each log paste should start with the dump header
agent_paste = uploaded_content[1]
assert "--- hermes dump ---" in agent_paste
assert "--- full agent.log ---" in agent_paste
gateway_paste = uploaded_content[2]
assert "--- hermes dump ---" in gateway_paste
assert "--- full gateway.log ---" in gateway_paste
def test_share_skips_missing_logs(self, tmp_path, monkeypatch, capsys):
"""Only uploads logs that exist."""
home = tmp_path / ".hermes"
home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(home))
from hermes_cli.debug import run_debug_share
args = MagicMock()
args.lines = 50
args.expire = 7
args.local = False
call_count = [0]
def _mock_upload(content, expiry_days=7):
call_count[0] += 1
return f"https://paste.rs/paste{call_count[0]}"
with patch("hermes_cli.dump.run_dump"), \
patch("hermes_cli.debug.upload_to_pastebin",
side_effect=_mock_upload):
run_debug_share(args)
out = capsys.readouterr().out
# Only the report should be uploaded (no log files exist)
assert call_count[0] == 1
assert "Report" in out
def test_share_continues_on_log_upload_failure(self, hermes_home, capsys):
"""Log upload failure doesn't stop the report from being shared."""
from hermes_cli.debug import run_debug_share
args = MagicMock()
args.lines = 50
args.expire = 7
args.local = False
call_count = [0]
def _mock_upload(content, expiry_days=7):
call_count[0] += 1
if call_count[0] > 1:
raise RuntimeError("upload failed")
return "https://paste.rs/report"
with patch("hermes_cli.dump.run_dump"), \
patch("hermes_cli.debug.upload_to_pastebin",
side_effect=_mock_upload):
run_debug_share(args)
out = capsys.readouterr().out
assert "Report" in out
assert "paste.rs/report" in out
assert "failed to upload" in out
def test_share_exits_on_report_upload_failure(self, hermes_home, capsys):
"""If the main report fails to upload, exit with code 1."""
from hermes_cli.debug import run_debug_share
args = MagicMock()
args.lines = 50
args.expire = 7
args.local = False
with patch("hermes_cli.dump.run_dump"), \
patch("hermes_cli.debug.upload_to_pastebin",
side_effect=RuntimeError("all failed")):
with pytest.raises(SystemExit) as exc_info:
run_debug_share(args)
assert exc_info.value.code == 1
out = capsys.readouterr()
assert "all failed" in out.err
# ---------------------------------------------------------------------------
# run_debug router
# ---------------------------------------------------------------------------
class TestRunDebug:
def test_no_subcommand_shows_usage(self, capsys):
from hermes_cli.debug import run_debug
args = MagicMock()
args.debug_command = None
run_debug(args)
out = capsys.readouterr().out
assert "hermes debug share" in out
def test_share_subcommand_routes(self, hermes_home):
from hermes_cli.debug import run_debug
args = MagicMock()
args.debug_command = "share"
args.lines = 200
args.expire = 7
args.local = True
with patch("hermes_cli.dump.run_dump"):
run_debug(args)
# ---------------------------------------------------------------------------
# Argparse integration
# ---------------------------------------------------------------------------
class TestArgparseIntegration:
def test_module_imports_clean(self):
from hermes_cli.debug import run_debug, run_debug_share
assert callable(run_debug)
assert callable(run_debug_share)
def test_cmd_debug_dispatches(self):
from hermes_cli.main import cmd_debug
args = MagicMock()
args.debug_command = None
cmd_debug(args)

View File

@@ -0,0 +1,91 @@
"""Tests for .env sanitization during load to prevent token duplication (#8908)."""
import tempfile
from pathlib import Path
from unittest.mock import patch
def test_load_env_sanitizes_concatenated_lines():
"""Verify load_env() splits concatenated KEY=VALUE pairs.
Reproduces the scenario from #8908 where a corrupted .env file
contained multiple tokens on a single line, causing the bot token
to be duplicated 8 times.
"""
from hermes_cli.config import load_env
token = "8356550917:AAGGEkzg06Hrc3Hjb3Sa1jkGVDOdU_lYy2Q"
# Simulate concatenated line: TOKEN=xxx followed immediately by another key
corrupted = f"TELEGRAM_BOT_TOKEN={token}ANTHROPIC_API_KEY=sk-ant-test123\n"
with tempfile.NamedTemporaryFile(
mode="w", suffix=".env", delete=False, encoding="utf-8"
) as f:
f.write(corrupted)
env_path = Path(f.name)
try:
with patch("hermes_cli.config.get_env_path", return_value=env_path):
result = load_env()
assert result.get("TELEGRAM_BOT_TOKEN") == token, (
f"Token should be exactly '{token}', got '{result.get('TELEGRAM_BOT_TOKEN')}'"
)
assert result.get("ANTHROPIC_API_KEY") == "sk-ant-test123"
finally:
env_path.unlink(missing_ok=True)
def test_load_env_normal_file_unchanged():
"""A well-formed .env file should be parsed identically."""
from hermes_cli.config import load_env
content = (
"TELEGRAM_BOT_TOKEN=mytoken123\n"
"ANTHROPIC_API_KEY=sk-ant-key\n"
"# comment\n"
"\n"
"OPENAI_API_KEY=sk-openai\n"
)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".env", delete=False, encoding="utf-8"
) as f:
f.write(content)
env_path = Path(f.name)
try:
with patch("hermes_cli.config.get_env_path", return_value=env_path):
result = load_env()
assert result["TELEGRAM_BOT_TOKEN"] == "mytoken123"
assert result["ANTHROPIC_API_KEY"] == "sk-ant-key"
assert result["OPENAI_API_KEY"] == "sk-openai"
finally:
env_path.unlink(missing_ok=True)
def test_env_loader_sanitizes_before_dotenv():
"""Verify env_loader._sanitize_env_file_if_needed fixes corrupted files."""
from hermes_cli.env_loader import _sanitize_env_file_if_needed
token = "8356550917:AAGGEkzg06Hrc3Hjb3Sa1jkGVDOdU_lYy2Q"
corrupted = f"TELEGRAM_BOT_TOKEN={token}ANTHROPIC_API_KEY=sk-ant-test\n"
with tempfile.NamedTemporaryFile(
mode="w", suffix=".env", delete=False, encoding="utf-8"
) as f:
f.write(corrupted)
env_path = Path(f.name)
try:
_sanitize_env_file_if_needed(env_path)
with open(env_path, encoding="utf-8") as f:
lines = f.readlines()
# Should be split into two separate lines
assert len(lines) == 2, f"Expected 2 lines, got {len(lines)}: {lines}"
assert lines[0].startswith("TELEGRAM_BOT_TOKEN=")
assert lines[1].startswith("ANTHROPIC_API_KEY=")
# Token should not contain the second key
parsed_token = lines[0].strip().split("=", 1)[1]
assert parsed_token == token
finally:
env_path.unlink(missing_ok=True)

View File

@@ -394,6 +394,21 @@ class TestLaunchdServiceRecovery:
class TestGatewayServiceDetection:
def test_supports_systemd_services_requires_systemctl_binary(self, monkeypatch):
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
monkeypatch.setattr(gateway_cli.shutil, "which", lambda name: None)
assert gateway_cli.supports_systemd_services() is False
def test_supports_systemd_services_returns_true_when_systemctl_present(self, monkeypatch):
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
monkeypatch.setattr(gateway_cli, "is_wsl", lambda: False)
monkeypatch.setattr(gateway_cli.shutil, "which", lambda name: "/usr/bin/systemctl")
assert gateway_cli.supports_systemd_services() is True
def test_is_service_running_checks_system_scope_when_user_scope_is_inactive(self, monkeypatch):
user_unit = SimpleNamespace(exists=lambda: True)
system_unit = SimpleNamespace(exists=lambda: True)
@@ -418,6 +433,23 @@ class TestGatewayServiceDetection:
assert gateway_cli._is_service_running() is True
def test_is_service_running_returns_false_when_systemctl_missing(self, monkeypatch):
unit = SimpleNamespace(exists=lambda: True)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(
gateway_cli,
"get_systemd_unit_path",
lambda system=False: unit,
)
def fake_run(*args, **kwargs):
raise FileNotFoundError("systemctl")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
assert gateway_cli._is_service_running() is False
class TestGatewaySystemServiceRouting:
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
@@ -1001,3 +1033,91 @@ class TestSystemUnitPathRemapping:
# Target user paths should be present
assert "/home/alice" in unit
assert "WorkingDirectory=/home/alice/.hermes/hermes-agent" in unit
class TestDockerAwareGateway:
"""Tests for Docker container awareness in gateway commands."""
def test_run_systemctl_raises_runtimeerror_when_missing(self, monkeypatch):
"""_run_systemctl raises RuntimeError with container guidance when systemctl is absent."""
import pytest
def fake_run(cmd, **kwargs):
raise FileNotFoundError("systemctl")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
with pytest.raises(RuntimeError, match="systemctl is not available"):
gateway_cli._run_systemctl(["start", "hermes-gateway"])
def test_run_systemctl_passes_through_on_success(self, monkeypatch):
"""_run_systemctl delegates to subprocess.run when systemctl exists."""
calls = []
def fake_run(cmd, **kwargs):
calls.append(cmd)
return SimpleNamespace(returncode=0, stdout="", stderr="")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
result = gateway_cli._run_systemctl(["status", "hermes-gateway"])
assert result.returncode == 0
assert len(calls) == 1
assert "status" in calls[0]
def test_install_in_container_prints_docker_guidance(self, monkeypatch, capsys):
"""'hermes gateway install' inside Docker exits 0 with container guidance."""
import pytest
monkeypatch.setattr(gateway_cli, "is_managed", lambda: False)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "is_wsl", lambda: False)
monkeypatch.setattr(gateway_cli, "is_container", lambda: True)
args = SimpleNamespace(gateway_command="install", force=False, system=False, run_as_user=None)
with pytest.raises(SystemExit) as exc_info:
gateway_cli.gateway_command(args)
assert exc_info.value.code == 0
out = capsys.readouterr().out
assert "Docker" in out or "docker" in out
assert "restart" in out.lower()
def test_uninstall_in_container_prints_docker_guidance(self, monkeypatch, capsys):
"""'hermes gateway uninstall' inside Docker exits 0 with container guidance."""
import pytest
monkeypatch.setattr(gateway_cli, "is_managed", lambda: False)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "is_container", lambda: True)
args = SimpleNamespace(gateway_command="uninstall", system=False)
with pytest.raises(SystemExit) as exc_info:
gateway_cli.gateway_command(args)
assert exc_info.value.code == 0
out = capsys.readouterr().out
assert "docker" in out.lower()
def test_start_in_container_prints_docker_guidance(self, monkeypatch, capsys):
"""'hermes gateway start' inside Docker exits 0 with container guidance."""
import pytest
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "is_wsl", lambda: False)
monkeypatch.setattr(gateway_cli, "is_container", lambda: True)
args = SimpleNamespace(gateway_command="start", system=False)
with pytest.raises(SystemExit) as exc_info:
gateway_cli.gateway_command(args)
assert exc_info.value.code == 0
out = capsys.readouterr().out
assert "docker" in out.lower()
assert "hermes gateway run" in out

View File

@@ -54,14 +54,19 @@ class TestAnthropicDotToHyphen:
# ── OpenCode Zen regression ────────────────────────────────────────────
class TestOpenCodeZenDotToHyphen:
"""OpenCode Zen follows Anthropic convention (dots→hyphens)."""
class TestOpenCodeZenModelNormalization:
"""OpenCode Zen preserves dots for most models, but Claude stays hyphenated."""
@pytest.mark.parametrize("model,expected", [
("claude-sonnet-4.6", "claude-sonnet-4-6"),
("glm-4.5", "glm-4-5"),
("opencode-zen/claude-opus-4.5", "claude-opus-4-5"),
("glm-4.5", "glm-4.5"),
("glm-5.1", "glm-5.1"),
("gpt-5.4", "gpt-5.4"),
("minimax-m2.5-free", "minimax-m2.5-free"),
("kimi-k2.5", "kimi-k2.5"),
])
def test_zen_converts_dots(self, model, expected):
def test_zen_normalizes_models(self, model, expected):
result = normalize_model_for_provider(model, "opencode-zen")
assert result == expected
@@ -69,6 +74,10 @@ class TestOpenCodeZenDotToHyphen:
result = normalize_model_for_provider("opencode-zen/claude-sonnet-4.6", "opencode-zen")
assert result == "claude-sonnet-4-6"
def test_zen_strips_vendor_prefix_for_non_claude(self):
result = normalize_model_for_provider("opencode-zen/glm-5.1", "opencode-zen")
assert result == "glm-5.1"
# ── Copilot dot preservation (regression) ──────────────────────────────

View File

@@ -0,0 +1,84 @@
"""Tests for the Nous-Hermes-3/4 non-agentic warning detector.
Prior to this check, the warning fired on any model whose name contained
``"hermes"`` anywhere (case-insensitive). That false-positived on unrelated
local Modelfiles such as ``hermes-brain:qwen3-14b-ctx16k`` — a tool-capable
Qwen3 wrapper that happens to live under the "hermes" tag namespace.
``is_nous_hermes_non_agentic`` should only match the actual Nous Research
Hermes-3 / Hermes-4 chat family.
"""
from __future__ import annotations
import pytest
from hermes_cli.model_switch import (
_HERMES_MODEL_WARNING,
_check_hermes_model_warning,
is_nous_hermes_non_agentic,
)
@pytest.mark.parametrize(
"model_name",
[
"NousResearch/Hermes-3-Llama-3.1-70B",
"NousResearch/Hermes-3-Llama-3.1-405B",
"hermes-3",
"Hermes-3",
"hermes-4",
"hermes-4-405b",
"hermes_4_70b",
"openrouter/hermes3:70b",
"openrouter/nousresearch/hermes-4-405b",
"NousResearch/Hermes3",
"hermes-3.1",
],
)
def test_matches_real_nous_hermes_chat_models(model_name: str) -> None:
assert is_nous_hermes_non_agentic(model_name), (
f"expected {model_name!r} to be flagged as Nous Hermes 3/4"
)
assert _check_hermes_model_warning(model_name) == _HERMES_MODEL_WARNING
@pytest.mark.parametrize(
"model_name",
[
# Kyle's local Modelfile — qwen3:14b under a custom tag
"hermes-brain:qwen3-14b-ctx16k",
"hermes-brain:qwen3-14b-ctx32k",
"hermes-honcho:qwen3-8b-ctx8k",
# Plain unrelated models
"qwen3:14b",
"qwen3-coder:30b",
"qwen2.5:14b",
"claude-opus-4-6",
"anthropic/claude-sonnet-4.5",
"gpt-5",
"openai/gpt-4o",
"google/gemini-2.5-flash",
"deepseek-chat",
# Non-chat Hermes models we don't warn about
"hermes-llm-2",
"hermes2-pro",
"nous-hermes-2-mistral",
# Edge cases
"",
"hermes", # bare "hermes" isn't the 3/4 family
"hermes-brain",
"brain-hermes-3-impostor", # "3" not preceded by /: boundary
],
)
def test_does_not_match_unrelated_models(model_name: str) -> None:
assert not is_nous_hermes_non_agentic(model_name), (
f"expected {model_name!r} NOT to be flagged as Nous Hermes 3/4"
)
assert _check_hermes_model_warning(model_name) == ""
def test_none_like_inputs_are_safe() -> None:
assert is_nous_hermes_non_agentic("") is False
# Defensive: the helper shouldn't crash on None-ish falsy input either.
assert _check_hermes_model_warning("") == ""

View File

@@ -177,7 +177,8 @@ class TestCreateProfile:
# No error; optional files just not copied
assert not (profile_dir / "config.yaml").exists()
assert not (profile_dir / ".env").exists()
assert not (profile_dir / "SOUL.md").exists()
# SOUL.md is always seeded with the default even when clone source lacks it
assert (profile_dir / "SOUL.md").exists()
# ===================================================================

View File

@@ -119,6 +119,11 @@ def test_resolve_runtime_provider_falls_back_when_pool_empty(monkeypatch):
def test_resolve_runtime_provider_codex(monkeypatch):
monkeypatch.setattr(
rp,
"load_pool",
lambda provider: type("P", (), {"has_credentials": lambda self: False})(),
)
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openai-codex")
monkeypatch.setattr(
rp,
@@ -567,6 +572,87 @@ def test_named_custom_provider_uses_saved_credentials(monkeypatch):
assert resolved["source"] == "custom_provider:Local"
def test_named_custom_provider_uses_providers_dict_when_list_missing(monkeypatch):
"""After v11→v12 migration deletes custom_providers, resolution should
still find entries in the providers dict via get_compatible_custom_providers."""
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.setattr(
rp,
"load_config",
lambda: {
"providers": {
"openai-direct-primary": {
"api": "https://api.openai.com/v1",
"api_key": "dir-key",
"default_model": "gpt-5-mini",
"name": "OpenAI Direct (Primary)",
"transport": "codex_responses",
}
}
},
)
monkeypatch.setattr(
rp,
"resolve_provider",
lambda *a, **k: (_ for _ in ()).throw(
AssertionError(
"resolve_provider should not be called for named custom providers"
)
),
)
resolved = rp.resolve_runtime_provider(requested="openai-direct-primary")
assert resolved["provider"] == "custom"
assert resolved["api_mode"] == "codex_responses"
assert resolved["base_url"] == "https://api.openai.com/v1"
assert resolved["api_key"] == "dir-key"
assert resolved["requested_provider"] == "openai-direct-primary"
assert resolved["source"] == "custom_provider:OpenAI Direct (Primary)"
assert resolved["model"] == "gpt-5-mini"
def test_named_custom_provider_uses_key_env_from_providers_dict(monkeypatch):
"""providers dict entries with key_env should resolve API key from env var."""
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.setenv("MYCORP_API_KEY", "env-secret")
monkeypatch.setattr(
rp,
"load_config",
lambda: {
"providers": {
"mycorp-proxy": {
"base_url": "https://proxy.example.com/v1",
"default_model": "acme-large",
"key_env": "MYCORP_API_KEY",
"name": "MyCorp Proxy",
}
}
},
)
monkeypatch.setattr(
rp,
"resolve_provider",
lambda *a, **k: (_ for _ in ()).throw(
AssertionError(
"resolve_provider should not be called for named custom providers"
)
),
)
resolved = rp.resolve_runtime_provider(requested="mycorp-proxy")
assert resolved["provider"] == "custom"
assert resolved["api_mode"] == "chat_completions"
assert resolved["base_url"] == "https://proxy.example.com/v1"
assert resolved["api_key"] == "env-secret"
assert resolved["requested_provider"] == "mycorp-proxy"
assert resolved["source"] == "custom_provider:MyCorp Proxy"
assert resolved["model"] == "acme-large"
def test_named_custom_provider_falls_back_to_openai_api_key(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "env-openai-key")
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)

View File

@@ -1,5 +1,4 @@
"""Tests for setup_model_provider — verifies the delegation to
select_provider_and_model() and config dict sync."""
"""Tests for setup.py configuration flows."""
import json
import sys
import types
@@ -8,6 +7,7 @@ import pytest
from hermes_cli.auth import get_active_provider
from hermes_cli.config import load_config, save_config
from hermes_cli import setup as setup_mod
from hermes_cli.setup import setup_model_provider
@@ -144,6 +144,85 @@ def test_setup_custom_providers_synced(tmp_path, monkeypatch):
assert reloaded.get("custom_providers") == [{"name": "Local", "base_url": "http://localhost:8080/v1"}]
def test_setup_gateway_skips_service_install_when_systemctl_missing(monkeypatch, capsys):
env = {
"TELEGRAM_BOT_TOKEN": "",
"TELEGRAM_HOME_CHANNEL": "",
"DISCORD_BOT_TOKEN": "",
"DISCORD_HOME_CHANNEL": "",
"SLACK_BOT_TOKEN": "",
"SLACK_HOME_CHANNEL": "",
"MATRIX_HOMESERVER": "https://matrix.example.com",
"MATRIX_USER_ID": "@alice:example.com",
"MATRIX_PASSWORD": "",
"MATRIX_ACCESS_TOKEN": "token",
"BLUEBUBBLES_SERVER_URL": "",
"BLUEBUBBLES_HOME_CHANNEL": "",
"WHATSAPP_ENABLED": "",
"WEBHOOK_ENABLED": "",
}
monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, ""))
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("platform.system", lambda: "Linux")
import hermes_cli.gateway as gateway_mod
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False)
setup_mod.setup_gateway({})
out = capsys.readouterr().out
assert "Messaging platforms configured!" in out
assert "Start the gateway to bring your bots online:" in out
assert "hermes gateway" in out
def test_setup_gateway_in_container_shows_docker_guidance(monkeypatch, capsys):
"""setup_gateway() in a Docker container shows Docker-specific restart instructions."""
env = {
"TELEGRAM_BOT_TOKEN": "",
"TELEGRAM_HOME_CHANNEL": "",
"DISCORD_BOT_TOKEN": "",
"DISCORD_HOME_CHANNEL": "",
"SLACK_BOT_TOKEN": "",
"SLACK_HOME_CHANNEL": "",
"MATRIX_HOMESERVER": "https://matrix.example.com",
"MATRIX_USER_ID": "@alice:example.com",
"MATRIX_PASSWORD": "",
"MATRIX_ACCESS_TOKEN": "token",
"BLUEBUBBLES_SERVER_URL": "",
"BLUEBUBBLES_HOME_CHANNEL": "",
"WHATSAPP_ENABLED": "",
"WEBHOOK_ENABLED": "",
}
monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, ""))
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("platform.system", lambda: "Linux")
import hermes_cli.gateway as gateway_mod
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False)
# Patch is_container at the import location in setup.py
import hermes_constants
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
setup_mod.setup_gateway({})
out = capsys.readouterr().out
assert "Messaging platforms configured!" in out
assert "docker" in out.lower() or "Docker" in out
assert "restart" in out.lower()
def test_setup_syncs_custom_provider_removal_from_disk(tmp_path, monkeypatch):
"""Removing the last custom provider in model setup should persist."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))

View File

@@ -119,8 +119,7 @@ def test_toolset_has_keys_for_vision_accepts_codex_auth(tmp_path, monkeypatch):
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AUXILIARY_VISION_PROVIDER", raising=False)
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
monkeypatch.setattr(
"agent.auxiliary_client.resolve_vision_provider_client",
lambda: ("openai-codex", object(), "gpt-4.1"),

View File

@@ -0,0 +1,280 @@
"""Tests for user-defined providers (providers: dict) in /model.
These tests ensure that providers defined in the config.yaml ``providers:`` section
are properly resolved for model switching and that their full ``models:`` lists
are exposed in the model picker.
"""
import pytest
from hermes_cli.model_switch import list_authenticated_providers, switch_model
from hermes_cli import runtime_provider as rp
# =============================================================================
# Tests for list_authenticated_providers including full models list
# =============================================================================
def test_list_authenticated_providers_includes_full_models_list_from_user_providers(monkeypatch):
"""User-defined providers should expose both default_model and full models list.
Regression test: previously only default_model was shown in /model picker.
"""
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
user_providers = {
"local-ollama": {
"name": "Local Ollama",
"api": "http://localhost:11434/v1",
"default_model": "minimax-m2.7:cloud",
"models": [
"minimax-m2.7:cloud",
"kimi-k2.5:cloud",
"glm-5.1:cloud",
"qwen3.5:cloud",
],
}
}
providers = list_authenticated_providers(
current_provider="local-ollama",
user_providers=user_providers,
custom_providers=[],
max_models=50,
)
# Find our user provider
user_prov = next(
(p for p in providers if p.get("is_user_defined") and p["slug"] == "local-ollama"),
None
)
assert user_prov is not None, "User provider 'local-ollama' should be in results"
assert user_prov["total_models"] == 4, f"Expected 4 models, got {user_prov['total_models']}"
assert "minimax-m2.7:cloud" in user_prov["models"]
assert "kimi-k2.5:cloud" in user_prov["models"]
assert "glm-5.1:cloud" in user_prov["models"]
assert "qwen3.5:cloud" in user_prov["models"]
def test_list_authenticated_providers_dedupes_models_when_default_in_list(monkeypatch):
"""When default_model is also in models list, don't duplicate."""
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
user_providers = {
"my-provider": {
"api": "http://example.com/v1",
"default_model": "model-a", # Included in models list below
"models": ["model-a", "model-b", "model-c"],
}
}
providers = list_authenticated_providers(
current_provider="my-provider",
user_providers=user_providers,
custom_providers=[],
)
user_prov = next(
(p for p in providers if p.get("is_user_defined")),
None
)
assert user_prov is not None
assert user_prov["total_models"] == 3, "Should have 3 unique models, not 4"
assert user_prov["models"].count("model-a") == 1, "model-a should not be duplicated"
def test_list_authenticated_providers_fallback_to_default_only(monkeypatch):
"""When no models array is provided, should fall back to default_model."""
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
user_providers = {
"simple-provider": {
"name": "Simple Provider",
"api": "http://example.com/v1",
"default_model": "single-model",
# No 'models' key
}
}
providers = list_authenticated_providers(
current_provider="",
user_providers=user_providers,
custom_providers=[],
)
user_prov = next(
(p for p in providers if p.get("is_user_defined")),
None
)
assert user_prov is not None
assert user_prov["total_models"] == 1
assert user_prov["models"] == ["single-model"]
# =============================================================================
# Tests for _get_named_custom_provider with providers: dict
# =============================================================================
def test_get_named_custom_provider_finds_user_providers_by_key(monkeypatch, tmp_path):
"""Should resolve providers from providers: dict (new-style), not just custom_providers."""
config = {
"providers": {
"local-localhost:11434": {
"api": "http://localhost:11434/v1",
"name": "Local (localhost:11434)",
"default_model": "minimax-m2.7:cloud",
}
}
}
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
result = rp._get_named_custom_provider("local-localhost:11434")
assert result is not None
assert result["base_url"] == "http://localhost:11434/v1"
assert result["name"] == "Local (localhost:11434)"
def test_get_named_custom_provider_finds_by_display_name(monkeypatch, tmp_path):
"""Should match providers by their 'name' field as well as key."""
config = {
"providers": {
"my-ollama-xyz": {
"api": "http://ollama.example.com/v1",
"name": "My Production Ollama",
"default_model": "llama3",
}
}
}
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Should find by display name (normalized)
result = rp._get_named_custom_provider("my-production-ollama")
assert result is not None
assert result["base_url"] == "http://ollama.example.com/v1"
def test_get_named_custom_provider_falls_back_to_legacy_format(monkeypatch, tmp_path):
"""Should still work with custom_providers: list format."""
config = {
"providers": {},
"custom_providers": [
{
"name": "Custom Endpoint",
"base_url": "http://custom.example.com/v1",
}
]
}
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
result = rp._get_named_custom_provider("custom-endpoint")
assert result is not None
def test_get_named_custom_provider_returns_none_for_unknown(monkeypatch, tmp_path):
"""Should return None for providers that don't exist."""
config = {
"providers": {
"known-provider": {
"api": "http://known.example.com/v1",
}
}
}
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
result = rp._get_named_custom_provider("other-provider")
# "unknown-provider" partial-matches "known-provider" because "unknown" doesn't match
# but our matching is loose (substring). Let's verify a truly non-matching provider
result = rp._get_named_custom_provider("completely-different-name")
assert result is None
def test_get_named_custom_provider_skips_empty_base_url(monkeypatch, tmp_path):
"""Should skip providers without a base_url."""
config = {
"providers": {
"incomplete-provider": {
"name": "Incomplete",
# No api/base_url field
}
}
}
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
result = rp._get_named_custom_provider("incomplete-provider")
assert result is None
# =============================================================================
# Integration test for switch_model with user providers
# =============================================================================
def test_switch_model_resolves_user_provider_credentials(monkeypatch, tmp_path):
"""/model switch should resolve credentials for providers: dict providers."""
import yaml
config = {
"providers": {
"local-ollama": {
"api": "http://localhost:11434/v1",
"name": "Local Ollama",
"default_model": "minimax-m2.7:cloud",
}
}
}
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Mock validation to pass
monkeypatch.setattr(
"hermes_cli.models.validate_requested_model",
lambda *a, **k: {"accepted": True, "persist": True, "recognized": True, "message": None}
)
result = switch_model(
raw_input="kimi-k2.5:cloud",
current_provider="local-ollama",
current_model="minimax-m2.7:cloud",
current_base_url="http://localhost:11434/v1",
is_global=False,
user_providers=config["providers"],
)
assert result.success is True
assert result.error_message == ""

View File

@@ -0,0 +1,675 @@
"""Tests for hermes_cli.web_server and related config utilities."""
import os
import json
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
from hermes_cli.config import (
DEFAULT_CONFIG,
reload_env,
redact_key,
_EXTRA_ENV_KEYS,
OPTIONAL_ENV_VARS,
)
# ---------------------------------------------------------------------------
# reload_env tests
# ---------------------------------------------------------------------------
class TestReloadEnv:
"""Tests for reload_env() — re-reads .env into os.environ."""
def test_adds_new_vars(self, tmp_path):
"""reload_env() adds vars from .env that are not in os.environ."""
env_file = tmp_path / ".env"
env_file.write_text("TEST_RELOAD_VAR=hello123\n")
with patch("hermes_cli.config.get_env_path", return_value=env_file):
os.environ.pop("TEST_RELOAD_VAR", None)
count = reload_env()
assert count >= 1
assert os.environ.get("TEST_RELOAD_VAR") == "hello123"
os.environ.pop("TEST_RELOAD_VAR", None)
def test_updates_changed_vars(self, tmp_path):
"""reload_env() updates vars whose value changed on disk."""
env_file = tmp_path / ".env"
env_file.write_text("TEST_RELOAD_VAR=old_value\n")
with patch("hermes_cli.config.get_env_path", return_value=env_file):
os.environ["TEST_RELOAD_VAR"] = "old_value"
# Now change the file
env_file.write_text("TEST_RELOAD_VAR=new_value\n")
count = reload_env()
assert count >= 1
assert os.environ.get("TEST_RELOAD_VAR") == "new_value"
os.environ.pop("TEST_RELOAD_VAR", None)
def test_removes_deleted_known_vars(self, tmp_path):
"""reload_env() removes known Hermes vars not present in .env."""
env_file = tmp_path / ".env"
env_file.write_text("") # empty .env
# Pick a known key from OPTIONAL_ENV_VARS
known_key = next(iter(OPTIONAL_ENV_VARS.keys()))
with patch("hermes_cli.config.get_env_path", return_value=env_file):
os.environ[known_key] = "stale_value"
count = reload_env()
assert known_key not in os.environ
assert count >= 1
def test_does_not_remove_unknown_vars(self, tmp_path):
"""reload_env() preserves non-Hermes env vars even when absent from .env."""
env_file = tmp_path / ".env"
env_file.write_text("")
with patch("hermes_cli.config.get_env_path", return_value=env_file):
os.environ["MY_CUSTOM_UNRELATED_VAR"] = "keep_me"
reload_env()
assert os.environ.get("MY_CUSTOM_UNRELATED_VAR") == "keep_me"
os.environ.pop("MY_CUSTOM_UNRELATED_VAR", None)
# ---------------------------------------------------------------------------
# redact_key tests
# ---------------------------------------------------------------------------
class TestRedactKey:
def test_long_key_shows_prefix_suffix(self):
result = redact_key("sk-1234567890abcdef")
assert result.startswith("sk-1")
assert result.endswith("cdef")
assert "..." in result
def test_short_key_fully_masked(self):
assert redact_key("short") == "***"
def test_empty_key(self):
result = redact_key("")
assert "not set" in result.lower() or result == "***" or "\x1b" in result
# ---------------------------------------------------------------------------
# web_server tests (FastAPI endpoints)
# ---------------------------------------------------------------------------
class TestWebServerEndpoints:
"""Test the FastAPI REST endpoints using Starlette TestClient."""
@pytest.fixture(autouse=True)
def _setup_test_client(self):
"""Create a TestClient — import is deferred to avoid requiring fastapi."""
try:
from starlette.testclient import TestClient
except ImportError:
pytest.skip("fastapi/starlette not installed")
from hermes_cli.web_server import app
self.client = TestClient(app)
def test_get_status(self):
resp = self.client.get("/api/status")
assert resp.status_code == 200
data = resp.json()
assert "version" in data
assert "hermes_home" in data
assert "active_sessions" in data
def test_get_status_filters_unconfigured_gateway_platforms(self, monkeypatch):
import gateway.config as gateway_config
import hermes_cli.web_server as web_server
class _Platform:
def __init__(self, value):
self.value = value
class _GatewayConfig:
def get_connected_platforms(self):
return [_Platform("telegram")]
monkeypatch.setattr(web_server, "get_running_pid", lambda: 1234)
monkeypatch.setattr(
web_server,
"read_runtime_status",
lambda: {
"gateway_state": "running",
"updated_at": "2026-04-12T00:00:00+00:00",
"platforms": {
"telegram": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
"whatsapp": {"state": "retrying", "updated_at": "2026-04-12T00:00:00+00:00"},
"feishu": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
},
},
)
monkeypatch.setattr(web_server, "check_config_version", lambda: (1, 1))
monkeypatch.setattr(gateway_config, "load_gateway_config", lambda: _GatewayConfig())
resp = self.client.get("/api/status")
assert resp.status_code == 200
assert resp.json()["gateway_platforms"] == {
"telegram": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
}
def test_get_status_hides_stale_platforms_when_gateway_not_running(self, monkeypatch):
import gateway.config as gateway_config
import hermes_cli.web_server as web_server
class _GatewayConfig:
def get_connected_platforms(self):
return []
monkeypatch.setattr(web_server, "get_running_pid", lambda: None)
monkeypatch.setattr(
web_server,
"read_runtime_status",
lambda: {
"gateway_state": "startup_failed",
"updated_at": "2026-04-12T00:00:00+00:00",
"platforms": {
"whatsapp": {"state": "retrying", "updated_at": "2026-04-12T00:00:00+00:00"},
"feishu": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
},
},
)
monkeypatch.setattr(web_server, "check_config_version", lambda: (1, 1))
monkeypatch.setattr(gateway_config, "load_gateway_config", lambda: _GatewayConfig())
resp = self.client.get("/api/status")
assert resp.status_code == 200
assert resp.json()["gateway_state"] == "startup_failed"
assert resp.json()["gateway_platforms"] == {}
def test_get_config_schema(self):
resp = self.client.get("/api/config/schema")
assert resp.status_code == 200
data = resp.json()
assert "fields" in data
assert "category_order" in data
schema = data["fields"]
assert len(schema) > 100 # Should have 150+ fields
assert "model" in schema
# Verify category_order is a non-empty list
assert isinstance(data["category_order"], list)
assert len(data["category_order"]) > 0
assert "general" in data["category_order"]
def test_get_config_defaults(self):
resp = self.client.get("/api/config/defaults")
assert resp.status_code == 200
defaults = resp.json()
assert "model" in defaults
def test_get_env_vars(self):
resp = self.client.get("/api/env")
assert resp.status_code == 200
data = resp.json()
# Should contain known env var names
assert any(k.endswith("_API_KEY") or k.endswith("_TOKEN") for k in data.keys())
def test_reveal_env_var(self, tmp_path):
"""POST /api/env/reveal should return the real unredacted value."""
from hermes_cli.config import save_env_value
from hermes_cli.web_server import _SESSION_TOKEN
save_env_value("TEST_REVEAL_KEY", "super-secret-value-12345")
resp = self.client.post(
"/api/env/reveal",
json={"key": "TEST_REVEAL_KEY"},
headers={"Authorization": f"Bearer {_SESSION_TOKEN}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["key"] == "TEST_REVEAL_KEY"
assert data["value"] == "super-secret-value-12345"
def test_reveal_env_var_not_found(self):
"""POST /api/env/reveal should 404 for unknown keys."""
from hermes_cli.web_server import _SESSION_TOKEN
resp = self.client.post(
"/api/env/reveal",
json={"key": "NONEXISTENT_KEY_XYZ"},
headers={"Authorization": f"Bearer {_SESSION_TOKEN}"},
)
assert resp.status_code == 404
def test_reveal_env_var_no_token(self, tmp_path):
"""POST /api/env/reveal without token should return 401."""
from hermes_cli.config import save_env_value
save_env_value("TEST_REVEAL_NOAUTH", "secret-value")
resp = self.client.post(
"/api/env/reveal",
json={"key": "TEST_REVEAL_NOAUTH"},
)
assert resp.status_code == 401
def test_reveal_env_var_bad_token(self, tmp_path):
"""POST /api/env/reveal with wrong token should return 401."""
from hermes_cli.config import save_env_value
save_env_value("TEST_REVEAL_BADAUTH", "secret-value")
resp = self.client.post(
"/api/env/reveal",
json={"key": "TEST_REVEAL_BADAUTH"},
headers={"Authorization": "Bearer wrong-token-here"},
)
assert resp.status_code == 401
def test_session_token_endpoint(self):
"""GET /api/auth/session-token should return a token."""
from hermes_cli.web_server import _SESSION_TOKEN
resp = self.client.get("/api/auth/session-token")
assert resp.status_code == 200
assert resp.json()["token"] == _SESSION_TOKEN
def test_path_traversal_blocked(self):
"""Verify URL-encoded path traversal is blocked."""
# %2e%2e = ..
resp = self.client.get("/%2e%2e/%2e%2e/etc/passwd")
# Should return 200 with index.html (SPA fallback), not the actual file
assert resp.status_code in (200, 404)
if resp.status_code == 200:
# Should be the SPA fallback, not the system file
assert "root:" not in resp.text
def test_path_traversal_dotdot_blocked(self):
"""Direct .. path traversal via encoded sequences."""
resp = self.client.get("/%2e%2e/hermes_cli/web_server.py")
assert resp.status_code in (200, 404)
if resp.status_code == 200:
assert "FastAPI" not in resp.text # Should not serve the actual source
# ---------------------------------------------------------------------------
# _build_schema_from_config tests
# ---------------------------------------------------------------------------
class TestBuildSchemaFromConfig:
def test_produces_expected_field_count(self):
from hermes_cli.web_server import CONFIG_SCHEMA
# DEFAULT_CONFIG has ~150+ leaf fields
assert len(CONFIG_SCHEMA) > 100
def test_schema_entries_have_required_fields(self):
from hermes_cli.web_server import CONFIG_SCHEMA
for key, entry in list(CONFIG_SCHEMA.items())[:10]:
assert "type" in entry, f"Missing type for {key}"
assert "category" in entry, f"Missing category for {key}"
def test_overrides_applied(self):
from hermes_cli.web_server import CONFIG_SCHEMA
# terminal.backend should be a select with options
if "terminal.backend" in CONFIG_SCHEMA:
entry = CONFIG_SCHEMA["terminal.backend"]
assert entry["type"] == "select"
assert "options" in entry
assert "local" in entry["options"]
def test_empty_prefix_produces_correct_keys(self):
from hermes_cli.web_server import _build_schema_from_config
test_config = {"model": "test", "nested": {"key": "val"}}
schema = _build_schema_from_config(test_config)
assert "model" in schema
assert "nested.key" in schema
def test_top_level_scalars_get_general_category(self):
"""Top-level scalar fields should be in 'general' category."""
from hermes_cli.web_server import CONFIG_SCHEMA
assert CONFIG_SCHEMA["model"]["category"] == "general"
def test_nested_keys_get_parent_category(self):
"""Nested fields should use the top-level parent as their category."""
from hermes_cli.web_server import CONFIG_SCHEMA
if "agent.max_turns" in CONFIG_SCHEMA:
assert CONFIG_SCHEMA["agent.max_turns"]["category"] == "agent"
def test_category_merge_applied(self):
"""Small categories should be merged into larger ones."""
from hermes_cli.web_server import CONFIG_SCHEMA
categories = {e["category"] for e in CONFIG_SCHEMA.values()}
# These should be merged away
assert "privacy" not in categories # merged into security
assert "context" not in categories # merged into agent
def test_no_single_field_categories(self):
"""After merging, no category should have just 1 field."""
from hermes_cli.web_server import CONFIG_SCHEMA
from collections import Counter
cats = Counter(e["category"] for e in CONFIG_SCHEMA.values())
for cat, count in cats.items():
assert count >= 2, f"Category '{cat}' has only {count} field(s) — should be merged"
# ---------------------------------------------------------------------------
# Config round-trip tests
# ---------------------------------------------------------------------------
class TestConfigRoundTrip:
"""Verify config survives GET → edit → PUT without data loss."""
@pytest.fixture(autouse=True)
def _setup(self):
try:
from starlette.testclient import TestClient
except ImportError:
pytest.skip("fastapi/starlette not installed")
from hermes_cli.web_server import app
self.client = TestClient(app)
def test_get_config_no_internal_keys(self):
"""GET /api/config should not expose _config_version or _model_meta."""
config = self.client.get("/api/config").json()
internal = [k for k in config if k.startswith("_")]
assert not internal, f"Internal keys leaked to frontend: {internal}"
def test_get_config_model_is_string(self):
"""GET /api/config should normalize model dict to a string."""
config = self.client.get("/api/config").json()
assert isinstance(config.get("model"), str), \
f"model should be string, got {type(config.get('model'))}"
def test_round_trip_preserves_model_subkeys(self):
"""Save and reload should not lose model.provider, model.base_url, etc."""
from hermes_cli.config import load_config, save_config
# Set up a config with model as a dict (the common user config form)
save_config({
"model": {
"default": "anthropic/claude-sonnet-4",
"provider": "openrouter",
"base_url": "https://openrouter.ai/api/v1",
"api_mode": "openai",
}
})
before = load_config()
assert isinstance(before.get("model"), dict)
original_keys = set(before["model"].keys())
# GET → PUT unchanged
web_config = self.client.get("/api/config").json()
assert isinstance(web_config.get("model"), str), "GET should normalize model to string"
self.client.put("/api/config", json={"config": web_config})
after = load_config()
assert isinstance(after.get("model"), dict), "model should still be a dict after save"
assert set(after["model"].keys()) >= original_keys, \
f"Lost model subkeys: {original_keys - set(after['model'].keys())}"
def test_edit_model_name_preserved(self):
"""Changing the model string should update model.default on disk."""
from hermes_cli.config import load_config
web_config = self.client.get("/api/config").json()
original_model = web_config["model"]
# Change model
web_config["model"] = "test/editing-model"
self.client.put("/api/config", json={"config": web_config})
after = load_config()
if isinstance(after.get("model"), dict):
assert after["model"]["default"] == "test/editing-model"
else:
assert after["model"] == "test/editing-model"
# Restore
web_config["model"] = original_model
self.client.put("/api/config", json={"config": web_config})
def test_edit_nested_value(self):
"""Editing a nested config value should persist correctly."""
from hermes_cli.config import load_config
web_config = self.client.get("/api/config").json()
original_turns = web_config.get("agent", {}).get("max_turns")
# Change max_turns
if "agent" not in web_config:
web_config["agent"] = {}
web_config["agent"]["max_turns"] = 42
self.client.put("/api/config", json={"config": web_config})
after = load_config()
assert after.get("agent", {}).get("max_turns") == 42
# Restore
web_config["agent"]["max_turns"] = original_turns
self.client.put("/api/config", json={"config": web_config})
def test_schema_types_match_config_values(self):
"""Every schema field should have a matching-type value in the config."""
config = self.client.get("/api/config").json()
schema_resp = self.client.get("/api/config/schema").json()
schema = schema_resp["fields"]
def get_nested(obj, path):
parts = path.split(".")
cur = obj
for p in parts:
if cur is None or not isinstance(cur, dict):
return None
cur = cur.get(p)
return cur
mismatches = []
for key, entry in schema.items():
val = get_nested(config, key)
if val is None:
continue # not set in user config — fine
expected = entry["type"]
if expected in ("string", "select") and not isinstance(val, str):
mismatches.append(f"{key}: expected str, got {type(val).__name__}")
elif expected == "number" and not isinstance(val, (int, float)):
mismatches.append(f"{key}: expected number, got {type(val).__name__}")
elif expected == "boolean" and not isinstance(val, bool):
mismatches.append(f"{key}: expected bool, got {type(val).__name__}")
elif expected == "list" and not isinstance(val, list):
mismatches.append(f"{key}: expected list, got {type(val).__name__}")
assert not mismatches, f"Type mismatches:\n" + "\n".join(mismatches)
# ---------------------------------------------------------------------------
# New feature endpoint tests
# ---------------------------------------------------------------------------
class TestNewEndpoints:
"""Tests for session detail, logs, cron, skills, tools, raw config, analytics."""
@pytest.fixture(autouse=True)
def _setup(self):
try:
from starlette.testclient import TestClient
except ImportError:
pytest.skip("fastapi/starlette not installed")
from hermes_cli.web_server import app
self.client = TestClient(app)
def test_get_logs_default(self):
resp = self.client.get("/api/logs")
assert resp.status_code == 200
data = resp.json()
assert "file" in data
assert "lines" in data
assert isinstance(data["lines"], list)
def test_get_logs_invalid_file(self):
resp = self.client.get("/api/logs?file=nonexistent")
assert resp.status_code == 400
def test_cron_list(self):
resp = self.client.get("/api/cron/jobs")
assert resp.status_code == 200
assert isinstance(resp.json(), list)
def test_cron_job_not_found(self):
resp = self.client.get("/api/cron/jobs/nonexistent-id")
assert resp.status_code == 404
def test_skills_list(self):
resp = self.client.get("/api/skills")
assert resp.status_code == 200
skills = resp.json()
assert isinstance(skills, list)
if skills:
assert "name" in skills[0]
assert "enabled" in skills[0]
def test_skills_list_includes_disabled_skills(self, monkeypatch):
import tools.skills_tool as skills_tool
import hermes_cli.skills_config as skills_config
import hermes_cli.web_server as web_server
def _fake_find_all_skills(*, skip_disabled=False):
if skip_disabled:
return [
{"name": "active-skill", "description": "active", "category": "demo"},
{"name": "disabled-skill", "description": "disabled", "category": "demo"},
]
return [
{"name": "active-skill", "description": "active", "category": "demo"},
]
monkeypatch.setattr(skills_tool, "_find_all_skills", _fake_find_all_skills)
monkeypatch.setattr(skills_config, "get_disabled_skills", lambda config: {"disabled-skill"})
monkeypatch.setattr(web_server, "load_config", lambda: {"skills": {"disabled": ["disabled-skill"]}})
resp = self.client.get("/api/skills")
assert resp.status_code == 200
assert resp.json() == [
{
"name": "active-skill",
"description": "active",
"category": "demo",
"enabled": True,
},
{
"name": "disabled-skill",
"description": "disabled",
"category": "demo",
"enabled": False,
},
]
def test_toolsets_list(self):
resp = self.client.get("/api/tools/toolsets")
assert resp.status_code == 200
toolsets = resp.json()
assert isinstance(toolsets, list)
if toolsets:
assert "name" in toolsets[0]
assert "label" in toolsets[0]
assert "enabled" in toolsets[0]
def test_toolsets_list_matches_cli_enabled_state(self, monkeypatch):
import hermes_cli.tools_config as tools_config
import toolsets as toolsets_module
import hermes_cli.web_server as web_server
monkeypatch.setattr(
tools_config,
"_get_effective_configurable_toolsets",
lambda: [
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
("skills", "📚 Skills", "list, view, manage"),
("memory", "💾 Memory", "persistent memory across sessions"),
],
)
monkeypatch.setattr(
tools_config,
"_get_platform_tools",
lambda config, platform, include_default_mcp_servers=False: {"web", "skills"},
)
monkeypatch.setattr(
tools_config,
"_toolset_has_keys",
lambda ts_key, config=None: ts_key != "web",
)
monkeypatch.setattr(
toolsets_module,
"resolve_toolset",
lambda name: {
"web": ["web_search", "web_extract"],
"skills": ["skills_list", "skill_view"],
"memory": ["memory_read"],
}[name],
)
monkeypatch.setattr(web_server, "load_config", lambda: {"platform_toolsets": {"cli": ["web", "skills"]}})
resp = self.client.get("/api/tools/toolsets")
assert resp.status_code == 200
assert resp.json() == [
{
"name": "web",
"label": "🔍 Web Search & Scraping",
"description": "web_search, web_extract",
"enabled": True,
"available": True,
"configured": False,
"tools": ["web_extract", "web_search"],
},
{
"name": "skills",
"label": "📚 Skills",
"description": "list, view, manage",
"enabled": True,
"available": True,
"configured": True,
"tools": ["skill_view", "skills_list"],
},
{
"name": "memory",
"label": "💾 Memory",
"description": "persistent memory across sessions",
"enabled": False,
"available": False,
"configured": True,
"tools": ["memory_read"],
},
]
def test_config_raw_get(self):
resp = self.client.get("/api/config/raw")
assert resp.status_code == 200
assert "yaml" in resp.json()
def test_config_raw_put_valid(self):
resp = self.client.put(
"/api/config/raw",
json={"yaml_text": "model: test\ntoolsets:\n - all\n"},
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_config_raw_put_invalid(self):
resp = self.client.put(
"/api/config/raw",
json={"yaml_text": "- this is a list not a dict"},
)
assert resp.status_code == 400
def test_analytics_usage(self):
resp = self.client.get("/api/analytics/usage?days=7")
assert resp.status_code == 200
data = resp.json()
assert "daily" in data
assert "by_model" in data
assert "totals" in data
assert isinstance(data["daily"], list)
assert "total_sessions" in data["totals"]
def test_session_token_endpoint(self):
from hermes_cli.web_server import _SESSION_TOKEN
resp = self.client.get("/api/auth/session-token")
assert resp.status_code == 200
assert resp.json()["token"] == _SESSION_TOKEN

View File

@@ -102,7 +102,19 @@ class _PromptTooLongError(Exception):
self.status_code = 400
class _FakeMessages:
"""Stub for client.messages.create() / client.messages.stream()."""
def create(self, **kwargs):
raise NotImplementedError("_FakeAnthropicClient.messages.create should not be called directly in tests")
def stream(self, **kwargs):
raise NotImplementedError("_FakeAnthropicClient.messages.stream should not be called directly in tests")
class _FakeAnthropicClient:
def __init__(self):
self.messages = _FakeMessages()
def close(self):
pass
@@ -131,13 +143,14 @@ def _make_agent_cls(error_cls, recover_after=None):
def run_conversation(self, user_message, conversation_history=None, task_id=None):
calls = {"n": 0}
def _fake_api_call(api_kwargs):
def _fake_api_call(api_kwargs, **kw):
calls["n"] += 1
if recover_after is not None and calls["n"] > recover_after:
return _anthropic_response("Recovered")
raise error_cls()
self._interruptible_api_call = _fake_api_call
self._interruptible_streaming_api_call = _fake_api_call
return super().run_conversation(
user_message, conversation_history=conversation_history, task_id=task_id
)
@@ -352,10 +365,11 @@ def test_401_refresh_fails_is_non_retryable(monkeypatch):
return False # Simulate failed credential refresh
def run_conversation(self, user_message, conversation_history=None, task_id=None):
def _fake_api_call(api_kwargs):
def _fake_api_call(api_kwargs, **kw):
raise _UnauthorizedError()
self._interruptible_api_call = _fake_api_call
self._interruptible_streaming_api_call = _fake_api_call
return super().run_conversation(
user_message, conversation_history=conversation_history, task_id=task_id
)
@@ -436,13 +450,14 @@ def test_prompt_too_long_triggers_compression(monkeypatch):
def run_conversation(self, user_message, conversation_history=None, task_id=None):
calls = {"n": 0}
def _fake_api_call(api_kwargs):
def _fake_api_call(api_kwargs, **kw):
calls["n"] += 1
if calls["n"] == 1:
raise _PromptTooLongError()
return _anthropic_response("Compressed and recovered")
self._interruptible_api_call = _fake_api_call
self._interruptible_streaming_api_call = _fake_api_call
return super().run_conversation(
user_message, conversation_history=conversation_history, task_id=task_id
)

View File

@@ -38,6 +38,7 @@ def _make_agent(
agent.status_callback = None
agent.tool_progress_callback = None
agent._compression_warning = None
agent.config = None
compressor = MagicMock(spec=ContextCompressor)
compressor.context_length = main_context
@@ -130,6 +131,64 @@ def test_feasibility_check_passes_live_main_runtime():
)
@patch("agent.model_metadata.get_model_context_length", return_value=1_000_000)
@patch("agent.auxiliary_client.get_text_auxiliary_client")
def test_feasibility_check_passes_config_context_length(mock_get_client, mock_ctx_len):
"""auxiliary.compression.context_length from config is forwarded to
get_model_context_length so custom endpoints that lack /models still
report the correct context window (fixes #8499)."""
agent = _make_agent(main_context=200_000, threshold_percent=0.85)
agent.config = {
"auxiliary": {
"compression": {
"context_length": 1_000_000,
},
},
}
mock_client = MagicMock()
mock_client.base_url = "http://custom-endpoint:8080/v1"
mock_client.api_key = "sk-custom"
mock_get_client.return_value = (mock_client, "custom/big-model")
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_ctx_len.assert_called_once_with(
"custom/big-model",
base_url="http://custom-endpoint:8080/v1",
api_key="sk-custom",
config_context_length=1_000_000,
)
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
@patch("agent.auxiliary_client.get_text_auxiliary_client")
def test_feasibility_check_ignores_invalid_context_length(mock_get_client, mock_ctx_len):
"""Non-integer context_length in config is silently ignored."""
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
agent.config = {
"auxiliary": {
"compression": {
"context_length": "not-a-number",
},
},
}
mock_client = MagicMock()
mock_client.base_url = "http://custom:8080/v1"
mock_client.api_key = "sk-test"
mock_get_client.return_value = (mock_client, "custom/model")
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_ctx_len.assert_called_once_with(
"custom/model",
base_url="http://custom:8080/v1",
api_key="sk-test",
config_context_length=None,
)
@patch("agent.auxiliary_client.get_text_auxiliary_client")
def test_warns_when_no_auxiliary_provider(mock_get_client):
"""Warning emitted when no auxiliary provider is configured."""

View File

@@ -56,6 +56,7 @@ def _make_agent(monkeypatch, api_mode, provider, response_fn):
def run_conversation(self, msg, conversation_history=None, task_id=None):
self._interruptible_api_call = lambda kw: response_fn()
self._disable_streaming = True
return super().run_conversation(msg, conversation_history=conversation_history, task_id=task_id)
return _A(model="test-model", api_key="test-key", provider=provider, api_mode=api_mode)

View File

@@ -66,6 +66,7 @@ def test_tool_call_validation_accepts_dict_arguments(monkeypatch):
quiet_mode=True,
skip_memory=True,
)
agent._disable_streaming = True
result = agent.run_conversation("read the file")

View File

@@ -0,0 +1,89 @@
"""Tests that plugin context engines get update_model() called during init.
Regression test for #9071 — plugin engines were never initialized with
context_length, causing the CLI status bar to show 'ctx --'.
"""
from unittest.mock import MagicMock, patch
from agent.context_engine import ContextEngine
class _StubEngine(ContextEngine):
"""Minimal concrete context engine for testing."""
@property
def name(self) -> str:
return "stub"
def update_from_response(self, usage):
pass
def should_compress(self, prompt_tokens=None):
return False
def compress(self, messages, current_tokens=None):
return messages
def test_plugin_engine_gets_context_length_on_init():
"""Plugin context engine should have context_length set during AIAgent init."""
engine = _StubEngine()
assert engine.context_length == 0 # ABC default before fix
cfg = {"context": {"engine": "stub"}, "agent": {}}
with (
patch("hermes_cli.config.load_config", return_value=cfg),
patch("plugins.context_engine.load_context_engine", return_value=engine),
patch("agent.model_metadata.get_model_context_length", return_value=204_800),
patch("run_agent.get_tool_definitions", return_value=[]),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
from run_agent import AIAgent
agent = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
assert agent.context_compressor is engine
assert engine.context_length == 204_800
assert engine.threshold_tokens == int(204_800 * engine.threshold_percent)
def test_plugin_engine_update_model_args():
"""Verify update_model() receives model, context_length, base_url, api_key, provider."""
engine = _StubEngine()
engine.update_model = MagicMock()
cfg = {"context": {"engine": "stub"}, "agent": {}}
with (
patch("hermes_cli.config.load_config", return_value=cfg),
patch("plugins.context_engine.load_context_engine", return_value=engine),
patch("agent.model_metadata.get_model_context_length", return_value=131_072),
patch("run_agent.get_tool_definitions", return_value=[]),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
from run_agent import AIAgent
agent = AIAgent(
model="openrouter/auto",
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
engine.update_model.assert_called_once()
kw = engine.update_model.call_args.kwargs
assert kw["context_length"] == 131_072
assert "model" in kw
assert "provider" in kw
# Should NOT pass api_mode — the ABC doesn't accept it
assert "api_mode" not in kw

View File

@@ -44,11 +44,11 @@ class _FakeOpenAI:
pass
def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="https://openrouter.ai/api/v1"):
def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="https://openrouter.ai/api/v1", model=None):
monkeypatch.setattr("run_agent.get_tool_definitions", lambda **kw: _tool_defs("web_search", "terminal"))
monkeypatch.setattr("run_agent.check_toolset_requirements", lambda: {})
monkeypatch.setattr("run_agent.OpenAI", _FakeOpenAI)
return AIAgent(
kwargs = dict(
api_key="test-key",
base_url=base_url,
provider=provider,
@@ -58,6 +58,9 @@ def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="ht
skip_context_files=True,
skip_memory=True,
)
if model:
kwargs["model"] = model
return AIAgent(**kwargs)
# ── _build_api_kwargs tests ─────────────────────────────────────────────────
@@ -247,7 +250,7 @@ class TestBuildApiKwargsChatCompletionsServiceTier:
class TestBuildApiKwargsAIGateway:
def test_uses_chat_completions_format(self, monkeypatch):
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert "messages" in kwargs
@@ -255,7 +258,7 @@ class TestBuildApiKwargsAIGateway:
assert kwargs["messages"][-1]["content"] == "hi"
def test_no_responses_api_fields(self, monkeypatch):
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert "input" not in kwargs
@@ -263,7 +266,7 @@ class TestBuildApiKwargsAIGateway:
assert "store" not in kwargs
def test_includes_reasoning_in_extra_body(self, monkeypatch):
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
extra = kwargs.get("extra_body", {})
@@ -271,7 +274,7 @@ class TestBuildApiKwargsAIGateway:
assert extra["reasoning"]["enabled"] is True
def test_includes_tools(self, monkeypatch):
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert "tools" in kwargs

View File

@@ -76,7 +76,8 @@ class TestRealSubagentInterrupt(unittest.TestCase):
parent._delegate_spinner = None
parent.tool_progress_callback = None
parent.iteration_budget = IterationBudget(max_total=100)
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
parent._client_kwargs = {"api_key": "***", "base_url": "http://localhost:1"}
parent._execution_thread_id = None
from tools.delegate_tool import _run_single_child

View File

@@ -302,6 +302,17 @@ class TestStripThinkBlocks:
assert "<think>" not in result
assert "visible" in result
def test_thought_block_removed(self, agent):
"""Gemma 4 uses <thought> tags for inline reasoning."""
result = agent._strip_think_blocks("<thought>internal reasoning</thought> answer")
assert "internal reasoning" not in result
assert "<thought>" not in result
assert "answer" in result
def test_orphaned_thought_tag(self, agent):
result = agent._strip_think_blocks("<thought>orphaned reasoning without close")
assert "<thought>" not in result
class TestExtractReasoning:
def test_reasoning_field(self, agent):
@@ -869,6 +880,7 @@ class TestBuildApiKwargs:
assert kwargs["extra_body"]["reasoning"] == {"enabled": False}
def test_reasoning_not_sent_for_unsupported_openrouter_model(self, agent):
agent.base_url = "https://openrouter.ai/api/v1"
agent.model = "minimax/minimax-m2.5"
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
@@ -1564,6 +1576,7 @@ class TestHandleMaxIterations:
assert "API down" in result
def test_summary_skips_reasoning_for_unsupported_openrouter_model(self, agent):
agent.base_url = "https://openrouter.ai/api/v1"
agent.model = "minimax/minimax-m2.5"
resp = _mock_response(content="Summary")
agent.client.chat.completions.create.return_value = resp
@@ -1694,27 +1707,6 @@ class TestRunConversation:
assert result["completed"] is True
assert result["api_calls"] == 2
def test_inline_think_blocks_reasoning_only_accepted(self, agent):
"""Inline <think> reasoning-only responses accepted with (empty) content, no retries."""
self._setup_agent(agent)
empty_resp = _mock_response(
content="<think>internal reasoning</think>",
finish_reason="stop",
)
agent.client.chat.completions.create.side_effect = [empty_resp]
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("answer me")
assert result["completed"] is True
assert result["final_response"] == "(empty)"
assert result["api_calls"] == 1 # no retries
# Reasoning should be preserved in the assistant message
assistant_msgs = [m for m in result["messages"] if m.get("role") == "assistant"]
assert any(m.get("reasoning") for m in assistant_msgs)
def test_reasoning_only_local_resumed_no_compression_triggered(self, agent):
"""Reasoning-only responses no longer trigger compression — prefill then accepted."""
self._setup_agent(agent)
@@ -1730,9 +1722,9 @@ class TestRunConversation:
{"role": "assistant", "content": "old answer"},
]
# 3 responses: original + 2 prefill continuations (structured reasoning triggers prefill)
# 6 responses: original + 2 prefill + 3 retries after prefill exhaustion
with (
patch.object(agent, "_interruptible_api_call", side_effect=[empty_resp, empty_resp, empty_resp]),
patch.object(agent, "_interruptible_api_call", side_effect=[empty_resp] * 6),
patch.object(agent, "_compress_context") as mock_compress,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
@@ -1743,18 +1735,18 @@ class TestRunConversation:
mock_compress.assert_not_called() # no compression triggered
assert result["completed"] is True
assert result["final_response"] == "(empty)"
assert result["api_calls"] == 3 # 1 original + 2 prefill continuations
assert result["api_calls"] == 6 # 1 original + 2 prefill + 3 retries
def test_reasoning_only_response_prefill_then_empty(self, agent):
"""Structured reasoning-only triggers prefill continuation (up to 2), then falls through to (empty)."""
"""Structured reasoning-only triggers prefill (2), then retries (3), then (empty)."""
self._setup_agent(agent)
empty_resp = _mock_response(
content=None,
finish_reason="stop",
reasoning_content="structured reasoning answer",
)
# 3 responses: original + 2 prefill continuations, all reasoning-only
agent.client.chat.completions.create.side_effect = [empty_resp, empty_resp, empty_resp]
# 6 responses: 1 original + 2 prefill + 3 retries after prefill exhaustion
agent.client.chat.completions.create.side_effect = [empty_resp] * 6
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
@@ -1763,7 +1755,7 @@ class TestRunConversation:
result = agent.run_conversation("answer me")
assert result["completed"] is True
assert result["final_response"] == "(empty)"
assert result["api_calls"] == 3 # 1 original + 2 prefill continuations
assert result["api_calls"] == 6 # 1 original + 2 prefill + 3 retries
def test_reasoning_only_prefill_succeeds_on_continuation(self, agent):
"""When prefill continuation produces content, it becomes the final response."""
@@ -1938,6 +1930,88 @@ class TestRunConversation:
failure_msgs = [m for m in status_messages if "no content" in m.lower() or "no fallback" in m.lower()]
assert len(failure_msgs) >= 1, f"Expected at least 1 failure status, got: {status_messages}"
def test_partial_stream_recovery_uses_streamed_content(self, agent):
"""When streaming fails after partial delivery, recovered partial content becomes final response."""
self._setup_agent(agent)
# Simulate a partial-stream-stub response: content recovered from streaming
partial_resp = _mock_response(
content="Here is the partial answer that was stream",
finish_reason="stop",
)
agent.client.chat.completions.create.return_value = partial_resp
# Simulate that streaming had already delivered this text
agent._current_streamed_assistant_text = "Here is the partial answer that was stream"
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("explain something")
# The partial content should be used as-is (not empty, not retried)
assert result["completed"] is True
assert result["final_response"] == "Here is the partial answer that was stream"
assert result["api_calls"] == 1 # No retries
def test_partial_stream_recovery_on_empty_stub(self, agent):
"""When stub response has no content but text was streamed, use streamed text."""
self._setup_agent(agent)
# Stub response with no content (old behavior before fix)
empty_stub = _mock_response(content=None, finish_reason="stop")
def _fake_api_call(api_kwargs):
# Simulate what streaming does: accumulate text before returning
# a stub with no content (connection died mid-stream)
agent._current_streamed_assistant_text = "The answer to your question is that"
return empty_stub
status_messages = []
def _capture_status(msg):
status_messages.append(msg)
with (
patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call),
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
patch.object(agent, "_emit_status", side_effect=_capture_status),
):
result = agent.run_conversation("ask me")
# Should recover partial streamed content, not fall through to (empty)
assert result["completed"] is True
assert result["final_response"] == "The answer to your question is that"
assert result["api_calls"] == 1 # No wasted retries
# Should emit the stream-interrupted status, NOT the empty-retry status
recovery_msgs = [m for m in status_messages if "stream interrupted" in m.lower()]
assert len(recovery_msgs) >= 1, f"Expected stream recovery status, got: {status_messages}"
# Should NOT have retry statuses
retry_msgs = [m for m in status_messages if "retrying" in m.lower()]
assert len(retry_msgs) == 0, f"Should not retry when stream content exists: {status_messages}"
def test_partial_stream_recovery_preempts_prior_turn_fallback(self, agent):
"""Partial streamed content takes priority over _last_content_with_tools fallback."""
self._setup_agent(agent)
# Set up the prior-turn fallback content (from a previous turn with tool calls)
agent._last_content_with_tools = "Old content from prior turn with tools"
# Stub response with no content
empty_stub = _mock_response(content=None, finish_reason="stop")
def _fake_api_call(api_kwargs):
# Simulate partial streaming before connection death
agent._current_streamed_assistant_text = "Fresh partial content from this turn"
return empty_stub
with (
patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call),
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("question")
# Should use the streamed content, not the old prior-turn fallback
assert result["final_response"] == "Fresh partial content from this turn"
assert result["api_calls"] == 1
def test_nous_401_refreshes_after_remint_and_retries(self, agent):
self._setup_agent(agent)
agent.provider = "nous"
@@ -3426,8 +3500,8 @@ class TestStreamingApiCall:
call_kwargs = agent.client.chat.completions.create.call_args
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
def test_api_exception_falls_back_to_non_streaming(self, agent):
"""When streaming fails before any deltas, fallback to non-streaming is attempted."""
def test_api_exception_propagates_no_non_streaming_fallback(self, agent):
"""When streaming fails before any deltas, error propagates to the main retry loop."""
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
# Prevent stream retry logic from replacing the mock client
with patch.object(agent, "_replace_primary_openai_client", return_value=False):

View File

@@ -243,6 +243,22 @@ def test_api_mode_respects_explicit_openrouter_provider_over_codex_url(monkeypat
assert agent.provider == "openrouter"
def test_copilot_acp_stays_on_chat_completions_for_gpt_5_models(monkeypatch):
_patch_agent_bootstrap(monkeypatch)
agent = run_agent.AIAgent(
model="gpt-5.4-mini",
base_url="acp://copilot",
provider="copilot-acp",
api_key="copilot-acp",
quiet_mode=True,
max_iterations=1,
skip_context_files=True,
skip_memory=True,
)
assert agent.provider == "copilot-acp"
assert agent.api_mode == "chat_completions"
def test_build_api_kwargs_codex(monkeypatch):
agent = _build_agent(monkeypatch)
kwargs = agent._build_api_kwargs(

View File

@@ -291,6 +291,38 @@ class TestStreamingCallbacks:
assert len(first_delta_calls) == 1
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_chat_stream_refreshes_activity_on_every_chunk(self, mock_close, mock_create):
"""Each streamed chat chunk should refresh the activity timestamp."""
from run_agent import AIAgent
chunks = [
_make_stream_chunk(content="a"),
_make_stream_chunk(content="b"),
_make_stream_chunk(finish_reason="stop"),
]
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = iter(chunks)
mock_create.return_value = mock_client
agent = AIAgent(
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
touch_calls = []
agent._touch_activity = lambda desc: touch_calls.append(desc)
agent._interruptible_streaming_api_call({})
assert touch_calls.count("receiving stream response") == len(chunks)
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_tool_only_does_not_fire_callback(self, mock_close, mock_create):
@@ -374,13 +406,19 @@ class TestStreamingCallbacks:
class TestStreamingFallback:
"""Verify fallback to non-streaming on ANY streaming error."""
"""Verify streaming errors propagate to the main retry loop.
Previously, streaming errors triggered an inline fallback to
non-streaming. Now they propagate so the main retry loop can apply
richer recovery (credential rotation, provider fallback, backoff).
The only special case: 'stream not supported' sets _disable_streaming
so the *next* main-loop retry uses non-streaming automatically.
"""
@patch("run_agent.AIAgent._interruptible_api_call")
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
"""'not supported' error triggers fallback to non-streaming."""
def test_stream_not_supported_sets_flag_and_raises(self, mock_close, mock_create):
"""'not supported' error sets _disable_streaming and propagates."""
from run_agent import AIAgent
mock_client = MagicMock()
@@ -389,23 +427,6 @@ class TestStreamingFallback:
)
mock_create.return_value = mock_client
fallback_response = SimpleNamespace(
id="fallback",
model="test",
choices=[SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content="fallback response",
tool_calls=None,
reasoning_content=None,
),
finish_reason="stop",
)],
usage=None,
)
mock_non_stream.return_value = fallback_response
agent = AIAgent(
model="test/model",
quiet_mode=True,
@@ -415,16 +436,16 @@ class TestStreamingFallback:
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
response = agent._interruptible_streaming_api_call({})
with pytest.raises(Exception, match="Streaming is not supported"):
agent._interruptible_streaming_api_call({})
assert response.choices[0].message.content == "fallback response"
mock_non_stream.assert_called_once()
# The flag should be set so the main retry loop switches to non-streaming
assert agent._disable_streaming is True
@patch("run_agent.AIAgent._interruptible_api_call")
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_any_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
"""ANY streaming error triggers fallback — not just specific messages."""
def test_non_transport_error_propagates(self, mock_close, mock_create):
"""Non-transport streaming errors propagate to the main retry loop."""
from run_agent import AIAgent
mock_client = MagicMock()
@@ -433,23 +454,6 @@ class TestStreamingFallback:
)
mock_create.return_value = mock_client
fallback_response = SimpleNamespace(
id="fallback",
model="test",
choices=[SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content="fallback after connection error",
tool_calls=None,
reasoning_content=None,
),
finish_reason="stop",
)],
usage=None,
)
mock_non_stream.return_value = fallback_response
agent = AIAgent(
model="test/model",
quiet_mode=True,
@@ -459,24 +463,19 @@ class TestStreamingFallback:
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
response = agent._interruptible_streaming_api_call({})
with pytest.raises(Exception, match="Connection reset by peer"):
agent._interruptible_streaming_api_call({})
assert response.choices[0].message.content == "fallback after connection error"
mock_non_stream.assert_called_once()
@patch("run_agent.AIAgent._interruptible_api_call")
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_fallback_error_propagates(self, mock_close, mock_create, mock_non_stream):
"""When both streaming AND fallback fail, the fallback error propagates."""
def test_stream_error_propagates_original(self, mock_close, mock_create):
"""The original streaming error propagates (not a fallback error)."""
from run_agent import AIAgent
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = Exception("stream broke")
mock_create.return_value = mock_client
mock_non_stream.side_effect = Exception("Rate limit exceeded")
agent = AIAgent(
model="test/model",
quiet_mode=True,
@@ -486,14 +485,13 @@ class TestStreamingFallback:
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
with pytest.raises(Exception, match="Rate limit exceeded"):
with pytest.raises(Exception, match="stream broke"):
agent._interruptible_streaming_api_call({})
@patch("run_agent.AIAgent._interruptible_api_call")
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_exhausted_transient_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
"""Transient stream errors retry first, then fall back after retries are exhausted."""
def test_exhausted_transient_stream_error_propagates(self, mock_close, mock_create):
"""Transient stream errors retry first, then propagate after retries exhausted."""
from run_agent import AIAgent
import httpx
@@ -501,23 +499,6 @@ class TestStreamingFallback:
mock_client.chat.completions.create.side_effect = httpx.ConnectError("socket closed")
mock_create.return_value = mock_client
fallback_response = SimpleNamespace(
id="fallback",
model="test",
choices=[SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content="fallback after retries exhausted",
tool_calls=None,
reasoning_content=None,
),
finish_reason="stop",
)],
usage=None,
)
mock_non_stream.return_value = fallback_response
agent = AIAgent(
model="test/model",
quiet_mode=True,
@@ -527,23 +508,22 @@ class TestStreamingFallback:
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
response = agent._interruptible_streaming_api_call({})
with pytest.raises(httpx.ConnectError, match="socket closed"):
agent._interruptible_streaming_api_call({})
assert response.choices[0].message.content == "fallback after retries exhausted"
# Should have retried 3 times (default HERMES_STREAM_RETRIES=2 → 3 attempts)
assert mock_client.chat.completions.create.call_count == 3
mock_non_stream.assert_called_once()
assert mock_close.call_count >= 1
@patch("run_agent.AIAgent._interruptible_api_call")
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_sse_connection_lost_retried_as_transient(self, mock_close, mock_create, mock_non_stream):
def test_sse_connection_lost_retried_as_transient(self, mock_close, mock_create):
"""SSE 'Network connection lost' (APIError w/ no status_code) retries like httpx errors.
OpenRouter sends {"error":{"message":"Network connection lost."}} as an SSE
event when the upstream stream drops. The OpenAI SDK raises APIError from
this. It should be retried at the streaming level, same as httpx connection
errors, before falling back to non-streaming.
errors, then propagate to the main retry loop after exhaustion.
"""
from run_agent import AIAgent
import httpx
@@ -561,23 +541,6 @@ class TestStreamingFallback:
mock_client.chat.completions.create.side_effect = sse_error
mock_create.return_value = mock_client
fallback_response = SimpleNamespace(
id="fallback",
model="test",
choices=[SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content="fallback after SSE retries",
tool_calls=None,
reasoning_content=None,
),
finish_reason="stop",
)],
usage=None,
)
mock_non_stream.return_value = fallback_response
agent = AIAgent(
model="test/model",
quiet_mode=True,
@@ -587,21 +550,18 @@ class TestStreamingFallback:
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
response = agent._interruptible_streaming_api_call({})
with pytest.raises(OAIAPIError):
agent._interruptible_streaming_api_call({})
assert response.choices[0].message.content == "fallback after SSE retries"
# Should retry 3 times (default HERMES_STREAM_RETRIES=2 → 3 attempts)
# before falling back to non-streaming
assert mock_client.chat.completions.create.call_count == 3
mock_non_stream.assert_called_once()
# Connection cleanup should happen for each failed retry
assert mock_close.call_count >= 2
@patch("run_agent.AIAgent._interruptible_api_call")
@patch("run_agent.AIAgent._create_request_openai_client")
@patch("run_agent.AIAgent._close_request_openai_client")
def test_sse_non_connection_error_falls_back_immediately(self, mock_close, mock_create, mock_non_stream):
"""SSE errors that aren't connection-related still fall back immediately (no stream retry)."""
def test_sse_non_connection_error_propagates_immediately(self, mock_close, mock_create):
"""SSE errors that aren't connection-related propagate immediately (no stream retry)."""
from run_agent import AIAgent
import httpx
@@ -616,23 +576,6 @@ class TestStreamingFallback:
mock_client.chat.completions.create.side_effect = sse_error
mock_create.return_value = mock_client
fallback_response = SimpleNamespace(
id="fallback",
model="test",
choices=[SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content="fallback no retry",
tool_calls=None,
reasoning_content=None,
),
finish_reason="stop",
)],
usage=None,
)
mock_non_stream.return_value = fallback_response
agent = AIAgent(
model="test/model",
quiet_mode=True,
@@ -642,12 +585,11 @@ class TestStreamingFallback:
agent.api_mode = "chat_completions"
agent._interrupt_requested = False
response = agent._interruptible_streaming_api_call({})
with pytest.raises(OAIAPIError):
agent._interruptible_streaming_api_call({})
assert response.choices[0].message.content == "fallback no retry"
# Should NOT retry — goes straight to non-streaming fallback
# Should NOT retry — propagates immediately
assert mock_client.chat.completions.create.call_count == 1
mock_non_stream.assert_called_once()
# ── Test: Reasoning Streaming ────────────────────────────────────────────
@@ -783,6 +725,55 @@ class TestCodexStreamCallbacks:
response = agent._run_codex_stream({}, client=mock_client)
assert "Hello from Codex!" in deltas
def test_codex_stream_refreshes_activity_on_every_event(self):
from run_agent import AIAgent
agent = AIAgent(
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "codex_responses"
agent._interrupt_requested = False
touch_calls = []
agent._touch_activity = lambda desc: touch_calls.append(desc)
mock_event_text_1 = SimpleNamespace(
type="response.output_text.delta",
delta="Hello",
)
mock_event_text_2 = SimpleNamespace(
type="response.output_text.delta",
delta=" world",
)
mock_event_done = SimpleNamespace(
type="response.completed",
delta="",
)
mock_stream = MagicMock()
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
mock_stream.__exit__ = MagicMock(return_value=False)
mock_stream.__iter__ = MagicMock(
return_value=iter([mock_event_text_1, mock_event_text_2, mock_event_done])
)
mock_stream.get_final_response.return_value = SimpleNamespace(
output=[SimpleNamespace(
type="message",
content=[SimpleNamespace(type="output_text", text="Hello world")],
)],
status="completed",
)
mock_client = MagicMock()
mock_client.responses.stream.return_value = mock_stream
agent._run_codex_stream({}, client=mock_client)
assert touch_calls.count("receiving stream response") == 3
def test_codex_remote_protocol_error_falls_back_to_create_stream(self):
from run_agent import AIAgent
import httpx
@@ -814,3 +805,102 @@ class TestCodexStreamCallbacks:
assert response is fallback_response
mock_fallback.assert_called_once_with({}, client=mock_client)
def test_codex_create_stream_fallback_refreshes_activity_on_every_event(self):
from run_agent import AIAgent
agent = AIAgent(
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "codex_responses"
touch_calls = []
agent._touch_activity = lambda desc: touch_calls.append(desc)
events = [
SimpleNamespace(type="response.output_text.delta", delta="Hello"),
SimpleNamespace(type="response.output_item.done", item=SimpleNamespace(type="message")),
SimpleNamespace(
type="response.completed",
response=SimpleNamespace(
output=[SimpleNamespace(
type="message",
content=[SimpleNamespace(type="output_text", text="Hello")],
)]
),
),
]
class _FakeCreateStream:
def __iter__(self_inner):
return iter(events)
def close(self_inner):
return None
mock_stream = _FakeCreateStream()
mock_client = MagicMock()
mock_client.responses.create.return_value = mock_stream
agent._run_codex_create_stream_fallback(
{"model": "test/model", "instructions": "hi", "input": []},
client=mock_client,
)
assert touch_calls.count("receiving stream response") == len(events)
class TestAnthropicStreamCallbacks:
"""Verify Anthropic streaming refreshes activity on every event."""
def test_anthropic_stream_refreshes_activity_on_every_event(self):
from run_agent import AIAgent
agent = AIAgent(
model="test/model",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.api_mode = "anthropic_messages"
agent._interrupt_requested = False
touch_calls = []
agent._touch_activity = lambda desc: touch_calls.append(desc)
events = [
SimpleNamespace(
type="content_block_delta",
delta=SimpleNamespace(type="text_delta", text="Hello"),
),
SimpleNamespace(
type="content_block_delta",
delta=SimpleNamespace(type="thinking_delta", thinking="thinking"),
),
SimpleNamespace(
type="content_block_start",
content_block=SimpleNamespace(type="tool_use", name="terminal"),
),
]
final_message = SimpleNamespace(
content=[],
stop_reason="end_turn",
)
mock_stream = MagicMock()
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
mock_stream.__exit__ = MagicMock(return_value=False)
mock_stream.__iter__ = MagicMock(return_value=iter(events))
mock_stream.get_final_message.return_value = final_message
agent._anthropic_client = MagicMock()
agent._anthropic_client.messages.stream.return_value = mock_stream
agent._interruptible_streaming_api_call({})
assert touch_calls.count("receiving stream response") == len(events)

View File

@@ -9,6 +9,8 @@ import pytest
from run_agent import (
_strip_non_ascii,
_sanitize_messages_non_ascii,
_sanitize_structure_non_ascii,
_sanitize_tools_non_ascii,
_sanitize_messages_surrogates,
)
@@ -138,3 +140,66 @@ class TestSurrogateVsAsciiSanitization:
"""When no surrogates present, _sanitize_messages_surrogates returns False."""
messages = [{"role": "user", "content": "hello ⚕ world"}]
assert _sanitize_messages_surrogates(messages) is False
class TestSanitizeToolsNonAscii:
"""Tests for _sanitize_tools_non_ascii."""
def test_sanitizes_tool_description_and_parameter_descriptions(self):
tools = [
{
"type": "function",
"function": {
"name": "read_file",
"description": "Print structured output │ with emoji 🤖",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path │ with unicode",
}
},
},
},
}
]
assert _sanitize_tools_non_ascii(tools) is True
assert tools[0]["function"]["description"] == "Print structured output with emoji "
assert tools[0]["function"]["parameters"]["properties"]["path"]["description"] == "File path with unicode"
def test_no_change_for_ascii_only_tools(self):
tools = [
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read file content",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path",
}
},
},
},
}
]
assert _sanitize_tools_non_ascii(tools) is False
class TestSanitizeStructureNonAscii:
def test_sanitizes_nested_dict_structure(self):
payload = {
"default_headers": {
"X-Title": "Hermes │ Agent",
"User-Agent": "Hermes/1.0 🤖",
}
}
assert _sanitize_structure_non_ascii(payload) is True
assert payload["default_headers"]["X-Title"] == "Hermes Agent"
assert payload["default_headers"]["User-Agent"] == "Hermes/1.0 "

View File

@@ -179,6 +179,7 @@ class TestEphemeralMaxOutputTokens:
return_value=[{"role": "user", "content": "hi"}]
)
agent._anthropic_preserve_dots = MagicMock(return_value=False)
agent.request_overrides = {}
return agent
def test_ephemeral_override_is_used_on_first_call(self):
@@ -253,6 +254,7 @@ class TestContextNotHalvedOnOutputCapError:
)
agent._anthropic_preserve_dots = MagicMock(return_value=False)
agent._vprint = MagicMock()
agent.request_overrides = {}
return agent
def test_output_cap_error_sets_ephemeral_not_context_length(self):

View File

@@ -6,7 +6,8 @@ from unittest.mock import patch
import pytest
from hermes_constants import get_default_hermes_root
import hermes_constants
from hermes_constants import get_default_hermes_root, is_container
class TestGetDefaultHermesRoot:
@@ -60,3 +61,53 @@ class TestGetDefaultHermesRoot:
monkeypatch.setattr(Path, "home", lambda: tmp_path)
monkeypatch.setenv("HERMES_HOME", str(profile))
assert get_default_hermes_root() == docker_root
class TestIsContainer:
"""Tests for is_container() — Docker/Podman detection."""
def _reset_cache(self, monkeypatch):
"""Reset the cached detection result before each test."""
monkeypatch.setattr(hermes_constants, "_container_detected", None)
def test_detects_dockerenv(self, monkeypatch, tmp_path):
"""/.dockerenv triggers container detection."""
self._reset_cache(monkeypatch)
monkeypatch.setattr(os.path, "exists", lambda p: p == "/.dockerenv")
assert is_container() is True
def test_detects_containerenv(self, monkeypatch, tmp_path):
"""/run/.containerenv triggers container detection (Podman)."""
self._reset_cache(monkeypatch)
monkeypatch.setattr(os.path, "exists", lambda p: p == "/run/.containerenv")
assert is_container() is True
def test_detects_cgroup_docker(self, monkeypatch, tmp_path):
"""/proc/1/cgroup containing 'docker' triggers detection."""
import builtins
self._reset_cache(monkeypatch)
monkeypatch.setattr(os.path, "exists", lambda p: False)
cgroup_file = tmp_path / "cgroup"
cgroup_file.write_text("12:memory:/docker/abc123\n")
_real_open = builtins.open
monkeypatch.setattr("builtins.open", lambda p, *a, **kw: _real_open(str(cgroup_file), *a, **kw) if p == "/proc/1/cgroup" else _real_open(p, *a, **kw))
assert is_container() is True
def test_negative_case(self, monkeypatch, tmp_path):
"""Returns False on a regular Linux host."""
import builtins
self._reset_cache(monkeypatch)
monkeypatch.setattr(os.path, "exists", lambda p: False)
cgroup_file = tmp_path / "cgroup"
cgroup_file.write_text("12:memory:/\n")
_real_open = builtins.open
monkeypatch.setattr("builtins.open", lambda p, *a, **kw: _real_open(str(cgroup_file), *a, **kw) if p == "/proc/1/cgroup" else _real_open(p, *a, **kw))
assert is_container() is False
def test_caches_result(self, monkeypatch):
"""Second call uses cached value without re-probing."""
monkeypatch.setattr(hermes_constants, "_container_detected", True)
assert is_container() is True
# Even if we make os.path.exists return False, cached value wins
monkeypatch.setattr(os.path, "exists", lambda p: False)
assert is_container() is True

View File

@@ -298,8 +298,17 @@ class TestGatewayMode:
"""agent.log (catch-all) still receives gateway AND tool records."""
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
logging.getLogger("gateway.run").info("gateway msg")
logging.getLogger("tools.file_tools").info("file msg")
gw_logger = logging.getLogger("gateway.run")
file_logger = logging.getLogger("tools.file_tools")
# Ensure propagation and levels are clean (cross-test pollution defense)
gw_logger.propagate = True
file_logger.propagate = True
logging.getLogger("tools").propagate = True
file_logger.setLevel(logging.NOTSET)
logging.getLogger("tools").setLevel(logging.NOTSET)
gw_logger.info("gateway msg")
file_logger.info("file msg")
for h in logging.getLogger().handlers:
h.flush()

View File

@@ -103,7 +103,7 @@ class TestSourceLineVerification:
if "self.async_client = AsyncOpenAI(" in line and "_get_async_client" not in lines[max(0,i-3):i+1]:
# Allow it inside _get_async_client method
# Check if we're inside _get_async_client by looking at context
context = "\n".join(lines[max(0,i-10):i+1])
context = "\n".join(lines[max(0,i-20):i+1])
if "_get_async_client" not in context:
pytest.fail(
f"Line {i}: AsyncOpenAI created eagerly outside _get_async_client()"

View File

@@ -64,4 +64,4 @@ class TestCamofoxConfigDefaults:
# The current schema version is tracked globally; unrelated default
# options may bump it after browser defaults are added.
assert DEFAULT_CONFIG["_config_version"] == 15
assert DEFAULT_CONFIG["_config_version"] == 17

View File

@@ -79,5 +79,33 @@ class TestSafeWriteRoot:
assert _is_write_denied(os.path.expanduser("~/.ssh/id_rsa")) is True
class TestCheckSensitivePathMacOSBypass:
"""Verify _check_sensitive_path blocks /private/etc paths (issue #8734)."""
def test_etc_hosts_blocked(self):
from tools.file_tools import _check_sensitive_path
assert _check_sensitive_path("/etc/hosts") is not None
def test_private_etc_hosts_blocked(self):
from tools.file_tools import _check_sensitive_path
assert _check_sensitive_path("/private/etc/hosts") is not None
def test_private_etc_ssh_config_blocked(self):
from tools.file_tools import _check_sensitive_path
assert _check_sensitive_path("/private/etc/ssh/sshd_config") is not None
def test_private_var_blocked(self):
from tools.file_tools import _check_sensitive_path
assert _check_sensitive_path("/private/var/db/something") is not None
def test_boot_still_blocked(self):
from tools.file_tools import _check_sensitive_path
assert _check_sensitive_path("/boot/grub/grub.cfg") is not None
def test_safe_path_allowed(self):
from tools.file_tools import _check_sensitive_path
assert _check_sensitive_path("/tmp/safe_file.txt") is None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -5,6 +5,7 @@ handler validation, and availability gating.
"""
import json
from unittest.mock import patch
import pytest
@@ -18,6 +19,7 @@ from tools.homeassistant_tool import (
_handle_call_service,
_BLOCKED_DOMAINS,
_ENTITY_ID_RE,
_SERVICE_NAME_RE,
)
@@ -303,6 +305,147 @@ class TestEntityIdValidation:
assert "Invalid entity_id" not in result["error"]
# ---------------------------------------------------------------------------
# String-data deserialization (XML tool calling workaround)
# ---------------------------------------------------------------------------
class TestCallServiceStringData:
"""data param may arrive as a JSON string (XML tool calling mode)."""
@patch("tools.homeassistant_tool._run_async", return_value={"success": True})
def test_string_data_deserialized(self, mock_run):
"""JSON string data is parsed into a dict before dispatch."""
_handle_call_service({
"domain": "climate",
"service": "set_hvac_mode",
"entity_id": "climate.living_room",
"data": '{"hvac_mode": "heat"}',
})
call_args = mock_run.call_args[0][0] # the coroutine arg
# _run_async was called, meaning we got past validation
@patch("tools.homeassistant_tool._run_async", return_value={"success": True})
def test_dict_data_passthrough(self, mock_run):
"""Dict data (JSON tool calling mode) still works unchanged."""
_handle_call_service({
"domain": "light",
"service": "turn_on",
"entity_id": "light.bedroom",
"data": {"brightness": 255},
})
mock_run.assert_called_once()
def test_invalid_json_string_returns_error(self):
"""Malformed JSON string in data returns a clear error."""
result = json.loads(_handle_call_service({
"domain": "light",
"service": "turn_on",
"entity_id": "light.bedroom",
"data": "{not valid json}",
}))
assert "error" in result
assert "Invalid JSON" in result["error"]
@patch("tools.homeassistant_tool._run_async", return_value={"success": True})
def test_empty_string_data_becomes_none(self, mock_run):
"""Empty/whitespace string data is treated as None."""
_handle_call_service({
"domain": "light",
"service": "turn_on",
"entity_id": "light.bedroom",
"data": " ",
})
mock_run.assert_called_once()
# ---------------------------------------------------------------------------
# Security: domain/service name format validation
# ---------------------------------------------------------------------------
class TestServiceNameValidation:
"""Verify domain/service format validation prevents path traversal in URL.
The domain and service parameters are interpolated into
/api/services/{domain}/{service}, so allowing arbitrary strings would
enable SSRF via path traversal or blocked-domain bypass.
"""
def test_valid_domain_names(self):
assert _SERVICE_NAME_RE.match("light")
assert _SERVICE_NAME_RE.match("switch")
assert _SERVICE_NAME_RE.match("climate")
assert _SERVICE_NAME_RE.match("shell_command")
assert _SERVICE_NAME_RE.match("media_player")
def test_valid_service_names(self):
assert _SERVICE_NAME_RE.match("turn_on")
assert _SERVICE_NAME_RE.match("turn_off")
assert _SERVICE_NAME_RE.match("set_temperature")
assert _SERVICE_NAME_RE.match("toggle")
def test_path_traversal_in_domain_rejected(self):
assert _SERVICE_NAME_RE.match("../../api/config") is None
assert _SERVICE_NAME_RE.match("light/../../../etc") is None
assert _SERVICE_NAME_RE.match("../config") is None
def test_path_traversal_in_service_rejected(self):
assert _SERVICE_NAME_RE.match("../../api/config") is None
assert _SERVICE_NAME_RE.match("turn_on/../../config") is None
def test_blocked_domain_bypass_via_traversal_rejected(self):
"""Ensure shell_command/../light is rejected, not just checked against blocklist."""
assert _SERVICE_NAME_RE.match("shell_command/../light") is None
assert _SERVICE_NAME_RE.match("python_script/../scene") is None
assert _SERVICE_NAME_RE.match("hassio/../automation") is None
def test_slashes_rejected(self):
assert _SERVICE_NAME_RE.match("light/turn_on") is None
assert _SERVICE_NAME_RE.match("a/b/c") is None
def test_dots_rejected(self):
assert _SERVICE_NAME_RE.match("light.turn_on") is None
assert _SERVICE_NAME_RE.match("..") is None
def test_uppercase_rejected(self):
assert _SERVICE_NAME_RE.match("LIGHT") is None
assert _SERVICE_NAME_RE.match("Turn_On") is None
def test_special_chars_rejected(self):
assert _SERVICE_NAME_RE.match("light;rm") is None
assert _SERVICE_NAME_RE.match("light&cmd") is None
assert _SERVICE_NAME_RE.match("light cmd") is None
def test_handler_rejects_traversal_domain(self):
"""_handle_call_service must reject domain with path traversal."""
result = json.loads(_handle_call_service({
"domain": "../../api/config",
"service": "turn_on",
}))
assert "error" in result
assert "Invalid domain" in result["error"]
def test_handler_rejects_traversal_service(self):
"""_handle_call_service must reject service with path traversal."""
result = json.loads(_handle_call_service({
"domain": "light",
"service": "../../api/config",
}))
assert "error" in result
assert "Invalid service" in result["error"]
def test_handler_rejects_blocklist_bypass_traversal(self):
"""Blocklist bypass via shell_command/../light must be caught by format validation."""
result = json.loads(_handle_call_service({
"domain": "shell_command/../light",
"service": "turn_on",
}))
assert "error" in result
# Must be rejected as "Invalid domain", not slip through the blocklist
assert "Invalid domain" in result["error"]
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------

View File

@@ -28,7 +28,7 @@ class TestInterruptModule:
assert not is_interrupted()
def test_thread_safety(self):
"""Set from one thread, check from another."""
"""Set from one thread targeting another thread's ident."""
from tools.interrupt import set_interrupt, is_interrupted
set_interrupt(False)
@@ -45,11 +45,12 @@ class TestInterruptModule:
time.sleep(0.05)
assert not seen["value"]
set_interrupt(True)
# Target the checker thread's ident so it sees the interrupt
set_interrupt(True, thread_id=t.ident)
t.join(timeout=1)
assert seen["value"]
set_interrupt(False)
set_interrupt(False, thread_id=t.ident)
# ---------------------------------------------------------------------------
@@ -189,10 +190,10 @@ class TestSIGKILLEscalation:
t.start()
time.sleep(0.5)
set_interrupt(True)
set_interrupt(True, thread_id=t.ident)
t.join(timeout=5)
set_interrupt(False)
set_interrupt(False, thread_id=t.ident)
assert result_holder["value"] is not None
assert result_holder["value"]["returncode"] == 130

View File

@@ -146,6 +146,40 @@ class TestTruncateAroundMatches:
result = _truncate_around_matches(text, "KEYWORD")
assert "KEYWORD" in result
def test_multiword_phrase_match_beats_individual_term(self):
"""Full phrase deep in text should be found even when a single term
appears much earlier in boilerplate."""
boilerplate = "The project setup is complex. " * 500 # ~15K, has 'project' early
filler = "x" * (MAX_SESSION_CHARS + 20000)
target = "We reviewed the keystone project roadmap in detail."
text = boilerplate + filler + target + filler
result = _truncate_around_matches(text, "keystone project")
assert "keystone project" in result.lower()
def test_multiword_proximity_cooccurrence(self):
"""When exact phrase is absent, terms co-occurring within proximity
should be preferred over a lone early term."""
early = "project " + "a" * (MAX_SESSION_CHARS + 20000)
# Place 'keystone' and 'project' near each other (but not as exact phrase)
cooccur = "this keystone initiative for the project was pivotal"
tail = "b" * (MAX_SESSION_CHARS + 20000)
text = early + cooccur + tail
result = _truncate_around_matches(text, "keystone project")
assert "keystone" in result.lower()
assert "project" in result.lower()
def test_multiword_window_maximises_coverage(self):
"""Sliding window should capture as many match clusters as possible."""
# Place two phrase matches: one at ~50K, one at ~60K, both should fit
pre = "z" * 50000
match1 = " alpha beta "
gap = "z" * 10000
match2 = " alpha beta "
post = "z" * (MAX_SESSION_CHARS + 40000)
text = pre + match1 + gap + match2 + post
result = _truncate_around_matches(text, "alpha beta")
assert result.lower().count("alpha beta") == 2
# =========================================================================
# session_search (dispatcher)

View File

@@ -0,0 +1,145 @@
"""Tests for TTS speed configuration across providers."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@pytest.fixture(autouse=True)
def clean_env(monkeypatch):
for key in ("OPENAI_API_KEY", "MINIMAX_API_KEY", "HERMES_SESSION_PLATFORM"):
monkeypatch.delenv(key, raising=False)
# ---------------------------------------------------------------------------
# Edge TTS speed
# ---------------------------------------------------------------------------
class TestEdgeTtsSpeed:
def _run(self, tts_config, tmp_path):
mock_comm = MagicMock()
mock_comm.save = AsyncMock()
mock_edge = MagicMock()
mock_edge.Communicate = MagicMock(return_value=mock_comm)
with patch("tools.tts_tool._import_edge_tts", return_value=mock_edge):
from tools.tts_tool import _generate_edge_tts
asyncio.run(_generate_edge_tts("Hello", str(tmp_path / "out.mp3"), tts_config))
return mock_edge.Communicate
def test_default_no_rate_kwarg(self, tmp_path):
"""No speed config => no rate kwarg passed to Communicate."""
comm_cls = self._run({}, tmp_path)
kwargs = comm_cls.call_args[1]
assert "rate" not in kwargs
def test_global_speed_applied(self, tmp_path):
"""Global tts.speed used as fallback."""
comm_cls = self._run({"speed": 1.5}, tmp_path)
kwargs = comm_cls.call_args[1]
assert kwargs["rate"] == "+50%"
def test_provider_speed_overrides_global(self, tmp_path):
"""tts.edge.speed takes precedence over tts.speed."""
comm_cls = self._run({"speed": 1.5, "edge": {"speed": 2.0}}, tmp_path)
kwargs = comm_cls.call_args[1]
assert kwargs["rate"] == "+100%"
def test_speed_below_one(self, tmp_path):
"""Speed < 1.0 produces a negative rate string."""
comm_cls = self._run({"speed": 0.5}, tmp_path)
kwargs = comm_cls.call_args[1]
assert kwargs["rate"] == "-50%"
def test_speed_exactly_one_no_rate(self, tmp_path):
"""Explicit speed=1.0 should not pass rate kwarg."""
comm_cls = self._run({"speed": 1.0}, tmp_path)
kwargs = comm_cls.call_args[1]
assert "rate" not in kwargs
# ---------------------------------------------------------------------------
# OpenAI TTS speed
# ---------------------------------------------------------------------------
class TestOpenaiTtsSpeed:
def _run(self, tts_config, tmp_path, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
mock_response = MagicMock()
mock_client = MagicMock()
mock_client.audio.speech.create.return_value = mock_response
mock_cls = MagicMock(return_value=mock_client)
with patch("tools.tts_tool._import_openai_client", return_value=mock_cls), \
patch("tools.tts_tool._resolve_openai_audio_client_config",
return_value=("test-key", None)):
from tools.tts_tool import _generate_openai_tts
_generate_openai_tts("Hello", str(tmp_path / "out.mp3"), tts_config)
return mock_client.audio.speech.create
def test_default_no_speed_kwarg(self, tmp_path, monkeypatch):
"""No speed config => no speed kwarg in create call."""
create = self._run({}, tmp_path, monkeypatch)
kwargs = create.call_args[1]
assert "speed" not in kwargs
def test_global_speed_applied(self, tmp_path, monkeypatch):
"""Global tts.speed used as fallback."""
create = self._run({"speed": 1.5}, tmp_path, monkeypatch)
kwargs = create.call_args[1]
assert kwargs["speed"] == 1.5
def test_provider_speed_overrides_global(self, tmp_path, monkeypatch):
"""tts.openai.speed takes precedence over tts.speed."""
create = self._run({"speed": 1.5, "openai": {"speed": 2.0}}, tmp_path, monkeypatch)
kwargs = create.call_args[1]
assert kwargs["speed"] == 2.0
def test_speed_clamped_low(self, tmp_path, monkeypatch):
"""Speed below 0.25 is clamped to 0.25."""
create = self._run({"speed": 0.1}, tmp_path, monkeypatch)
kwargs = create.call_args[1]
assert kwargs["speed"] == 0.25
def test_speed_clamped_high(self, tmp_path, monkeypatch):
"""Speed above 4.0 is clamped to 4.0."""
create = self._run({"speed": 10.0}, tmp_path, monkeypatch)
kwargs = create.call_args[1]
assert kwargs["speed"] == 4.0
# ---------------------------------------------------------------------------
# MiniMax TTS speed (global fallback wired)
# ---------------------------------------------------------------------------
class TestMinimaxTtsSpeed:
def _run(self, tts_config, tmp_path, monkeypatch):
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {"audio": "deadbeef"},
"base_resp": {"status_code": 0, "status_msg": "success"},
"extra_info": {"audio_size": 8},
}
# requests is imported locally inside _generate_minimax_tts
with patch("requests.post", return_value=mock_response) as mock_post:
from tools.tts_tool import _generate_minimax_tts
_generate_minimax_tts("Hello", str(tmp_path / "out.mp3"), tts_config)
return mock_post
def test_global_speed_fallback(self, tmp_path, monkeypatch):
"""Global tts.speed used when minimax.speed not set."""
mock_post = self._run({"speed": 1.5}, tmp_path, monkeypatch)
payload = mock_post.call_args[1]["json"]
assert payload["voice_setting"]["speed"] == 1.5
def test_provider_speed_overrides_global(self, tmp_path, monkeypatch):
"""tts.minimax.speed takes precedence over tts.speed."""
mock_post = self._run(
{"speed": 1.5, "minimax": {"speed": 2.0}}, tmp_path, monkeypatch
)
payload = mock_post.call_args[1]["json"]
assert payload["voice_setting"]["speed"] == 2.0

View File

@@ -463,8 +463,6 @@ class TestVisionRequirements:
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AUXILIARY_VISION_PROVIDER", raising=False)
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
assert check_vision_requirements() is True

View File

@@ -32,6 +32,7 @@ def _make_voice_cli(**overrides):
cli._voice_tts_done.set()
cli._pending_input = queue.Queue()
cli._app = None
cli._attached_images = []
cli.console = SimpleNamespace(width=80)
for k, v in overrides.items():
setattr(cli, k, v)

View File

@@ -190,17 +190,38 @@ class TestGatewayCleanupWiring:
def test_gateway_stop_calls_close(self):
"""gateway stop() should call close() on all running agents."""
import asyncio
from unittest.mock import MagicMock, patch
import threading
from unittest.mock import AsyncMock, MagicMock, patch
runner = MagicMock()
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._running = True
runner._running_agents = {}
runner._running_agents_ts = {}
runner.adapters = {}
runner._background_tasks = set()
runner._pending_messages = {}
runner._pending_approvals = {}
runner._pending_model_notes = {}
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_code = None
runner._stop_task = None
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 5.0
runner._voice_mode = {}
runner._session_model_overrides = {}
runner._update_prompt_pending = {}
runner._busy_input_mode = "interrupt"
runner._agent_cache = {}
runner._agent_cache_lock = threading.Lock()
runner._shutdown_all_gateway_honcho = lambda: None
runner._update_runtime_status = MagicMock()
mock_agent_1 = MagicMock()
mock_agent_2 = MagicMock()
@@ -209,8 +230,6 @@ class TestGatewayCleanupWiring:
"session-2": mock_agent_2,
}
from gateway.run import GatewayRunner
loop = asyncio.new_event_loop()
try:
with patch("gateway.status.remove_pid_file"), \