Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
@@ -17,7 +17,6 @@ from agent.auxiliary_client import (
|
||||
call_llm,
|
||||
async_call_llm,
|
||||
_read_codex_access_token,
|
||||
_get_auxiliary_provider,
|
||||
_get_provider_chain,
|
||||
_is_payment_error,
|
||||
_try_payment_fallback,
|
||||
@@ -32,12 +31,6 @@ def _clean_env(monkeypatch):
|
||||
"OPENROUTER_API_KEY", "OPENAI_BASE_URL", "OPENAI_API_KEY",
|
||||
"OPENAI_MODEL", "LLM_MODEL", "NOUS_INFERENCE_BASE_URL",
|
||||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN",
|
||||
# Per-task provider/model/direct-endpoint overrides
|
||||
"AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_VISION_BASE_URL", "AUXILIARY_VISION_API_KEY",
|
||||
"AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"AUXILIARY_WEB_EXTRACT_BASE_URL", "AUXILIARY_WEB_EXTRACT_API_KEY",
|
||||
"CONTEXT_COMPRESSION_PROVIDER", "CONTEXT_COMPRESSION_MODEL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
@@ -568,29 +561,6 @@ class TestGetTextAuxiliaryClient:
|
||||
call_kwargs = mock_openai.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == "http://localhost:1234/v1"
|
||||
|
||||
def test_task_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_API_KEY", "task-key")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert model == "task-model"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "task-key"
|
||||
|
||||
def test_task_direct_endpoint_without_openai_key_uses_placeholder(self, monkeypatch):
|
||||
"""Local endpoints without an API key should use 'no-key-required' placeholder."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert client is not None
|
||||
assert model == "task-model"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "no-key-required"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1"
|
||||
|
||||
def test_custom_endpoint_uses_config_saved_base_url(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
@@ -879,73 +849,9 @@ class TestAuxiliaryPoolAwareness:
|
||||
|
||||
|
||||
|
||||
class TestGetAuxiliaryProvider:
|
||||
"""Tests for _get_auxiliary_provider env var resolution."""
|
||||
|
||||
def test_no_task_returns_auto(self):
|
||||
assert _get_auxiliary_provider() == "auto"
|
||||
assert _get_auxiliary_provider("") == "auto"
|
||||
|
||||
def test_auxiliary_prefix_takes_priority(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "openrouter")
|
||||
assert _get_auxiliary_provider("vision") == "openrouter"
|
||||
|
||||
def test_context_prefix_fallback(self, monkeypatch):
|
||||
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
|
||||
assert _get_auxiliary_provider("compression") == "nous"
|
||||
|
||||
def test_auxiliary_prefix_over_context_prefix(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_COMPRESSION_PROVIDER", "openrouter")
|
||||
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
|
||||
assert _get_auxiliary_provider("compression") == "openrouter"
|
||||
|
||||
def test_auto_value_treated_as_auto(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "auto")
|
||||
assert _get_auxiliary_provider("vision") == "auto"
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", " openrouter ")
|
||||
assert _get_auxiliary_provider("vision") == "openrouter"
|
||||
|
||||
def test_case_insensitive(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "OpenRouter")
|
||||
assert _get_auxiliary_provider("vision") == "openrouter"
|
||||
|
||||
def test_main_provider(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_PROVIDER", "main")
|
||||
assert _get_auxiliary_provider("web_extract") == "main"
|
||||
|
||||
|
||||
class TestTaskSpecificOverrides:
|
||||
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
|
||||
|
||||
def test_text_with_vision_provider_override(self, monkeypatch):
|
||||
"""AUXILIARY_VISION_PROVIDER should not affect text tasks."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "nous")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client() # no task → auto
|
||||
assert model == "google/gemini-3-flash-preview" # OpenRouter, not Nous
|
||||
|
||||
def test_compression_task_reads_context_prefix(self, monkeypatch):
|
||||
"""Compression task should check CONTEXT_COMPRESSION_PROVIDER env var."""
|
||||
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") # would win in auto
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "***"}
|
||||
client, model = get_text_auxiliary_client("compression")
|
||||
# Config-first: model comes from config.yaml summary_model default,
|
||||
# but provider is forced to Nous via env var
|
||||
assert client is not None
|
||||
|
||||
def test_web_extract_task_override(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_PROVIDER", "openrouter")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_task_direct_endpoint_from_config(self, monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
@@ -979,8 +885,6 @@ class TestTaskSpecificOverrides:
|
||||
"""model:
|
||||
default: glm-5.1
|
||||
provider: opencode-go
|
||||
compression:
|
||||
summary_provider: auto
|
||||
"""
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
@@ -1039,24 +943,45 @@ model:
|
||||
"model": "gpt-5.4",
|
||||
}
|
||||
|
||||
def test_compression_summary_base_url_from_config(self, monkeypatch, tmp_path):
|
||||
"""compression.summary_base_url should produce a custom-endpoint client."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"""compression:
|
||||
summary_provider: custom
|
||||
summary_model: glm-4.7
|
||||
summary_base_url: https://api.z.ai/api/coding/paas/v4
|
||||
"""
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# Custom endpoints need an API key to build the client
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client("compression")
|
||||
assert model == "glm-4.7"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://api.z.ai/api/coding/paas/v4"
|
||||
|
||||
def test_resolve_provider_client_supports_copilot_acp_external_process():
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch("agent.auxiliary_client._read_main_model", return_value="gpt-5.4-mini"), \
|
||||
patch("agent.auxiliary_client.CodexAuxiliaryClient", MagicMock()), \
|
||||
patch("agent.copilot_acp_client.CopilotACPClient", return_value=fake_client) as mock_acp, \
|
||||
patch("hermes_cli.auth.resolve_external_process_provider_credentials", return_value={
|
||||
"provider": "copilot-acp",
|
||||
"api_key": "copilot-acp",
|
||||
"base_url": "acp://copilot",
|
||||
"command": "/usr/bin/copilot",
|
||||
"args": ["--acp", "--stdio"],
|
||||
}):
|
||||
client, model = resolve_provider_client("copilot-acp")
|
||||
|
||||
assert client is fake_client
|
||||
assert model == "gpt-5.4-mini"
|
||||
assert mock_acp.call_args.kwargs["api_key"] == "copilot-acp"
|
||||
assert mock_acp.call_args.kwargs["base_url"] == "acp://copilot"
|
||||
assert mock_acp.call_args.kwargs["command"] == "/usr/bin/copilot"
|
||||
assert mock_acp.call_args.kwargs["args"] == ["--acp", "--stdio"]
|
||||
|
||||
|
||||
def test_resolve_provider_client_copilot_acp_requires_explicit_or_configured_model():
|
||||
with patch("agent.auxiliary_client._read_main_model", return_value=""), \
|
||||
patch("agent.copilot_acp_client.CopilotACPClient") as mock_acp, \
|
||||
patch("hermes_cli.auth.resolve_external_process_provider_credentials", return_value={
|
||||
"provider": "copilot-acp",
|
||||
"api_key": "copilot-acp",
|
||||
"base_url": "acp://copilot",
|
||||
"command": "/usr/bin/copilot",
|
||||
"args": ["--acp", "--stdio"],
|
||||
}):
|
||||
client, model = resolve_provider_client("copilot-acp")
|
||||
|
||||
assert client is None
|
||||
assert model is None
|
||||
mock_acp.assert_not_called()
|
||||
|
||||
|
||||
class TestAuxiliaryMaxTokensParam:
|
||||
|
||||
@@ -273,18 +273,6 @@ class TestDefaultConfigShape:
|
||||
assert web["provider"] == "auto"
|
||||
assert web["model"] == ""
|
||||
|
||||
def test_compression_provider_default(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
compression = DEFAULT_CONFIG["compression"]
|
||||
assert "summary_provider" in compression
|
||||
assert compression["summary_provider"] == "auto"
|
||||
|
||||
def test_compression_base_url_default(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
compression = DEFAULT_CONFIG["compression"]
|
||||
assert "summary_base_url" in compression
|
||||
assert compression["summary_base_url"] is None
|
||||
|
||||
|
||||
# ── CLI defaults parity ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -12,17 +12,6 @@ def _isolate(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
for env_var in (
|
||||
"AUXILIARY_VISION_PROVIDER",
|
||||
"AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_VISION_BASE_URL",
|
||||
"AUXILIARY_VISION_API_KEY",
|
||||
"CONTEXT_VISION_PROVIDER",
|
||||
"CONTEXT_VISION_MODEL",
|
||||
"CONTEXT_VISION_BASE_URL",
|
||||
"CONTEXT_VISION_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(env_var, raising=False)
|
||||
# Write a minimal config so load_config doesn't fail
|
||||
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
|
||||
|
||||
@@ -69,6 +58,10 @@ class TestNormalizeVisionProvider:
|
||||
assert _normalize_vision_provider("beans") == "beans"
|
||||
assert _normalize_vision_provider("deepseek") == "deepseek"
|
||||
|
||||
def test_custom_colon_named_provider_preserved(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("custom:beans") == "beans"
|
||||
|
||||
def test_codex_alias_still_works(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("codex") == "openai-codex"
|
||||
@@ -240,3 +233,22 @@ class TestResolveVisionProviderClientModelNormalization:
|
||||
assert provider == "zai"
|
||||
assert client is not None
|
||||
assert model == "glm-5.1"
|
||||
|
||||
|
||||
class TestVisionPathApiMode:
|
||||
"""Vision path should propagate api_mode to _get_cached_client."""
|
||||
|
||||
def test_explicit_provider_passes_api_mode(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test-model"},
|
||||
"auxiliary": {"vision": {"api_mode": "chat_completions"}},
|
||||
})
|
||||
with patch("agent.auxiliary_client._get_cached_client") as mock_gcc:
|
||||
mock_gcc.return_value = (MagicMock(), "test-model")
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client(provider="deepseek")
|
||||
|
||||
mock_gcc.assert_called_once()
|
||||
_, kwargs = mock_gcc.call_args
|
||||
assert kwargs.get("api_mode") == "chat_completions"
|
||||
|
||||
@@ -308,6 +308,34 @@ class TestMinimaxPreserveDots:
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is False
|
||||
|
||||
def test_opencode_zen_provider_preserves_dots(self):
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(provider="opencode-zen", base_url="")
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_opencode_zen_url_preserves_dots(self):
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(provider="custom", base_url="https://opencode.ai/zen/v1")
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_zai_provider_preserves_dots(self):
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(provider="zai", base_url="")
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_bigmodel_cn_url_preserves_dots(self):
|
||||
from types import SimpleNamespace
|
||||
agent = SimpleNamespace(provider="custom", base_url="https://open.bigmodel.cn/api/paas/v4")
|
||||
from run_agent import AIAgent
|
||||
assert AIAgent._anthropic_preserve_dots(agent) is True
|
||||
|
||||
def test_normalize_preserves_m25_free_dot(self):
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name("minimax-m2.5-free", preserve_dots=True) == "minimax-m2.5-free"
|
||||
|
||||
def test_normalize_preserves_m27_dot(self):
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name("MiniMax-M2.7", preserve_dots=True) == "MiniMax-M2.7"
|
||||
|
||||
@@ -70,6 +70,44 @@ class TestQueryLocalContextLengthOllama:
|
||||
|
||||
assert result == 32768
|
||||
|
||||
def test_ollama_num_ctx_wins_over_model_info(self):
|
||||
"""When both num_ctx (Modelfile) and model_info (GGUF) are present,
|
||||
num_ctx wins because it's the *runtime* context Ollama actually
|
||||
allocates KV cache for. The GGUF model_info.context_length is the
|
||||
training max — using it would let Hermes grow conversations past
|
||||
the runtime limit and Ollama would silently truncate.
|
||||
|
||||
Concrete example: hermes-brain:qwen3-14b-ctx32k is a Modelfile
|
||||
derived from qwen3:14b with `num_ctx 32768`, but the underlying
|
||||
GGUF reports `qwen3.context_length: 40960` (training max). If
|
||||
Hermes used 40960 it would let the conversation grow past 32768
|
||||
before compressing, and Ollama would truncate the prefix.
|
||||
"""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
show_resp = self._make_resp(200, {
|
||||
"model_info": {"qwen3.context_length": 40960},
|
||||
"parameters": "num_ctx 32768\ntemperature 0.6\n",
|
||||
})
|
||||
models_resp = self._make_resp(404, {})
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = show_resp
|
||||
client_mock.get.return_value = models_resp
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length(
|
||||
"hermes-brain:qwen3-14b-ctx32k", "http://100.77.243.5:11434/v1"
|
||||
)
|
||||
|
||||
assert result == 32768, (
|
||||
f"Expected num_ctx (32768) to win over model_info (40960), got {result}. "
|
||||
"If Hermes uses the GGUF training max, conversations will silently truncate."
|
||||
)
|
||||
|
||||
def test_ollama_show_404_falls_through(self):
|
||||
"""When /api/show returns 404, falls through to /v1/models/{model}."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
@@ -51,10 +51,10 @@ class TestSaveConfigValueAtomic:
|
||||
def test_creates_nested_keys(self, config_env):
|
||||
"""Dot-separated paths create intermediate dicts as needed."""
|
||||
from cli import save_config_value
|
||||
save_config_value("compression.summary_model", "google/gemini-3-flash-preview")
|
||||
save_config_value("auxiliary.compression.model", "google/gemini-3-flash-preview")
|
||||
|
||||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["compression"]["summary_model"] == "google/gemini-3-flash-preview"
|
||||
assert result["auxiliary"]["compression"]["model"] == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_overwrites_existing_value(self, config_env):
|
||||
"""Updating an existing key replaces the value."""
|
||||
|
||||
@@ -180,33 +180,71 @@ class TestDisplayResumedHistory:
|
||||
assert 200 <= a_count <= 310 # roughly 300 chars (±panel padding)
|
||||
|
||||
def test_long_assistant_message_truncated(self):
|
||||
"""Non-last assistant messages are still truncated."""
|
||||
cli = _make_cli()
|
||||
long_text = "B" * 400
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Tell me a lot."},
|
||||
{"role": "assistant", "content": long_text},
|
||||
{"role": "user", "content": "And more?"},
|
||||
{"role": "assistant", "content": "Short final reply."},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "..." in output
|
||||
# The non-last assistant message should be truncated
|
||||
assert "B" * 400 not in output
|
||||
# The last assistant message shown in full
|
||||
assert "Short final reply." in output
|
||||
|
||||
def test_multiline_assistant_truncated(self):
|
||||
"""Non-last multiline assistant messages are truncated to 3 lines."""
|
||||
cli = _make_cli()
|
||||
multi = "\n".join([f"Line {i}" for i in range(20)])
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Show me lines."},
|
||||
{"role": "assistant", "content": multi},
|
||||
{"role": "user", "content": "What else?"},
|
||||
{"role": "assistant", "content": "Done."},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# First 3 lines should be there
|
||||
# First 3 lines of non-last assistant should be there
|
||||
assert "Line 0" in output
|
||||
assert "Line 1" in output
|
||||
assert "Line 2" in output
|
||||
# Line 19 should NOT be there (truncated after 3 lines)
|
||||
# Line 19 should NOT be in the truncated message
|
||||
assert "Line 19" not in output
|
||||
|
||||
def test_last_assistant_response_shown_in_full(self):
|
||||
"""The last assistant response is shown un-truncated so the user
|
||||
knows where they left off without wasting tokens re-asking."""
|
||||
cli = _make_cli()
|
||||
long_text = "X" * 500
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Tell me everything."},
|
||||
{"role": "assistant", "content": long_text},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# Full 500-char text should be present (may be line-wrapped by Rich)
|
||||
x_count = output.count("X")
|
||||
assert x_count >= 490 # allow small Rich formatting variance
|
||||
|
||||
def test_last_assistant_multiline_shown_in_full(self):
|
||||
"""The last assistant response shows all lines, not just 3."""
|
||||
cli = _make_cli()
|
||||
multi = "\n".join([f"Line {i}" for i in range(20)])
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Show me everything."},
|
||||
{"role": "assistant", "content": multi},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# All 20 lines should be present since it's the last response
|
||||
assert "Line 0" in output
|
||||
assert "Line 10" in output
|
||||
assert "Line 19" in output
|
||||
|
||||
def test_large_history_shows_truncation_indicator(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _large_history(n_exchanges=15)
|
||||
|
||||
@@ -35,6 +35,7 @@ def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> Sessi
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
user_id="u1",
|
||||
)
|
||||
|
||||
|
||||
|
||||
87
tests/gateway/test_api_server_normalize.py
Normal file
87
tests/gateway/test_api_server_normalize.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Tests for _normalize_chat_content in the API server adapter."""
|
||||
|
||||
from gateway.platforms.api_server import _normalize_chat_content
|
||||
|
||||
|
||||
class TestNormalizeChatContent:
|
||||
"""Content normalization converts array-based content parts to plain text."""
|
||||
|
||||
def test_none_returns_empty_string(self):
|
||||
assert _normalize_chat_content(None) == ""
|
||||
|
||||
def test_plain_string_returned_as_is(self):
|
||||
assert _normalize_chat_content("hello world") == "hello world"
|
||||
|
||||
def test_empty_string_returned_as_is(self):
|
||||
assert _normalize_chat_content("") == ""
|
||||
|
||||
def test_text_content_part(self):
|
||||
content = [{"type": "text", "text": "hello"}]
|
||||
assert _normalize_chat_content(content) == "hello"
|
||||
|
||||
def test_input_text_content_part(self):
|
||||
content = [{"type": "input_text", "text": "user input"}]
|
||||
assert _normalize_chat_content(content) == "user input"
|
||||
|
||||
def test_output_text_content_part(self):
|
||||
content = [{"type": "output_text", "text": "assistant output"}]
|
||||
assert _normalize_chat_content(content) == "assistant output"
|
||||
|
||||
def test_multiple_text_parts_joined_with_newline(self):
|
||||
content = [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
assert _normalize_chat_content(content) == "first\nsecond"
|
||||
|
||||
def test_mixed_string_and_dict_parts(self):
|
||||
content = ["plain string", {"type": "text", "text": "dict part"}]
|
||||
assert _normalize_chat_content(content) == "plain string\ndict part"
|
||||
|
||||
def test_image_url_parts_silently_skipped(self):
|
||||
content = [
|
||||
{"type": "text", "text": "check this:"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
|
||||
]
|
||||
assert _normalize_chat_content(content) == "check this:"
|
||||
|
||||
def test_integer_content_converted(self):
|
||||
assert _normalize_chat_content(42) == "42"
|
||||
|
||||
def test_boolean_content_converted(self):
|
||||
assert _normalize_chat_content(True) == "True"
|
||||
|
||||
def test_deeply_nested_list_respects_depth_limit(self):
|
||||
"""Nesting beyond max_depth returns empty string."""
|
||||
content = [[[[[[[[[[[["deep"]]]]]]]]]]]]
|
||||
result = _normalize_chat_content(content)
|
||||
# The deep nesting should be truncated, not crash
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_large_list_capped(self):
|
||||
"""Lists beyond MAX_CONTENT_LIST_SIZE are truncated."""
|
||||
content = [{"type": "text", "text": f"item{i}"} for i in range(2000)]
|
||||
result = _normalize_chat_content(content)
|
||||
# Should not contain all 2000 items
|
||||
assert result.count("item") <= 1000
|
||||
|
||||
def test_oversized_string_truncated(self):
|
||||
"""Strings beyond 64KB are truncated."""
|
||||
huge = "x" * 100_000
|
||||
result = _normalize_chat_content(huge)
|
||||
assert len(result) == 65_536
|
||||
|
||||
def test_empty_text_parts_filtered(self):
|
||||
content = [
|
||||
{"type": "text", "text": ""},
|
||||
{"type": "text", "text": "actual"},
|
||||
{"type": "text", "text": ""},
|
||||
]
|
||||
assert _normalize_chat_content(content) == "actual"
|
||||
|
||||
def test_dict_without_type_skipped(self):
|
||||
content = [{"foo": "bar"}, {"type": "text", "text": "real"}]
|
||||
assert _normalize_chat_content(content) == "real"
|
||||
|
||||
def test_empty_list_returns_empty(self):
|
||||
assert _normalize_chat_content([]) == ""
|
||||
@@ -359,3 +359,44 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_voice_linked_channel_skips_mention_requirement_and_auto_thread(adapter, monkeypatch):
|
||||
"""Active voice-linked text channels should behave like free-response channels."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter._voice_text_channels[111] = 789
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=789),
|
||||
content="follow-up from voice text chat",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "follow-up from voice text chat"
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_voice_linked_parent_thread_still_requires_mention(adapter, monkeypatch):
|
||||
"""Threads under a voice-linked channel should still require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._voice_text_channels[111] = 789
|
||||
message = make_message(
|
||||
channel=FakeThread(channel_id=790, parent=FakeTextChannel(channel_id=789)),
|
||||
content="thread reply without mention",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
@@ -124,7 +124,7 @@ class TestSendWithReplyToMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_off_mode_no_reply_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("off")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestSendWithReplyToMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_mode_only_first_chunk_references(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("first")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
@@ -152,7 +152,7 @@ class TestSendWithReplyToMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_mode_all_chunks_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("all")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
@@ -165,7 +165,7 @@ class TestSendWithReplyToMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reply_to_param_no_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("all")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to=None)
|
||||
|
||||
@@ -176,7 +176,7 @@ class TestSendWithReplyToMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_respects_first_mode(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("first")
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
@@ -187,7 +187,7 @@ class TestSendWithReplyToMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_off_mode(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("off")
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestSendWithReplyToMode:
|
||||
async def test_invalid_mode_falls_back_to_first_behavior(self):
|
||||
"""Invalid mode behaves like 'first' — only first chunk gets reference."""
|
||||
adapter, channel, ref_msg = _make_discord_adapter("banana")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
|
||||
@@ -189,14 +189,14 @@ class TestPlatformDefaults:
|
||||
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu"):
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu", "whatsapp"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
|
||||
|
||||
def test_low_tier_platforms(self):
|
||||
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
|
||||
"""Signal, BlueBubbles, etc. default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
|
||||
for plat in ("signal", "bluebubbles", "weixin", "wecom", "dingtalk"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_minimal_tier_platforms(self):
|
||||
|
||||
438
tests/gateway/test_feishu_onboard.py
Normal file
438
tests/gateway/test_feishu_onboard.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""Tests for gateway.platforms.feishu — Feishu scan-to-create registration."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
|
||||
def _mock_urlopen(response_data, status=200):
|
||||
"""Create a mock for urllib.request.urlopen that returns JSON response_data."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps(response_data).encode("utf-8")
|
||||
mock_response.status = status
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
return mock_response
|
||||
|
||||
|
||||
class TestPostRegistration:
|
||||
"""Tests for the low-level HTTP helper."""
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_post_registration_returns_parsed_json(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _post_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({"nonce": "abc", "supported_auth_methods": ["client_secret"]})
|
||||
result = _post_registration("https://accounts.feishu.cn", {"action": "init"})
|
||||
assert result["nonce"] == "abc"
|
||||
assert "client_secret" in result["supported_auth_methods"]
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_post_registration_sends_form_encoded_body(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _post_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({})
|
||||
_post_registration("https://accounts.feishu.cn", {"action": "init", "key": "val"})
|
||||
call_args = mock_urlopen_fn.call_args
|
||||
request = call_args[0][0]
|
||||
body = request.data.decode("utf-8")
|
||||
assert "action=init" in body
|
||||
assert "key=val" in body
|
||||
assert request.get_header("Content-type") == "application/x-www-form-urlencoded"
|
||||
|
||||
|
||||
class TestInitRegistration:
|
||||
"""Tests for the init step."""
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_init_succeeds_when_client_secret_supported(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _init_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"nonce": "abc",
|
||||
"supported_auth_methods": ["client_secret"],
|
||||
})
|
||||
_init_registration("feishu")
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_init_raises_when_client_secret_not_supported(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _init_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"nonce": "abc",
|
||||
"supported_auth_methods": ["other_method"],
|
||||
})
|
||||
with pytest.raises(RuntimeError, match="client_secret"):
|
||||
_init_registration("feishu")
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_init_uses_lark_url_for_lark_domain(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _init_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"nonce": "abc",
|
||||
"supported_auth_methods": ["client_secret"],
|
||||
})
|
||||
_init_registration("lark")
|
||||
call_args = mock_urlopen_fn.call_args
|
||||
request = call_args[0][0]
|
||||
assert "larksuite.com" in request.full_url
|
||||
|
||||
|
||||
class TestBeginRegistration:
|
||||
"""Tests for the begin step."""
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_begin_returns_device_code_and_qr_url(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _begin_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"device_code": "dc_123",
|
||||
"verification_uri_complete": "https://accounts.feishu.cn/qr/abc",
|
||||
"user_code": "ABCD-1234",
|
||||
"interval": 5,
|
||||
"expire_in": 600,
|
||||
})
|
||||
result = _begin_registration("feishu")
|
||||
assert result["device_code"] == "dc_123"
|
||||
assert "qr_url" in result
|
||||
assert "accounts.feishu.cn" in result["qr_url"]
|
||||
assert result["user_code"] == "ABCD-1234"
|
||||
assert result["interval"] == 5
|
||||
assert result["expire_in"] == 600
|
||||
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_begin_sends_correct_archetype(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import _begin_registration
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"device_code": "dc_123",
|
||||
"verification_uri_complete": "https://example.com/qr",
|
||||
"user_code": "X",
|
||||
"interval": 5,
|
||||
"expire_in": 600,
|
||||
})
|
||||
_begin_registration("feishu")
|
||||
request = mock_urlopen_fn.call_args[0][0]
|
||||
body = request.data.decode("utf-8")
|
||||
assert "archetype=PersonalAgent" in body
|
||||
assert "auth_method=client_secret" in body
|
||||
|
||||
|
||||
class TestPollRegistration:
|
||||
"""Tests for the poll step."""
|
||||
|
||||
@patch("gateway.platforms.feishu.time")
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_poll_returns_credentials_on_success(self, mock_urlopen_fn, mock_time):
|
||||
from gateway.platforms.feishu import _poll_registration
|
||||
|
||||
mock_time.time.side_effect = [0, 1]
|
||||
mock_time.sleep = MagicMock()
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"client_id": "cli_app123",
|
||||
"client_secret": "secret456",
|
||||
"user_info": {"open_id": "ou_owner", "tenant_brand": "feishu"},
|
||||
})
|
||||
result = _poll_registration(
|
||||
device_code="dc_123", interval=1, expire_in=60, domain="feishu"
|
||||
)
|
||||
assert result is not None
|
||||
assert result["app_id"] == "cli_app123"
|
||||
assert result["app_secret"] == "secret456"
|
||||
assert result["domain"] == "feishu"
|
||||
assert result["open_id"] == "ou_owner"
|
||||
|
||||
@patch("gateway.platforms.feishu.time")
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_poll_switches_domain_on_lark_tenant_brand(self, mock_urlopen_fn, mock_time):
|
||||
from gateway.platforms.feishu import _poll_registration
|
||||
|
||||
mock_time.time.side_effect = [0, 1, 2]
|
||||
mock_time.sleep = MagicMock()
|
||||
|
||||
pending_resp = _mock_urlopen({
|
||||
"error": "authorization_pending",
|
||||
"user_info": {"tenant_brand": "lark"},
|
||||
})
|
||||
success_resp = _mock_urlopen({
|
||||
"client_id": "cli_lark",
|
||||
"client_secret": "secret_lark",
|
||||
"user_info": {"open_id": "ou_lark", "tenant_brand": "lark"},
|
||||
})
|
||||
mock_urlopen_fn.side_effect = [pending_resp, success_resp]
|
||||
|
||||
result = _poll_registration(
|
||||
device_code="dc_123", interval=0, expire_in=60, domain="feishu"
|
||||
)
|
||||
assert result is not None
|
||||
assert result["domain"] == "lark"
|
||||
|
||||
@patch("gateway.platforms.feishu.time")
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_poll_success_with_lark_brand_in_same_response(self, mock_urlopen_fn, mock_time):
|
||||
"""Credentials and lark tenant_brand in one response must not be discarded."""
|
||||
from gateway.platforms.feishu import _poll_registration
|
||||
|
||||
mock_time.time.side_effect = [0, 1]
|
||||
mock_time.sleep = MagicMock()
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"client_id": "cli_lark_direct",
|
||||
"client_secret": "secret_lark_direct",
|
||||
"user_info": {"open_id": "ou_lark_direct", "tenant_brand": "lark"},
|
||||
})
|
||||
result = _poll_registration(
|
||||
device_code="dc_123", interval=1, expire_in=60, domain="feishu"
|
||||
)
|
||||
assert result is not None
|
||||
assert result["app_id"] == "cli_lark_direct"
|
||||
assert result["domain"] == "lark"
|
||||
assert result["open_id"] == "ou_lark_direct"
|
||||
|
||||
@patch("gateway.platforms.feishu.time")
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_poll_returns_none_on_access_denied(self, mock_urlopen_fn, mock_time):
|
||||
from gateway.platforms.feishu import _poll_registration
|
||||
|
||||
mock_time.time.side_effect = [0, 1]
|
||||
mock_time.sleep = MagicMock()
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"error": "access_denied",
|
||||
})
|
||||
result = _poll_registration(
|
||||
device_code="dc_123", interval=1, expire_in=60, domain="feishu"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@patch("gateway.platforms.feishu.time")
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_poll_returns_none_on_timeout(self, mock_urlopen_fn, mock_time):
|
||||
from gateway.platforms.feishu import _poll_registration
|
||||
|
||||
mock_time.time.side_effect = [0, 999]
|
||||
mock_time.sleep = MagicMock()
|
||||
|
||||
mock_urlopen_fn.return_value = _mock_urlopen({
|
||||
"error": "authorization_pending",
|
||||
})
|
||||
result = _poll_registration(
|
||||
device_code="dc_123", interval=1, expire_in=1, domain="feishu"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRenderQr:
|
||||
"""Tests for QR code terminal rendering."""
|
||||
|
||||
@patch("gateway.platforms.feishu._qrcode_mod", create=True)
|
||||
def test_render_qr_returns_true_on_success(self, mock_qrcode_mod):
|
||||
from gateway.platforms.feishu import _render_qr
|
||||
|
||||
mock_qr = MagicMock()
|
||||
mock_qrcode_mod.QRCode.return_value = mock_qr
|
||||
assert _render_qr("https://example.com/qr") is True
|
||||
mock_qr.add_data.assert_called_once_with("https://example.com/qr")
|
||||
mock_qr.make.assert_called_once_with(fit=True)
|
||||
mock_qr.print_ascii.assert_called_once()
|
||||
|
||||
def test_render_qr_returns_false_when_qrcode_missing(self):
|
||||
from gateway.platforms.feishu import _render_qr
|
||||
|
||||
with patch("gateway.platforms.feishu._qrcode_mod", None):
|
||||
assert _render_qr("https://example.com/qr") is False
|
||||
|
||||
|
||||
class TestProbeBot:
|
||||
"""Tests for bot connectivity verification."""
|
||||
|
||||
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True)
|
||||
def test_probe_returns_bot_info_on_success(self):
|
||||
from gateway.platforms.feishu import probe_bot
|
||||
|
||||
with patch("gateway.platforms.feishu._probe_bot_sdk") as mock_sdk:
|
||||
mock_sdk.return_value = {"bot_name": "TestBot", "bot_open_id": "ou_bot123"}
|
||||
result = probe_bot("cli_app", "secret", "feishu")
|
||||
|
||||
assert result is not None
|
||||
assert result["bot_name"] == "TestBot"
|
||||
assert result["bot_open_id"] == "ou_bot123"
|
||||
|
||||
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True)
|
||||
def test_probe_returns_none_on_failure(self):
|
||||
from gateway.platforms.feishu import probe_bot
|
||||
|
||||
with patch("gateway.platforms.feishu._probe_bot_sdk") as mock_sdk:
|
||||
mock_sdk.return_value = None
|
||||
result = probe_bot("bad_id", "bad_secret", "feishu")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", False)
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_http_fallback_when_sdk_unavailable(self, mock_urlopen_fn):
|
||||
"""Without lark_oapi, probe falls back to raw HTTP."""
|
||||
from gateway.platforms.feishu import probe_bot
|
||||
|
||||
token_resp = _mock_urlopen({"code": 0, "tenant_access_token": "t-123"})
|
||||
bot_resp = _mock_urlopen({"code": 0, "bot": {"bot_name": "HttpBot", "open_id": "ou_http"}})
|
||||
mock_urlopen_fn.side_effect = [token_resp, bot_resp]
|
||||
|
||||
result = probe_bot("cli_app", "secret", "feishu")
|
||||
assert result is not None
|
||||
assert result["bot_name"] == "HttpBot"
|
||||
|
||||
@patch("gateway.platforms.feishu.FEISHU_AVAILABLE", False)
|
||||
@patch("gateway.platforms.feishu.urlopen")
|
||||
def test_http_fallback_returns_none_on_network_error(self, mock_urlopen_fn):
|
||||
from gateway.platforms.feishu import probe_bot
|
||||
from urllib.error import URLError
|
||||
|
||||
mock_urlopen_fn.side_effect = URLError("connection refused")
|
||||
result = probe_bot("cli_app", "secret", "feishu")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestQrRegister:
|
||||
"""Tests for the public qr_register entry point."""
|
||||
|
||||
@patch("gateway.platforms.feishu.probe_bot")
|
||||
@patch("gateway.platforms.feishu._render_qr")
|
||||
@patch("gateway.platforms.feishu._poll_registration")
|
||||
@patch("gateway.platforms.feishu._begin_registration")
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_success_flow(
|
||||
self, mock_init, mock_begin, mock_poll, mock_render, mock_probe
|
||||
):
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_begin.return_value = {
|
||||
"device_code": "dc_123",
|
||||
"qr_url": "https://example.com/qr",
|
||||
"user_code": "ABCD",
|
||||
"interval": 1,
|
||||
"expire_in": 60,
|
||||
}
|
||||
mock_poll.return_value = {
|
||||
"app_id": "cli_app",
|
||||
"app_secret": "secret",
|
||||
"domain": "feishu",
|
||||
"open_id": "ou_owner",
|
||||
}
|
||||
mock_probe.return_value = {"bot_name": "MyBot", "bot_open_id": "ou_bot"}
|
||||
|
||||
result = qr_register()
|
||||
assert result is not None
|
||||
assert result["app_id"] == "cli_app"
|
||||
assert result["app_secret"] == "secret"
|
||||
assert result["bot_name"] == "MyBot"
|
||||
mock_init.assert_called_once()
|
||||
mock_render.assert_called_once()
|
||||
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_returns_none_on_init_failure(self, mock_init):
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_init.side_effect = RuntimeError("not supported")
|
||||
result = qr_register()
|
||||
assert result is None
|
||||
|
||||
@patch("gateway.platforms.feishu._render_qr")
|
||||
@patch("gateway.platforms.feishu._poll_registration")
|
||||
@patch("gateway.platforms.feishu._begin_registration")
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_returns_none_on_poll_failure(
|
||||
self, mock_init, mock_begin, mock_poll, mock_render
|
||||
):
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_begin.return_value = {
|
||||
"device_code": "dc_123",
|
||||
"qr_url": "https://example.com/qr",
|
||||
"user_code": "ABCD",
|
||||
"interval": 1,
|
||||
"expire_in": 60,
|
||||
}
|
||||
mock_poll.return_value = None
|
||||
|
||||
result = qr_register()
|
||||
assert result is None
|
||||
|
||||
# -- Contract: expected errors → None, unexpected errors → propagate --
|
||||
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_returns_none_on_network_error(self, mock_init):
|
||||
"""URLError (network down) is an expected failure → None."""
|
||||
from gateway.platforms.feishu import qr_register
|
||||
from urllib.error import URLError
|
||||
|
||||
mock_init.side_effect = URLError("DNS resolution failed")
|
||||
result = qr_register()
|
||||
assert result is None
|
||||
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_returns_none_on_json_error(self, mock_init):
|
||||
"""Malformed server response is an expected failure → None."""
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_init.side_effect = json.JSONDecodeError("bad json", "", 0)
|
||||
result = qr_register()
|
||||
assert result is None
|
||||
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_propagates_unexpected_errors(self, mock_init):
|
||||
"""Bugs (e.g. AttributeError) must not be swallowed — they propagate."""
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_init.side_effect = AttributeError("some internal bug")
|
||||
with pytest.raises(AttributeError, match="some internal bug"):
|
||||
qr_register()
|
||||
|
||||
# -- Negative paths: partial/malformed server responses --
|
||||
|
||||
@patch("gateway.platforms.feishu._render_qr")
|
||||
@patch("gateway.platforms.feishu._begin_registration")
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_returns_none_when_begin_missing_device_code(
|
||||
self, mock_init, mock_begin, mock_render
|
||||
):
|
||||
"""Server returns begin response without device_code → RuntimeError → None."""
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_begin.side_effect = RuntimeError("Feishu registration did not return a device_code")
|
||||
result = qr_register()
|
||||
assert result is None
|
||||
|
||||
@patch("gateway.platforms.feishu.probe_bot")
|
||||
@patch("gateway.platforms.feishu._render_qr")
|
||||
@patch("gateway.platforms.feishu._poll_registration")
|
||||
@patch("gateway.platforms.feishu._begin_registration")
|
||||
@patch("gateway.platforms.feishu._init_registration")
|
||||
def test_qr_register_succeeds_even_when_probe_fails(
|
||||
self, mock_init, mock_begin, mock_poll, mock_render, mock_probe
|
||||
):
|
||||
"""Registration succeeds but probe fails → result with bot_name=None."""
|
||||
from gateway.platforms.feishu import qr_register
|
||||
|
||||
mock_begin.return_value = {
|
||||
"device_code": "dc_123",
|
||||
"qr_url": "https://example.com/qr",
|
||||
"user_code": "ABCD",
|
||||
"interval": 1,
|
||||
"expire_in": 60,
|
||||
}
|
||||
mock_poll.return_value = {
|
||||
"app_id": "cli_app",
|
||||
"app_secret": "secret",
|
||||
"domain": "feishu",
|
||||
"open_id": "ou_owner",
|
||||
}
|
||||
mock_probe.return_value = None # probe failed
|
||||
|
||||
result = qr_register()
|
||||
assert result is not None
|
||||
assert result["app_id"] == "cli_app"
|
||||
assert result["bot_name"] is None
|
||||
assert result["bot_open_id"] is None
|
||||
@@ -48,6 +48,7 @@ def _make_event(
|
||||
room_id="!room1:example.org",
|
||||
formatted_body=None,
|
||||
thread_id=None,
|
||||
mention_user_ids=None,
|
||||
):
|
||||
"""Create a fake room message event.
|
||||
|
||||
@@ -60,6 +61,9 @@ def _make_event(
|
||||
content["formatted_body"] = formatted_body
|
||||
content["format"] = "org.matrix.custom.html"
|
||||
|
||||
if mention_user_ids is not None:
|
||||
content["m.mentions"] = {"user_ids": mention_user_ids}
|
||||
|
||||
relates_to = {}
|
||||
if thread_id:
|
||||
relates_to["rel_type"] = "m.thread"
|
||||
@@ -108,6 +112,44 @@ class TestIsBotMentioned:
|
||||
# "hermesbot" should not match word-boundary check for "hermes"
|
||||
assert not self.adapter._is_bot_mentioned("hermesbot is here")
|
||||
|
||||
# m.mentions.user_ids — MSC3952 / Matrix v1.7 authoritative mentions
|
||||
# Ported from openclaw/openclaw#64796
|
||||
|
||||
def test_m_mentions_user_ids_authoritative(self):
|
||||
"""m.mentions.user_ids alone is sufficient — no body text needed."""
|
||||
assert self.adapter._is_bot_mentioned(
|
||||
"please reply", # no @hermes anywhere in body
|
||||
mention_user_ids=["@hermes:example.org"],
|
||||
)
|
||||
|
||||
def test_m_mentions_user_ids_with_body_mention(self):
|
||||
"""Both m.mentions and body mention — should still be True."""
|
||||
assert self.adapter._is_bot_mentioned(
|
||||
"hey @hermes:example.org help",
|
||||
mention_user_ids=["@hermes:example.org"],
|
||||
)
|
||||
|
||||
def test_m_mentions_user_ids_other_user_only(self):
|
||||
"""m.mentions with a different user — bot is NOT mentioned."""
|
||||
assert not self.adapter._is_bot_mentioned(
|
||||
"hello",
|
||||
mention_user_ids=["@alice:example.org"],
|
||||
)
|
||||
|
||||
def test_m_mentions_user_ids_empty_list(self):
|
||||
"""Empty user_ids list — falls through to text detection."""
|
||||
assert not self.adapter._is_bot_mentioned(
|
||||
"hello everyone",
|
||||
mention_user_ids=[],
|
||||
)
|
||||
|
||||
def test_m_mentions_user_ids_none(self):
|
||||
"""None mention_user_ids — falls through to text detection."""
|
||||
assert not self.adapter._is_bot_mentioned(
|
||||
"hello everyone",
|
||||
mention_user_ids=None,
|
||||
)
|
||||
|
||||
|
||||
class TestStripMention:
|
||||
def setup_method(self):
|
||||
@@ -176,6 +218,44 @@ async def test_require_mention_html_pill(monkeypatch):
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_mention_m_mentions_user_ids(monkeypatch):
|
||||
"""m.mentions.user_ids is authoritative per MSC3952 — no body mention needed.
|
||||
|
||||
Ported from openclaw/openclaw#64796.
|
||||
"""
|
||||
monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False)
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
# Body has NO mention, but m.mentions.user_ids includes the bot.
|
||||
event = _make_event(
|
||||
"please reply",
|
||||
mention_user_ids=["@hermes:example.org"],
|
||||
)
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_mention_m_mentions_other_user_ignored(monkeypatch):
|
||||
"""m.mentions.user_ids mentioning another user should NOT activate the bot."""
|
||||
monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False)
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
event = _make_event(
|
||||
"hey alice check this",
|
||||
mention_user_ids=["@alice:example.org"],
|
||||
)
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_mention_dm_always_responds(monkeypatch):
|
||||
"""DMs always respond regardless of mention setting."""
|
||||
|
||||
@@ -9,6 +9,8 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
safe_url_for_log,
|
||||
utf16_len,
|
||||
_prefix_within_utf16_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -448,3 +450,135 @@ class TestGetHumanDelay:
|
||||
with patch.dict(os.environ, env):
|
||||
delay = BasePlatformAdapter._get_human_delay()
|
||||
assert 0.1 <= delay <= 0.2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# utf16_len / _prefix_within_utf16_limit / truncate_message with len_fn
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ported from nearai/ironclaw#2304 — Telegram counts message length in UTF-16
|
||||
# code units, not Unicode code-points. Astral-plane characters (emoji, CJK
|
||||
# Extension B) are surrogate pairs: 1 Python char but 2 UTF-16 units.
|
||||
|
||||
|
||||
class TestUtf16Len:
|
||||
"""Verify the UTF-16 length helper."""
|
||||
|
||||
def test_ascii(self):
|
||||
assert utf16_len("hello") == 5
|
||||
|
||||
def test_bmp_cjk(self):
|
||||
# CJK ideographs in the BMP are 1 code unit each
|
||||
assert utf16_len("你好") == 2
|
||||
|
||||
def test_emoji_surrogate_pair(self):
|
||||
# 😀 (U+1F600) is outside BMP → 2 UTF-16 code units
|
||||
assert utf16_len("😀") == 2
|
||||
|
||||
def test_mixed(self):
|
||||
# "hi😀" = 2 + 2 = 4 UTF-16 units
|
||||
assert utf16_len("hi😀") == 4
|
||||
|
||||
def test_musical_symbol(self):
|
||||
# 𝄞 (U+1D11E) — Musical Symbol G Clef, surrogate pair
|
||||
assert utf16_len("𝄞") == 2
|
||||
|
||||
def test_empty(self):
|
||||
assert utf16_len("") == 0
|
||||
|
||||
|
||||
class TestPrefixWithinUtf16Limit:
|
||||
"""Verify UTF-16-aware prefix truncation."""
|
||||
|
||||
def test_fits_entirely(self):
|
||||
assert _prefix_within_utf16_limit("hello", 10) == "hello"
|
||||
|
||||
def test_ascii_truncation(self):
|
||||
result = _prefix_within_utf16_limit("hello world", 5)
|
||||
assert result == "hello"
|
||||
assert utf16_len(result) <= 5
|
||||
|
||||
def test_does_not_split_surrogate_pair(self):
|
||||
# "a😀b" = 1 + 2 + 1 = 4 UTF-16 units; limit 2 should give "a"
|
||||
result = _prefix_within_utf16_limit("a😀b", 2)
|
||||
assert result == "a"
|
||||
assert utf16_len(result) <= 2
|
||||
|
||||
def test_emoji_at_limit(self):
|
||||
# "😀" = 2 UTF-16 units; limit 2 should include it
|
||||
result = _prefix_within_utf16_limit("😀x", 2)
|
||||
assert result == "😀"
|
||||
|
||||
def test_all_emoji(self):
|
||||
msg = "😀" * 10 # 20 UTF-16 units
|
||||
result = _prefix_within_utf16_limit(msg, 6)
|
||||
assert result == "😀😀😀"
|
||||
assert utf16_len(result) == 6
|
||||
|
||||
def test_empty(self):
|
||||
assert _prefix_within_utf16_limit("", 5) == ""
|
||||
|
||||
|
||||
class TestTruncateMessageUtf16:
|
||||
"""Verify truncate_message respects UTF-16 lengths when len_fn=utf16_len."""
|
||||
|
||||
def test_short_emoji_message_no_split(self):
|
||||
"""A short message under the UTF-16 limit should not be split."""
|
||||
msg = "Hello 😀 world"
|
||||
chunks = BasePlatformAdapter.truncate_message(msg, 4096, len_fn=utf16_len)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == msg
|
||||
|
||||
def test_emoji_near_limit_triggers_split(self):
|
||||
"""A message at 4096 codepoints but >4096 UTF-16 units must split."""
|
||||
# 2049 emoji = 2049 codepoints but 4098 UTF-16 units → exceeds 4096
|
||||
msg = "😀" * 2049
|
||||
assert len(msg) == 2049 # Python len sees 2049 chars
|
||||
assert utf16_len(msg) == 4098 # but it's 4098 UTF-16 units
|
||||
|
||||
# Without UTF-16 awareness, this would NOT split (2049 < 4096)
|
||||
chunks_naive = BasePlatformAdapter.truncate_message(msg, 4096)
|
||||
assert len(chunks_naive) == 1, "Without len_fn, no split expected"
|
||||
|
||||
# With UTF-16 awareness, it MUST split
|
||||
chunks = BasePlatformAdapter.truncate_message(msg, 4096, len_fn=utf16_len)
|
||||
assert len(chunks) > 1, "With utf16_len, message should be split"
|
||||
|
||||
# Each chunk must fit within the UTF-16 limit
|
||||
for i, chunk in enumerate(chunks):
|
||||
assert utf16_len(chunk) <= 4096, (
|
||||
f"Chunk {i} exceeds 4096 UTF-16 units: {utf16_len(chunk)}"
|
||||
)
|
||||
|
||||
def test_each_utf16_chunk_within_limit(self):
|
||||
"""All chunks produced with utf16_len must fit the limit."""
|
||||
# Mix of BMP and astral-plane characters
|
||||
msg = ("Hello 😀 world 🎵 test 𝄞 " * 200).strip()
|
||||
max_len = 200
|
||||
chunks = BasePlatformAdapter.truncate_message(msg, max_len, len_fn=utf16_len)
|
||||
for i, chunk in enumerate(chunks):
|
||||
u16_len = utf16_len(chunk)
|
||||
assert u16_len <= max_len + 20, (
|
||||
f"Chunk {i} UTF-16 length {u16_len} exceeds {max_len}"
|
||||
)
|
||||
|
||||
def test_all_content_preserved(self):
|
||||
"""Splitting with utf16_len must not lose content."""
|
||||
words = ["emoji😀", "music🎵", "cjk你好", "plain"] * 100
|
||||
msg = " ".join(words)
|
||||
chunks = BasePlatformAdapter.truncate_message(msg, 200, len_fn=utf16_len)
|
||||
reassembled = " ".join(chunks)
|
||||
for word in words:
|
||||
assert word in reassembled, f"Word '{word}' lost during UTF-16 split"
|
||||
|
||||
def test_code_blocks_preserved_with_utf16(self):
|
||||
"""Code block fence handling should work with utf16_len too."""
|
||||
msg = "Before\n```python\n" + "x = '😀'\n" * 200 + "```\nAfter"
|
||||
chunks = BasePlatformAdapter.truncate_message(msg, 300, len_fn=utf16_len)
|
||||
assert len(chunks) > 1
|
||||
# Each chunk should have balanced fences
|
||||
for i, chunk in enumerate(chunks):
|
||||
fence_count = chunk.count("```")
|
||||
assert fence_count % 2 == 0, (
|
||||
f"Chunk {i} has unbalanced fences ({fence_count})"
|
||||
)
|
||||
|
||||
|
||||
215
tests/gateway/test_restart_notification.py
Normal file
215
tests/gateway/test_restart_notification.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Tests for /restart notification — the gateway notifies the requester on comeback."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import build_session_key
|
||||
from tests.gateway.restart_test_helpers import (
|
||||
make_restart_runner,
|
||||
make_restart_source,
|
||||
)
|
||||
|
||||
|
||||
# ── _handle_restart_command writes .restart_notify.json ──────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_command_writes_notify_file(tmp_path, monkeypatch):
|
||||
"""When /restart fires, the requester's routing info is persisted to disk."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
source = make_restart_source(chat_id="42")
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
result = await runner._handle_restart_command(event)
|
||||
assert "Restarting" in result
|
||||
|
||||
notify_path = tmp_path / ".restart_notify.json"
|
||||
assert notify_path.exists()
|
||||
data = json.loads(notify_path.read_text())
|
||||
assert data["platform"] == "telegram"
|
||||
assert data["chat_id"] == "42"
|
||||
assert "thread_id" not in data # no thread → omitted
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_command_uses_service_restart_under_systemd(tmp_path, monkeypatch):
|
||||
"""Under systemd (INVOCATION_ID set), /restart uses via_service=True."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setenv("INVOCATION_ID", "abc123")
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
source = make_restart_source(chat_id="42")
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
await runner._handle_restart_command(event)
|
||||
runner.request_restart.assert_called_once_with(detached=False, via_service=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_command_uses_detached_without_systemd(tmp_path, monkeypatch):
|
||||
"""Without systemd, /restart uses the detached subprocess approach."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("INVOCATION_ID", raising=False)
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
source = make_restart_source(chat_id="42")
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
await runner._handle_restart_command(event)
|
||||
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_command_preserves_thread_id(tmp_path, monkeypatch):
|
||||
"""Thread ID is saved when the requester is in a threaded chat."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
|
||||
source = make_restart_source(chat_id="99")
|
||||
source.thread_id = "topic_7"
|
||||
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="m2",
|
||||
)
|
||||
|
||||
await runner._handle_restart_command(event)
|
||||
|
||||
data = json.loads((tmp_path / ".restart_notify.json").read_text())
|
||||
assert data["thread_id"] == "topic_7"
|
||||
|
||||
|
||||
# ── _send_restart_notification ───────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_restart_notification_delivers_and_cleans_up(tmp_path, monkeypatch):
|
||||
"""On startup, the notification is sent and the file is removed."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
notify_path = tmp_path / ".restart_notify.json"
|
||||
notify_path.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"chat_id": "42",
|
||||
}))
|
||||
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.send = AsyncMock()
|
||||
|
||||
await runner._send_restart_notification()
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
call_args = adapter.send.call_args
|
||||
assert call_args[0][0] == "42" # chat_id
|
||||
assert "restarted" in call_args[0][1].lower()
|
||||
assert call_args[1].get("metadata") is None # no thread
|
||||
assert not notify_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_restart_notification_with_thread(tmp_path, monkeypatch):
|
||||
"""Thread ID is passed as metadata so the message lands in the right topic."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
notify_path = tmp_path / ".restart_notify.json"
|
||||
notify_path.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"chat_id": "99",
|
||||
"thread_id": "topic_7",
|
||||
}))
|
||||
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.send = AsyncMock()
|
||||
|
||||
await runner._send_restart_notification()
|
||||
|
||||
call_args = adapter.send.call_args
|
||||
assert call_args[1]["metadata"] == {"thread_id": "topic_7"}
|
||||
assert not notify_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_restart_notification_noop_when_no_file(tmp_path, monkeypatch):
|
||||
"""Nothing happens if there's no pending restart notification."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.send = AsyncMock()
|
||||
|
||||
await runner._send_restart_notification()
|
||||
|
||||
adapter.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_restart_notification_skips_when_adapter_missing(tmp_path, monkeypatch):
|
||||
"""If the requester's platform isn't connected, clean up without crashing."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
notify_path = tmp_path / ".restart_notify.json"
|
||||
notify_path.write_text(json.dumps({
|
||||
"platform": "discord", # runner only has telegram adapter
|
||||
"chat_id": "42",
|
||||
}))
|
||||
|
||||
runner, _adapter = make_restart_runner()
|
||||
|
||||
await runner._send_restart_notification()
|
||||
|
||||
# File cleaned up even though we couldn't send
|
||||
assert not notify_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_restart_notification_cleans_up_on_send_failure(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""If the adapter.send() raises, the file is still cleaned up."""
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
notify_path = tmp_path / ".restart_notify.json"
|
||||
notify_path.write_text(json.dumps({
|
||||
"platform": "telegram",
|
||||
"chat_id": "42",
|
||||
}))
|
||||
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.send = AsyncMock(side_effect=RuntimeError("network down"))
|
||||
|
||||
await runner._send_restart_notification()
|
||||
|
||||
assert not notify_path.exists() # cleaned up despite error
|
||||
@@ -396,6 +396,27 @@ class QueuedCommentaryAgent:
|
||||
}
|
||||
|
||||
|
||||
class VerboseAgent:
|
||||
"""Agent that emits a tool call with args whose JSON exceeds 200 chars."""
|
||||
LONG_CODE = "x" * 300
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback(
|
||||
"tool.started", "execute_code", None,
|
||||
{"code": self.LONG_CODE},
|
||||
)
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
async def _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
@@ -575,3 +596,45 @@ async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monke
|
||||
assert result["final_response"] == "final response 2"
|
||||
assert "I'll inspect the repo first." in sent_texts
|
||||
assert "final response 1" in sent_texts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
|
||||
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
|
||||
|
||||
Previously, verbose mode capped args at 200 chars when tool_preview_length
|
||||
was 0 (default). The user explicitly opted into verbose — show full detail.
|
||||
"""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
VerboseAgent,
|
||||
session_id="sess-verbose-no-truncate",
|
||||
config_data={"display": {"tool_progress": "verbose", "tool_preview_length": 0}},
|
||||
)
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
# The full 300-char 'x' string should be present, not truncated to 200
|
||||
all_content = " ".join(call["content"] for call in adapter.sent)
|
||||
all_content += " ".join(call["content"] for call in adapter.edits)
|
||||
assert VerboseAgent.LONG_CODE in all_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_respects_explicit_tool_preview_length(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is set to a positive value, verbose truncates to that."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
VerboseAgent,
|
||||
session_id="sess-verbose-explicit-cap",
|
||||
config_data={"display": {"tool_progress": "verbose", "tool_preview_length": 50}},
|
||||
)
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
all_content = " ".join(call["content"] for call in adapter.sent)
|
||||
all_content += " ".join(call["content"] for call in adapter.edits)
|
||||
# Should be truncated — full 300-char string NOT present
|
||||
assert VerboseAgent.LONG_CODE not in all_content
|
||||
# But should still contain the truncated portion with "..."
|
||||
assert "..." in all_content
|
||||
|
||||
@@ -552,6 +552,45 @@ class TestLoadTranscriptPreferLongerSource:
|
||||
assert result[0]["content"] == "db-q"
|
||||
|
||||
|
||||
class TestSessionStoreSwitchSession:
|
||||
"""Regression coverage for gateway /resume session switching semantics."""
|
||||
|
||||
def test_switch_session_reopens_target_session_in_db(self, tmp_path):
|
||||
from hermes_state import SessionDB
|
||||
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path / "sessions", config=config)
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
store._db = db
|
||||
store._loaded = True
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.FEISHU,
|
||||
chat_id="chat-1",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
user_name="tester",
|
||||
)
|
||||
current_entry = store.get_or_create_session(source)
|
||||
current_session_id = current_entry.session_id
|
||||
|
||||
target_session_id = "old_session_abc"
|
||||
db.create_session(target_session_id, source="feishu", user_id="user-1")
|
||||
db.end_session(target_session_id, end_reason="user_exit")
|
||||
assert db.get_session(target_session_id)["ended_at"] is not None
|
||||
|
||||
switched = store.switch_session(current_entry.session_key, target_session_id)
|
||||
|
||||
assert switched is not None
|
||||
assert switched.session_id == target_session_id
|
||||
assert db.get_session(current_session_id)["end_reason"] == "session_switch"
|
||||
resumed = db.get_session(target_session_id)
|
||||
assert resumed["ended_at"] is None
|
||||
assert resumed["end_reason"] is None
|
||||
db.close()
|
||||
|
||||
|
||||
class TestWhatsAppDMSessionKeyConsistency:
|
||||
"""Regression: all session-key construction must go through build_session_key
|
||||
so DMs are isolated by chat_id across platforms."""
|
||||
|
||||
@@ -60,7 +60,8 @@ def _make_runner():
|
||||
|
||||
def _make_event(text="hello", chat_id="12345"):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
|
||||
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm",
|
||||
user_id="u1",
|
||||
)
|
||||
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
|
||||
|
||||
@@ -192,7 +193,8 @@ async def test_command_messages_do_not_leave_sentinel():
|
||||
_handle_message. They must NOT leave a sentinel behind."""
|
||||
runner = _make_runner()
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"
|
||||
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm",
|
||||
user_id="u1",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="/help", message_type=MessageType.TEXT, source=source
|
||||
@@ -240,9 +242,7 @@ async def test_stop_during_sentinel_force_cleans_session():
|
||||
stop_event = _make_event(text="/stop")
|
||||
result = await runner._handle_message(stop_event)
|
||||
assert result is not None, "/stop during sentinel should return a message"
|
||||
assert "force-stopped" in result.lower() or "unlocked" in result.lower()
|
||||
|
||||
# Sentinel must be cleaned up
|
||||
assert "stopped" in result.lower()
|
||||
assert session_key not in runner._running_agents, (
|
||||
"/stop must remove sentinel so the session is unlocked"
|
||||
)
|
||||
@@ -268,7 +268,7 @@ async def test_stop_hard_kills_running_agent():
|
||||
forever — showing 'writing...' but never producing output."""
|
||||
runner = _make_runner()
|
||||
session_key = build_session_key(
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm", user_id="u1")
|
||||
)
|
||||
|
||||
# Simulate a running (possibly hung) agent
|
||||
@@ -289,7 +289,7 @@ async def test_stop_hard_kills_running_agent():
|
||||
|
||||
# Must return a confirmation
|
||||
assert result is not None
|
||||
assert "force-stopped" in result.lower() or "unlocked" in result.lower()
|
||||
assert "stopped" in result.lower()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -301,7 +301,7 @@ async def test_stop_clears_pending_messages():
|
||||
queued during the run must be discarded."""
|
||||
runner = _make_runner()
|
||||
session_key = build_session_key(
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
|
||||
SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm", user_id="u1")
|
||||
)
|
||||
|
||||
fake_agent = MagicMock()
|
||||
|
||||
279
tests/gateway/test_setup_feishu.py
Normal file
279
tests/gateway/test_setup_feishu.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tests for _setup_feishu() in hermes_cli/gateway.py.
|
||||
|
||||
Verifies that the interactive setup writes env vars that correctly drive the
|
||||
Feishu adapter: credentials, connection mode, DM policy, and group policy.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_setup_feishu(
|
||||
*,
|
||||
qr_result=None,
|
||||
prompt_yes_no_responses=None,
|
||||
prompt_choice_responses=None,
|
||||
prompt_responses=None,
|
||||
existing_env=None,
|
||||
):
|
||||
"""Run _setup_feishu() with mocked I/O and return the env vars that were saved.
|
||||
|
||||
Returns a dict of {env_var_name: value} for all save_env_value calls.
|
||||
"""
|
||||
existing_env = existing_env or {}
|
||||
prompt_yes_no_responses = list(prompt_yes_no_responses or [True])
|
||||
# QR path: method(0), dm(0), group(0) — 3 choices (no connection mode)
|
||||
# Manual path: method(1), domain(0), connection(0), dm(0), group(0) — 5 choices
|
||||
prompt_choice_responses = list(prompt_choice_responses or [0, 0, 0])
|
||||
prompt_responses = list(prompt_responses or [""])
|
||||
|
||||
saved_env = {}
|
||||
|
||||
def mock_save(name, value):
|
||||
saved_env[name] = value
|
||||
|
||||
def mock_get(name):
|
||||
return existing_env.get(name, "")
|
||||
|
||||
with patch("hermes_cli.gateway.save_env_value", side_effect=mock_save), \
|
||||
patch("hermes_cli.gateway.get_env_value", side_effect=mock_get), \
|
||||
patch("hermes_cli.gateway.prompt_yes_no", side_effect=prompt_yes_no_responses), \
|
||||
patch("hermes_cli.gateway.prompt_choice", side_effect=prompt_choice_responses), \
|
||||
patch("hermes_cli.gateway.prompt", side_effect=prompt_responses), \
|
||||
patch("hermes_cli.gateway.print_info"), \
|
||||
patch("hermes_cli.gateway.print_success"), \
|
||||
patch("hermes_cli.gateway.print_warning"), \
|
||||
patch("hermes_cli.gateway.print_error"), \
|
||||
patch("hermes_cli.gateway.color", side_effect=lambda t, c: t), \
|
||||
patch("gateway.platforms.feishu.qr_register", return_value=qr_result):
|
||||
|
||||
from hermes_cli.gateway import _setup_feishu
|
||||
_setup_feishu()
|
||||
|
||||
return saved_env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QR scan-to-create path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSetupFeishuQrPath:
|
||||
"""Tests for the QR scan-to-create happy path."""
|
||||
|
||||
def test_qr_success_saves_core_credentials(self):
|
||||
env = _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test",
|
||||
"app_secret": "secret_test",
|
||||
"domain": "feishu",
|
||||
"open_id": "ou_owner",
|
||||
"bot_name": "TestBot",
|
||||
"bot_open_id": "ou_bot",
|
||||
},
|
||||
prompt_yes_no_responses=[True], # Start QR
|
||||
prompt_choice_responses=[0, 0, 0], # method=QR, dm=pairing, group=open
|
||||
prompt_responses=[""], # home channel: skip
|
||||
)
|
||||
assert env["FEISHU_APP_ID"] == "cli_test"
|
||||
assert env["FEISHU_APP_SECRET"] == "secret_test"
|
||||
assert env["FEISHU_DOMAIN"] == "feishu"
|
||||
|
||||
def test_qr_success_does_not_persist_bot_identity(self):
|
||||
"""Bot identity is discovered at runtime by _hydrate_bot_identity — not persisted
|
||||
in env, so it stays fresh if the user renames the bot later."""
|
||||
env = _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test",
|
||||
"app_secret": "secret_test",
|
||||
"domain": "feishu",
|
||||
"open_id": "ou_owner",
|
||||
"bot_name": "TestBot",
|
||||
"bot_open_id": "ou_bot",
|
||||
},
|
||||
prompt_yes_no_responses=[True],
|
||||
prompt_choice_responses=[0, 0, 0],
|
||||
prompt_responses=[""],
|
||||
)
|
||||
assert "FEISHU_BOT_OPEN_ID" not in env
|
||||
assert "FEISHU_BOT_NAME" not in env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSetupFeishuConnectionMode:
|
||||
"""Connection mode: QR always websocket, manual path lets user choose."""
|
||||
|
||||
def test_qr_path_defaults_to_websocket(self):
|
||||
env = _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
|
||||
"open_id": None, "bot_name": None, "bot_open_id": None,
|
||||
},
|
||||
prompt_choice_responses=[0, 0, 0], # method=QR, dm=pairing, group=open
|
||||
prompt_responses=[""],
|
||||
)
|
||||
assert env["FEISHU_CONNECTION_MODE"] == "websocket"
|
||||
|
||||
@patch("gateway.platforms.feishu.probe_bot", return_value=None)
|
||||
def test_manual_path_websocket(self, _mock_probe):
|
||||
env = _run_setup_feishu(
|
||||
qr_result=None,
|
||||
prompt_choice_responses=[1, 0, 0, 0, 0], # method=manual, domain=feishu, connection=ws, dm=pairing, group=open
|
||||
prompt_responses=["cli_manual", "secret_manual", ""], # app_id, app_secret, home_channel
|
||||
)
|
||||
assert env["FEISHU_CONNECTION_MODE"] == "websocket"
|
||||
|
||||
@patch("gateway.platforms.feishu.probe_bot", return_value=None)
|
||||
def test_manual_path_webhook(self, _mock_probe):
|
||||
env = _run_setup_feishu(
|
||||
qr_result=None,
|
||||
prompt_choice_responses=[1, 0, 1, 0, 0], # method=manual, domain=feishu, connection=webhook, dm=pairing, group=open
|
||||
prompt_responses=["cli_manual", "secret_manual", ""], # app_id, app_secret, home_channel
|
||||
)
|
||||
assert env["FEISHU_CONNECTION_MODE"] == "webhook"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DM security policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSetupFeishuDmPolicy:
|
||||
"""DM policy must use platform-scoped FEISHU_ALLOW_ALL_USERS, not the global flag."""
|
||||
|
||||
def _run_with_dm_choice(self, dm_choice_idx, prompt_responses=None):
|
||||
return _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
|
||||
"open_id": "ou_owner", "bot_name": None, "bot_open_id": None,
|
||||
},
|
||||
prompt_yes_no_responses=[True],
|
||||
prompt_choice_responses=[0, dm_choice_idx, 0], # method=QR, dm=<choice>, group=open
|
||||
prompt_responses=prompt_responses or [""],
|
||||
)
|
||||
|
||||
def test_pairing_sets_feishu_allow_all_false(self):
|
||||
env = self._run_with_dm_choice(0)
|
||||
assert env["FEISHU_ALLOW_ALL_USERS"] == "false"
|
||||
assert env["FEISHU_ALLOWED_USERS"] == ""
|
||||
assert "GATEWAY_ALLOW_ALL_USERS" not in env
|
||||
|
||||
def test_allow_all_sets_feishu_allow_all_true(self):
|
||||
env = self._run_with_dm_choice(1)
|
||||
assert env["FEISHU_ALLOW_ALL_USERS"] == "true"
|
||||
assert env["FEISHU_ALLOWED_USERS"] == ""
|
||||
assert "GATEWAY_ALLOW_ALL_USERS" not in env
|
||||
|
||||
def test_allowlist_sets_feishu_allow_all_false_with_list(self):
|
||||
env = self._run_with_dm_choice(2, prompt_responses=["ou_user1,ou_user2", ""])
|
||||
assert env["FEISHU_ALLOW_ALL_USERS"] == "false"
|
||||
assert env["FEISHU_ALLOWED_USERS"] == "ou_user1,ou_user2"
|
||||
assert "GATEWAY_ALLOW_ALL_USERS" not in env
|
||||
|
||||
def test_allowlist_prepopulates_with_scan_owner_open_id(self):
|
||||
"""When open_id is available from QR scan, it should be the default allowlist value."""
|
||||
# We return the owner's open_id from prompt (+ empty home channel).
|
||||
env = self._run_with_dm_choice(2, prompt_responses=["ou_owner", ""])
|
||||
assert env["FEISHU_ALLOWED_USERS"] == "ou_owner"
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSetupFeishuGroupPolicy:
|
||||
|
||||
def test_open_with_mention(self):
|
||||
env = _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
|
||||
"open_id": None, "bot_name": None, "bot_open_id": None,
|
||||
},
|
||||
prompt_yes_no_responses=[True],
|
||||
prompt_choice_responses=[0, 0, 0], # method=QR, dm=pairing, group=open
|
||||
prompt_responses=[""],
|
||||
)
|
||||
assert env["FEISHU_GROUP_POLICY"] == "open"
|
||||
|
||||
def test_disabled(self):
|
||||
env = _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test", "app_secret": "s", "domain": "feishu",
|
||||
"open_id": None, "bot_name": None, "bot_open_id": None,
|
||||
},
|
||||
prompt_yes_no_responses=[True],
|
||||
prompt_choice_responses=[0, 0, 1], # method=QR, dm=pairing, group=disabled
|
||||
prompt_responses=[""],
|
||||
)
|
||||
assert env["FEISHU_GROUP_POLICY"] == "disabled"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter integration: env vars → FeishuAdapterSettings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSetupFeishuAdapterIntegration:
|
||||
"""Verify that env vars written by _setup_feishu() produce a valid adapter config.
|
||||
|
||||
This bridges the gap between 'setup wrote the right env vars' and
|
||||
'the adapter will actually initialize correctly from those vars'.
|
||||
"""
|
||||
|
||||
def _make_env_from_setup(self, dm_idx=0, group_idx=0):
|
||||
"""Run _setup_feishu via QR path and return the env vars it would write."""
|
||||
return _run_setup_feishu(
|
||||
qr_result={
|
||||
"app_id": "cli_test_app",
|
||||
"app_secret": "test_secret_value",
|
||||
"domain": "feishu",
|
||||
"open_id": "ou_owner",
|
||||
"bot_name": "IntegrationBot",
|
||||
"bot_open_id": "ou_bot_integration",
|
||||
},
|
||||
prompt_yes_no_responses=[True],
|
||||
prompt_choice_responses=[0, dm_idx, group_idx], # method=QR, dm, group
|
||||
prompt_responses=[""],
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_qr_env_produces_valid_adapter_settings(self):
|
||||
"""QR setup → adapter initializes with websocket mode."""
|
||||
env = self._make_env_from_setup()
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
assert adapter._app_id == "cli_test_app"
|
||||
assert adapter._app_secret == "test_secret_value"
|
||||
assert adapter._domain_name == "feishu"
|
||||
assert adapter._connection_mode == "websocket"
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_open_dm_env_sets_correct_adapter_state(self):
|
||||
"""Setup with 'allow all DMs' → adapter sees allow-all flag."""
|
||||
env = self._make_env_from_setup(dm_idx=1)
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
# Verify adapter initializes without error and env var is correct.
|
||||
FeishuAdapter(PlatformConfig())
|
||||
assert os.getenv("FEISHU_ALLOW_ALL_USERS") == "true"
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_group_open_env_sets_adapter_group_policy(self):
|
||||
"""Setup with 'open groups' → adapter group_policy is 'open'."""
|
||||
env = self._make_env_from_setup(group_idx=0)
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
assert adapter._group_policy == "open"
|
||||
@@ -209,6 +209,33 @@ class TestScopedLocks:
|
||||
assert payload["pid"] == os.getpid()
|
||||
assert payload["metadata"]["platform"] == "telegram"
|
||||
|
||||
def test_acquire_scoped_lock_recovers_empty_lock_file(self, tmp_path, monkeypatch):
|
||||
"""Empty lock file (0 bytes) left by a crashed process should be treated as stale."""
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
lock_path = tmp_path / "locks" / "slack-app-token-2bb80d537b1da3e3.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_path.write_text("") # simulate crash between O_CREAT and json.dump
|
||||
|
||||
acquired, existing = status.acquire_scoped_lock("slack-app-token", "secret", metadata={"platform": "slack"})
|
||||
|
||||
assert acquired is True
|
||||
payload = json.loads(lock_path.read_text())
|
||||
assert payload["pid"] == os.getpid()
|
||||
assert payload["metadata"]["platform"] == "slack"
|
||||
|
||||
def test_acquire_scoped_lock_recovers_corrupt_lock_file(self, tmp_path, monkeypatch):
|
||||
"""Lock file with invalid JSON should be treated as stale."""
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
lock_path = tmp_path / "locks" / "slack-app-token-2bb80d537b1da3e3.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_path.write_text("{truncated") # simulate partial write
|
||||
|
||||
acquired, existing = status.acquire_scoped_lock("slack-app-token", "secret", metadata={"platform": "slack"})
|
||||
|
||||
assert acquired is True
|
||||
payload = json.loads(lock_path.read_text())
|
||||
assert payload["pid"] == os.getpid()
|
||||
|
||||
def test_release_scoped_lock_only_removes_current_owner(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def _make_runner():
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_does_not_priority_interrupt_photo_followup():
|
||||
runner = _make_runner()
|
||||
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
|
||||
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm", user_id="u1")
|
||||
session_key = build_session_key(source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[session_key] = running_agent
|
||||
|
||||
@@ -121,7 +121,7 @@ class TestSendWithReplyToMode:
|
||||
adapter = adapter_factory(reply_to_mode="off")
|
||||
adapter._bot = MagicMock()
|
||||
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestSendWithReplyToMode:
|
||||
adapter = adapter_factory(reply_to_mode="first")
|
||||
adapter._bot = MagicMock()
|
||||
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestSendWithReplyToMode:
|
||||
adapter = adapter_factory(reply_to_mode="all")
|
||||
adapter._bot = MagicMock()
|
||||
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
@@ -162,7 +162,7 @@ class TestSendWithReplyToMode:
|
||||
adapter = adapter_factory(reply_to_mode="all")
|
||||
adapter._bot = MagicMock()
|
||||
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to=None)
|
||||
|
||||
@@ -175,7 +175,7 @@ class TestSendWithReplyToMode:
|
||||
adapter = adapter_factory(reply_to_mode="first")
|
||||
adapter._bot = MagicMock()
|
||||
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1))
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
|
||||
@@ -417,6 +417,7 @@ class TestDiscordPlayTtsSkip:
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
@@ -702,13 +703,18 @@ class TestVoiceChannelCommands:
|
||||
mock_adapter.join_voice_channel = AsyncMock(return_value=True)
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
mock_adapter._voice_text_channels = {}
|
||||
mock_adapter._voice_sources = {}
|
||||
mock_adapter._voice_input_callback = None
|
||||
event = self._make_discord_event()
|
||||
event.source.chat_type = "group"
|
||||
event.source.chat_name = "Hermes Server / #general"
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "joined" in result.lower()
|
||||
assert "General" in result
|
||||
assert runner._voice_mode["123"] == "all"
|
||||
assert mock_adapter._voice_sources[111]["chat_id"] == "123"
|
||||
assert mock_adapter._voice_sources[111]["chat_type"] == "group"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_failure(self, runner):
|
||||
@@ -815,6 +821,7 @@ class TestVoiceChannelCommands:
|
||||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_adapter._voice_sources = {}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
@@ -828,12 +835,45 @@ class TestVoiceChannelCommands:
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "channel"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_reuses_bound_source_metadata(self, runner):
|
||||
"""Voice input should share the linked text channel session metadata."""
|
||||
from gateway.config import Platform
|
||||
|
||||
bound_source = SessionSource(
|
||||
chat_id="123",
|
||||
chat_name="Hermes Server / #general",
|
||||
chat_type="group",
|
||||
user_id="user1",
|
||||
user_name="user1",
|
||||
platform=Platform.DISCORD,
|
||||
)
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_adapter._voice_sources = {111: bound_source.to_dict()}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
mock_adapter.handle_message = AsyncMock()
|
||||
runner.adapters[Platform.DISCORD] = mock_adapter
|
||||
|
||||
await runner._handle_voice_channel_input(111, 42, "Hello from VC")
|
||||
|
||||
mock_adapter.handle_message.assert_called_once()
|
||||
event = mock_adapter.handle_message.call_args[0][0]
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.chat_name == "Hermes Server / #general"
|
||||
assert event.source.user_id == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_posts_transcript_in_text_channel(self, runner):
|
||||
"""Voice input sends transcript message to text channel."""
|
||||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_adapter._voice_sources = {}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
@@ -892,6 +932,7 @@ class TestDiscordVoiceChannelMethods:
|
||||
adapter._client = MagicMock()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
@@ -926,6 +967,7 @@ class TestDiscordVoiceChannelMethods:
|
||||
mock_vc.disconnect = AsyncMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
adapter._voice_sources[111] = {"chat_id": "123", "chat_type": "group"}
|
||||
|
||||
mock_receiver = MagicMock()
|
||||
adapter._voice_receivers[111] = mock_receiver
|
||||
@@ -944,6 +986,7 @@ class TestDiscordVoiceChannelMethods:
|
||||
mock_timeout.cancel.assert_called_once()
|
||||
assert 111 not in adapter._voice_clients
|
||||
assert 111 not in adapter._voice_text_channels
|
||||
assert 111 not in adapter._voice_sources
|
||||
assert 111 not in adapter._voice_receivers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1670,6 +1713,7 @@ class TestVoiceTimeoutCleansRunnerState:
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
@@ -1759,6 +1803,7 @@ class TestPlaybackTimeout:
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
@@ -1939,6 +1984,7 @@ class TestVoiceChannelAwareness:
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.user = SimpleNamespace(id=99999, name="HermesBot")
|
||||
@@ -2408,6 +2454,7 @@ class TestVoiceTTSPlayback:
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
return adapter
|
||||
|
||||
@@ -2587,6 +2634,7 @@ class TestUDPKeepalive:
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
||||
|
||||
141
tests/gateway/test_weak_credential_guard.py
Normal file
141
tests/gateway/test_weak_credential_guard.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for gateway weak credential rejection at startup.
|
||||
|
||||
Ported from openclaw/openclaw#64586: rejects known-weak placeholder
|
||||
tokens at gateway startup instead of letting them silently fail
|
||||
against platform APIs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig, Platform, _validate_gateway_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: create a minimal GatewayConfig with one enabled platform
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_gateway_config(platform, token, enabled=True, **extra_kwargs):
|
||||
"""Create a minimal GatewayConfig-like object for validation testing."""
|
||||
from gateway.config import GatewayConfig
|
||||
|
||||
config = GatewayConfig(platforms={})
|
||||
pconfig = PlatformConfig(enabled=enabled, token=token, **extra_kwargs)
|
||||
config.platforms[platform] = pconfig
|
||||
return config
|
||||
|
||||
|
||||
def _validate_and_return(config):
|
||||
"""Call _validate_gateway_config and return the config (mutated in place)."""
|
||||
_validate_gateway_config(config)
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: platform token placeholder rejection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlatformTokenPlaceholderGuard:
|
||||
"""Verify that _validate_gateway_config disables platforms with placeholder tokens."""
|
||||
|
||||
def test_rejects_triple_asterisk(self, caplog):
|
||||
"""'***' is the .env.example placeholder — should be rejected."""
|
||||
config = _make_gateway_config(Platform.TELEGRAM, "***")
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert config.platforms[Platform.TELEGRAM].enabled is False
|
||||
assert "placeholder" in caplog.text.lower()
|
||||
|
||||
def test_rejects_changeme(self, caplog):
|
||||
config = _make_gateway_config(Platform.DISCORD, "changeme")
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert config.platforms[Platform.DISCORD].enabled is False
|
||||
|
||||
def test_rejects_your_api_key(self, caplog):
|
||||
config = _make_gateway_config(Platform.SLACK, "your_api_key")
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert config.platforms[Platform.SLACK].enabled is False
|
||||
|
||||
def test_rejects_placeholder(self, caplog):
|
||||
config = _make_gateway_config(Platform.MATRIX, "placeholder")
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert config.platforms[Platform.MATRIX].enabled is False
|
||||
|
||||
def test_accepts_real_token(self, caplog):
|
||||
"""A real-looking bot token should pass validation."""
|
||||
config = _make_gateway_config(
|
||||
Platform.TELEGRAM, "7123456789:AAHdqTcvCH1vGWJxfSeOfSAs0K5PALDsaw"
|
||||
)
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert config.platforms[Platform.TELEGRAM].enabled is True
|
||||
assert "placeholder" not in caplog.text.lower()
|
||||
|
||||
def test_accepts_empty_token_without_error(self, caplog):
|
||||
"""Empty tokens get a warning (existing behavior), not a placeholder error."""
|
||||
config = _make_gateway_config(Platform.TELEGRAM, "")
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_validate_and_return(config)
|
||||
# Empty token doesn't trigger placeholder rejection — enabled stays True
|
||||
# (the existing empty-token warning is separate)
|
||||
assert config.platforms[Platform.TELEGRAM].enabled is True
|
||||
|
||||
def test_disabled_platform_not_checked(self, caplog):
|
||||
"""Disabled platforms should not be validated."""
|
||||
config = _make_gateway_config(Platform.TELEGRAM, "***", enabled=False)
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert "placeholder" not in caplog.text.lower()
|
||||
|
||||
def test_rejects_whitespace_padded_placeholder(self, caplog):
|
||||
"""Whitespace-padded placeholders should still be caught."""
|
||||
config = _make_gateway_config(Platform.TELEGRAM, " *** ")
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_validate_and_return(config)
|
||||
assert config.platforms[Platform.TELEGRAM].enabled is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration test: API server placeholder key on network-accessible host
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAPIServerPlaceholderKeyGuard:
|
||||
"""Verify that the API server rejects placeholder keys on network hosts."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_wildcard_with_placeholder_key(self):
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
|
||||
adapter = APIServerAdapter(
|
||||
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "changeme"})
|
||||
)
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_wildcard_with_asterisk_key(self):
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
|
||||
adapter = APIServerAdapter(
|
||||
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "***"})
|
||||
)
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
def test_allows_loopback_with_placeholder_key(self):
|
||||
"""Loopback with a placeholder key is fine — not network-exposed."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.platforms.base import is_network_accessible
|
||||
|
||||
adapter = APIServerAdapter(
|
||||
PlatformConfig(enabled=True, extra={"host": "127.0.0.1", "key": "changeme"})
|
||||
)
|
||||
# On loopback the placeholder guard doesn't fire
|
||||
assert is_network_accessible(adapter._host) is False
|
||||
@@ -30,7 +30,7 @@ class TestWeixinFormatting:
|
||||
|
||||
assert (
|
||||
adapter.format_message(content)
|
||||
== "【Title】\n\n**Plan**\n\nUse **bold** and [docs](https://example.com)."
|
||||
== "【Title】\n\n**Plan**\n\nUse **bold** and docs (https://example.com)."
|
||||
)
|
||||
|
||||
def test_format_message_rewrites_markdown_tables(self):
|
||||
@@ -374,3 +374,149 @@ class TestWeixinRemoteMediaSafety:
|
||||
assert "Blocked unsafe URL" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected ValueError for unsafe URL")
|
||||
|
||||
|
||||
class TestWeixinMarkdownLinks:
|
||||
"""Markdown links should be converted to plaintext since WeChat can't render them."""
|
||||
|
||||
def test_format_message_converts_markdown_links_to_plain_text(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "Check [the docs](https://example.com) and [GitHub](https://github.com) for details"
|
||||
assert (
|
||||
adapter.format_message(content)
|
||||
== "Check the docs (https://example.com) and GitHub (https://github.com) for details"
|
||||
)
|
||||
|
||||
def test_format_message_preserves_links_inside_code_blocks(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "See below:\n\n```\n[link](https://example.com)\n```\n\nDone."
|
||||
result = adapter.format_message(content)
|
||||
assert "[link](https://example.com)" in result
|
||||
|
||||
|
||||
class TestWeixinBlankMessagePrevention:
|
||||
"""Regression tests for the blank-bubble bugs.
|
||||
|
||||
Three separate guards now prevent a blank WeChat message from ever being
|
||||
dispatched:
|
||||
|
||||
1. ``_split_text_for_weixin_delivery("")`` returns ``[]`` — not ``[""]``.
|
||||
2. ``send()`` filters out empty/whitespace-only chunks before calling
|
||||
``_send_text_chunk``.
|
||||
3. ``_send_message()`` raises ``ValueError`` for empty text as a last-resort
|
||||
safety net.
|
||||
"""
|
||||
|
||||
def test_split_text_returns_empty_list_for_empty_string(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter._split_text("") == []
|
||||
|
||||
def test_split_text_returns_empty_list_for_empty_string_split_per_line(self):
|
||||
adapter = WeixinAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"account_id": "acct",
|
||||
"token": "test-tok",
|
||||
"split_multiline_messages": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert adapter._split_text("") == []
|
||||
|
||||
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
|
||||
def test_send_empty_content_does_not_call_send_message(self, send_message_mock):
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
|
||||
|
||||
result = asyncio.run(adapter.send("wxid_test123", ""))
|
||||
# Empty content → no chunks → no _send_message calls
|
||||
assert result.success is True
|
||||
send_message_mock.assert_not_awaited()
|
||||
|
||||
def test_send_message_rejects_empty_text(self):
|
||||
"""_send_message raises ValueError for empty/whitespace text."""
|
||||
import pytest
|
||||
with pytest.raises(ValueError, match="text must not be empty"):
|
||||
asyncio.run(
|
||||
weixin._send_message(
|
||||
AsyncMock(),
|
||||
base_url="https://example.com",
|
||||
token="tok",
|
||||
to="wxid_test",
|
||||
text="",
|
||||
context_token=None,
|
||||
client_id="cid",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestWeixinStreamingCursorSuppression:
|
||||
"""WeChat doesn't support message editing — cursor must be suppressed."""
|
||||
|
||||
def test_supports_message_editing_is_false(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.SUPPORTS_MESSAGE_EDITING is False
|
||||
|
||||
|
||||
class TestWeixinMediaBuilder:
|
||||
"""Media builder uses base64(hex_key), not base64(raw_bytes) for aes_key."""
|
||||
|
||||
def test_image_builder_aes_key_is_base64_of_hex(self):
|
||||
import base64
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder("photo.jpg")
|
||||
assert media_type == weixin.MEDIA_IMAGE
|
||||
|
||||
fake_hex_key = "0123456789abcdef0123456789abcdef"
|
||||
expected_aes = base64.b64encode(fake_hex_key.encode("ascii")).decode("ascii")
|
||||
item = builder(
|
||||
encrypt_query_param="eq",
|
||||
aes_key_for_api=expected_aes,
|
||||
ciphertext_size=1024,
|
||||
plaintext_size=1000,
|
||||
filename="photo.jpg",
|
||||
rawfilemd5="abc123",
|
||||
)
|
||||
assert item["image_item"]["media"]["aes_key"] == expected_aes
|
||||
|
||||
def test_video_builder_includes_md5(self):
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder("clip.mp4")
|
||||
assert media_type == weixin.MEDIA_VIDEO
|
||||
|
||||
item = builder(
|
||||
encrypt_query_param="eq",
|
||||
aes_key_for_api="fakekey",
|
||||
ciphertext_size=2048,
|
||||
plaintext_size=2000,
|
||||
filename="clip.mp4",
|
||||
rawfilemd5="deadbeef",
|
||||
)
|
||||
assert item["video_item"]["video_md5"] == "deadbeef"
|
||||
|
||||
def test_voice_builder_for_audio_files(self):
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder("note.mp3")
|
||||
assert media_type == weixin.MEDIA_VOICE
|
||||
|
||||
item = builder(
|
||||
encrypt_query_param="eq",
|
||||
aes_key_for_api="fakekey",
|
||||
ciphertext_size=512,
|
||||
plaintext_size=500,
|
||||
filename="note.mp3",
|
||||
rawfilemd5="abc",
|
||||
)
|
||||
assert item["type"] == weixin.ITEM_VOICE
|
||||
assert "voice_item" in item
|
||||
|
||||
def test_voice_builder_for_silk_files(self):
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder("recording.silk")
|
||||
assert media_type == weixin.MEDIA_VOICE
|
||||
|
||||
271
tests/gateway/test_whatsapp_formatting.py
Normal file
271
tests/gateway/test_whatsapp_formatting.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Tests for WhatsApp message formatting and chunking.
|
||||
|
||||
Covers:
|
||||
- format_message(): markdown → WhatsApp syntax conversion
|
||||
- send(): message chunking for long responses
|
||||
- MAX_MESSAGE_LENGTH: practical UX limit
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a WhatsAppAdapter with test attributes (bypass __init__)."""
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
adapter = WhatsAppAdapter.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = MagicMock()
|
||||
adapter.config.extra = {}
|
||||
adapter._bridge_port = 3000
|
||||
adapter._bridge_script = "/tmp/test-bridge.js"
|
||||
adapter._session_path = MagicMock()
|
||||
adapter._bridge_log_fh = None
|
||||
adapter._bridge_log = None
|
||||
adapter._bridge_process = None
|
||||
adapter._reply_prefix = None
|
||||
adapter._running = True
|
||||
adapter._message_handler = None
|
||||
adapter._fatal_error_code = None
|
||||
adapter._fatal_error_message = None
|
||||
adapter._fatal_error_retryable = True
|
||||
adapter._fatal_error_handler = None
|
||||
adapter._active_sessions = {}
|
||||
adapter._pending_messages = {}
|
||||
adapter._background_tasks = set()
|
||||
adapter._auto_tts_disabled_chats = set()
|
||||
adapter._message_queue = asyncio.Queue()
|
||||
adapter._http_session = MagicMock()
|
||||
adapter._mention_patterns = []
|
||||
return adapter
|
||||
|
||||
|
||||
class _AsyncCM:
|
||||
"""Minimal async context manager returning a fixed value."""
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.value
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# format_message tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFormatMessage:
|
||||
"""WhatsApp markdown conversion."""
|
||||
|
||||
def test_bold_double_asterisk(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.format_message("**hello**") == "*hello*"
|
||||
|
||||
def test_bold_double_underscore(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.format_message("__hello__") == "*hello*"
|
||||
|
||||
def test_strikethrough(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.format_message("~~deleted~~") == "~deleted~"
|
||||
|
||||
def test_headers_converted_to_bold(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.format_message("# Title") == "*Title*"
|
||||
assert adapter.format_message("## Subtitle") == "*Subtitle*"
|
||||
assert adapter.format_message("### Deep") == "*Deep*"
|
||||
|
||||
def test_links_converted(self):
|
||||
adapter = _make_adapter()
|
||||
result = adapter.format_message("[click here](https://example.com)")
|
||||
assert result == "click here (https://example.com)"
|
||||
|
||||
def test_code_blocks_protected(self):
|
||||
"""Code blocks should not have their content reformatted."""
|
||||
adapter = _make_adapter()
|
||||
content = "before **bold** ```python\n**not bold**\n``` after **bold**"
|
||||
result = adapter.format_message(content)
|
||||
assert "```python\n**not bold**\n```" in result
|
||||
assert result.startswith("before *bold*")
|
||||
assert result.endswith("after *bold*")
|
||||
|
||||
def test_inline_code_protected(self):
|
||||
"""Inline code should not have its content reformatted."""
|
||||
adapter = _make_adapter()
|
||||
content = "use `**raw**` here"
|
||||
result = adapter.format_message(content)
|
||||
assert "`**raw**`" in result
|
||||
assert result.startswith("use ")
|
||||
|
||||
def test_empty_content(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.format_message("") == ""
|
||||
assert adapter.format_message(None) is None
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter.format_message("hello world") == "hello world"
|
||||
|
||||
def test_already_whatsapp_italic(self):
|
||||
"""Single *italic* should pass through unchanged."""
|
||||
adapter = _make_adapter()
|
||||
# After bold conversion, *text* is WhatsApp italic
|
||||
assert adapter.format_message("*italic*") == "*italic*"
|
||||
|
||||
def test_multiline_mixed(self):
|
||||
adapter = _make_adapter()
|
||||
content = "# Header\n\n**Bold text** and ~~strike~~\n\n```\ncode\n```"
|
||||
result = adapter.format_message(content)
|
||||
assert "*Header*" in result
|
||||
assert "*Bold text*" in result
|
||||
assert "~strike~" in result
|
||||
assert "```\ncode\n```" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MAX_MESSAGE_LENGTH tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageLimits:
|
||||
"""WhatsApp message length limits."""
|
||||
|
||||
def test_max_message_length_is_practical(self):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
assert WhatsAppAdapter.MAX_MESSAGE_LENGTH == 4096
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() chunking tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendChunking:
|
||||
"""WhatsApp send() splits long messages into chunks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_message_single_send(self):
|
||||
adapter = _make_adapter()
|
||||
resp = MagicMock(status=200)
|
||||
resp.json = AsyncMock(return_value={"messageId": "msg1"})
|
||||
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
|
||||
|
||||
result = await adapter.send("chat1", "short message")
|
||||
assert result.success
|
||||
# Only one call to bridge /send
|
||||
assert adapter._http_session.post.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_message_chunked(self):
|
||||
adapter = _make_adapter()
|
||||
resp = MagicMock(status=200)
|
||||
resp.json = AsyncMock(return_value={"messageId": "msg1"})
|
||||
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
|
||||
|
||||
# Create a message longer than MAX_MESSAGE_LENGTH (4096)
|
||||
long_msg = "a " * 3000 # ~6000 chars
|
||||
|
||||
result = await adapter.send("chat1", long_msg)
|
||||
assert result.success
|
||||
# Should have made multiple calls
|
||||
assert adapter._http_session.post.call_count > 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_message_no_send(self):
|
||||
adapter = _make_adapter()
|
||||
result = await adapter.send("chat1", "")
|
||||
assert result.success
|
||||
assert adapter._http_session.post.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_no_send(self):
|
||||
adapter = _make_adapter()
|
||||
result = await adapter.send("chat1", " \n ")
|
||||
assert result.success
|
||||
assert adapter._http_session.post.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_applied_before_send(self):
|
||||
"""Markdown should be converted to WhatsApp format before sending."""
|
||||
adapter = _make_adapter()
|
||||
resp = MagicMock(status=200)
|
||||
resp.json = AsyncMock(return_value={"messageId": "msg1"})
|
||||
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
|
||||
|
||||
await adapter.send("chat1", "**bold text**")
|
||||
|
||||
# Check the payload sent to the bridge
|
||||
call_args = adapter._http_session.post.call_args
|
||||
payload = call_args.kwargs.get("json") or call_args[1].get("json")
|
||||
assert payload["message"] == "*bold text*"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_to_only_on_first_chunk(self):
|
||||
"""reply_to should only be set on the first chunk."""
|
||||
adapter = _make_adapter()
|
||||
resp = MagicMock(status=200)
|
||||
resp.json = AsyncMock(return_value={"messageId": "msg1"})
|
||||
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
|
||||
|
||||
long_msg = "word " * 2000 # ~10000 chars, multiple chunks
|
||||
|
||||
await adapter.send("chat1", long_msg, reply_to="orig123")
|
||||
|
||||
calls = adapter._http_session.post.call_args_list
|
||||
assert len(calls) > 1
|
||||
|
||||
# First chunk should have replyTo
|
||||
first_payload = calls[0].kwargs.get("json") or calls[0][1].get("json")
|
||||
assert first_payload.get("replyTo") == "orig123"
|
||||
|
||||
# Subsequent chunks should NOT have replyTo
|
||||
for call in calls[1:]:
|
||||
payload = call.kwargs.get("json") or call[1].get("json")
|
||||
assert "replyTo" not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_error_returns_failure(self):
|
||||
adapter = _make_adapter()
|
||||
resp = MagicMock(status=500)
|
||||
resp.text = AsyncMock(return_value="Internal Server Error")
|
||||
adapter._http_session.post = MagicMock(return_value=_AsyncCM(resp))
|
||||
|
||||
result = await adapter.send("chat1", "hello")
|
||||
assert not result.success
|
||||
assert "Internal Server Error" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_connected_returns_failure(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._running = False
|
||||
|
||||
result = await adapter.send("chat1", "hello")
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# display_config tier classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWhatsAppTier:
|
||||
"""WhatsApp should be classified as TIER_MEDIUM."""
|
||||
|
||||
def test_whatsapp_streaming_follows_global(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
# TIER_MEDIUM has streaming: None (follow global), not False
|
||||
assert resolve_display_setting({}, "whatsapp", "streaming") is None
|
||||
|
||||
def test_whatsapp_tool_progress_is_new(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
assert resolve_display_setting({}, "whatsapp", "tool_progress") == "new"
|
||||
@@ -23,9 +23,9 @@ from hermes_cli.auth import (
|
||||
get_auth_status,
|
||||
AuthError,
|
||||
KIMI_CODE_BASE_URL,
|
||||
_try_gh_cli_token,
|
||||
_resolve_kimi_base_url,
|
||||
)
|
||||
from hermes_cli.copilot_auth import _try_gh_cli_token
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -68,7 +68,7 @@ class TestProviderRegistry:
|
||||
def test_copilot_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["copilot"]
|
||||
assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")
|
||||
assert pconfig.base_url_env_var == ""
|
||||
assert pconfig.base_url_env_var == "COPILOT_API_BASE_URL"
|
||||
|
||||
def test_kimi_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["kimi-coding"]
|
||||
@@ -381,13 +381,13 @@ class TestResolveApiKeyProviderCredentials:
|
||||
assert creds["source"] == "gh auth token"
|
||||
|
||||
def test_try_gh_cli_token_uses_homebrew_path_when_not_on_path(self, monkeypatch):
|
||||
monkeypatch.setattr("hermes_cli.auth.shutil.which", lambda command: None)
|
||||
monkeypatch.setattr("hermes_cli.copilot_auth.shutil.which", lambda command: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.os.path.isfile",
|
||||
"hermes_cli.copilot_auth.os.path.isfile",
|
||||
lambda path: path == "/opt/homebrew/bin/gh",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.os.access",
|
||||
"hermes_cli.copilot_auth.os.access",
|
||||
lambda path, mode: path == "/opt/homebrew/bin/gh" and mode == os.X_OK,
|
||||
)
|
||||
|
||||
@@ -397,11 +397,11 @@ class TestResolveApiKeyProviderCredentials:
|
||||
returncode = 0
|
||||
stdout = "gh-cli-secret\n"
|
||||
|
||||
def _fake_run(cmd, capture_output, text, timeout):
|
||||
def _fake_run(cmd, **kwargs):
|
||||
calls.append(cmd)
|
||||
return _Result()
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth.subprocess.run", _fake_run)
|
||||
monkeypatch.setattr("hermes_cli.copilot_auth.subprocess.run", _fake_run)
|
||||
|
||||
assert _try_gh_cli_token() == "gh-cli-secret"
|
||||
assert calls == [["/opt/homebrew/bin/gh", "auth", "token"]]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Tests for hermes backup and import commands."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import zipfile
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
@@ -232,6 +234,44 @@ class TestBackup:
|
||||
assert len(zips) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_backup_zip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateBackupZip:
|
||||
def _make_zip(self, zip_path: Path, filenames: list[str]) -> None:
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name in filenames:
|
||||
zf.writestr(name, "dummy")
|
||||
|
||||
def test_state_db_passes(self, tmp_path):
|
||||
"""A zip containing state.db is accepted as a valid Hermes backup."""
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_zip(zip_path, ["state.db", "sessions/abc.json"])
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert ok, reason
|
||||
|
||||
def test_old_wrong_db_name_fails(self, tmp_path):
|
||||
"""A zip with only hermes_state.db (old wrong name) is rejected."""
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
zip_path = tmp_path / "old.zip"
|
||||
self._make_zip(zip_path, ["hermes_state.db", "memory_store.db"])
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert not ok
|
||||
|
||||
def test_config_yaml_passes(self, tmp_path):
|
||||
"""A zip containing config.yaml is accepted (existing behaviour preserved)."""
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_zip(zip_path, ["config.yaml", "skills/x/SKILL.md"])
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert ok, reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -895,3 +935,181 @@ class TestProfileRestoration:
|
||||
|
||||
# Files should still be restored even if wrappers can't be created
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLite safe copy tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeCopyDb:
|
||||
def test_copies_valid_database(self, tmp_path):
|
||||
from hermes_cli.backup import _safe_copy_db
|
||||
src = tmp_path / "test.db"
|
||||
dst = tmp_path / "copy.db"
|
||||
|
||||
conn = sqlite3.connect(str(src))
|
||||
conn.execute("CREATE TABLE t (x INTEGER)")
|
||||
conn.execute("INSERT INTO t VALUES (42)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
result = _safe_copy_db(src, dst)
|
||||
assert result is True
|
||||
|
||||
conn = sqlite3.connect(str(dst))
|
||||
rows = conn.execute("SELECT x FROM t").fetchall()
|
||||
conn.close()
|
||||
assert rows == [(42,)]
|
||||
|
||||
def test_copies_wal_mode_database(self, tmp_path):
|
||||
from hermes_cli.backup import _safe_copy_db
|
||||
src = tmp_path / "wal.db"
|
||||
dst = tmp_path / "copy.db"
|
||||
|
||||
conn = sqlite3.connect(str(src))
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("CREATE TABLE t (x TEXT)")
|
||||
conn.execute("INSERT INTO t VALUES ('wal-test')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
result = _safe_copy_db(src, dst)
|
||||
assert result is True
|
||||
|
||||
conn = sqlite3.connect(str(dst))
|
||||
rows = conn.execute("SELECT x FROM t").fetchall()
|
||||
conn.close()
|
||||
assert rows == [("wal-test",)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quick state snapshot tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuickSnapshot:
|
||||
@pytest.fixture
|
||||
def hermes_home(self, tmp_path):
|
||||
"""Create a fake HERMES_HOME with critical state files."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
(home / "config.yaml").write_text("model:\n provider: openrouter\n")
|
||||
(home / ".env").write_text("OPENROUTER_API_KEY=test-key-123\n")
|
||||
(home / "auth.json").write_text('{"providers": {}}\n')
|
||||
(home / "cron").mkdir()
|
||||
(home / "cron" / "jobs.json").write_text('{"jobs": []}\n')
|
||||
|
||||
# Real SQLite database
|
||||
db_path = home / "state.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute("CREATE TABLE sessions (id TEXT PRIMARY KEY, data TEXT)")
|
||||
conn.execute("INSERT INTO sessions VALUES ('s1', 'hello world')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return home
|
||||
|
||||
def test_creates_snapshot(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot
|
||||
snap_id = create_quick_snapshot(hermes_home=hermes_home)
|
||||
assert snap_id is not None
|
||||
snap_dir = hermes_home / "state-snapshots" / snap_id
|
||||
assert snap_dir.is_dir()
|
||||
assert (snap_dir / "manifest.json").exists()
|
||||
|
||||
def test_label_in_id(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot
|
||||
snap_id = create_quick_snapshot(label="before-upgrade", hermes_home=hermes_home)
|
||||
assert "before-upgrade" in snap_id
|
||||
|
||||
def test_state_db_safely_copied(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot
|
||||
snap_id = create_quick_snapshot(hermes_home=hermes_home)
|
||||
db_copy = hermes_home / "state-snapshots" / snap_id / "state.db"
|
||||
assert db_copy.exists()
|
||||
|
||||
conn = sqlite3.connect(str(db_copy))
|
||||
rows = conn.execute("SELECT * FROM sessions").fetchall()
|
||||
conn.close()
|
||||
assert len(rows) == 1
|
||||
assert rows[0] == ("s1", "hello world")
|
||||
|
||||
def test_copies_nested_files(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot
|
||||
snap_id = create_quick_snapshot(hermes_home=hermes_home)
|
||||
assert (hermes_home / "state-snapshots" / snap_id / "cron" / "jobs.json").exists()
|
||||
|
||||
def test_missing_files_skipped(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot
|
||||
snap_id = create_quick_snapshot(hermes_home=hermes_home)
|
||||
with open(hermes_home / "state-snapshots" / snap_id / "manifest.json") as f:
|
||||
meta = json.load(f)
|
||||
# gateway_state.json etc. don't exist in fixture
|
||||
assert "gateway_state.json" not in meta["files"]
|
||||
|
||||
def test_empty_home_returns_none(self, tmp_path):
|
||||
from hermes_cli.backup import create_quick_snapshot
|
||||
empty = tmp_path / "empty"
|
||||
empty.mkdir()
|
||||
assert create_quick_snapshot(hermes_home=empty) is None
|
||||
|
||||
def test_list_snapshots(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot, list_quick_snapshots
|
||||
id1 = create_quick_snapshot(label="first", hermes_home=hermes_home)
|
||||
id2 = create_quick_snapshot(label="second", hermes_home=hermes_home)
|
||||
|
||||
snaps = list_quick_snapshots(hermes_home=hermes_home)
|
||||
assert len(snaps) == 2
|
||||
assert snaps[0]["id"] == id2 # most recent first
|
||||
assert snaps[1]["id"] == id1
|
||||
|
||||
def test_list_limit(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot, list_quick_snapshots
|
||||
for i in range(5):
|
||||
create_quick_snapshot(label=f"s{i}", hermes_home=hermes_home)
|
||||
snaps = list_quick_snapshots(limit=3, hermes_home=hermes_home)
|
||||
assert len(snaps) == 3
|
||||
|
||||
def test_restore_config(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot, restore_quick_snapshot
|
||||
snap_id = create_quick_snapshot(hermes_home=hermes_home)
|
||||
|
||||
(hermes_home / "config.yaml").write_text("model:\n provider: anthropic\n")
|
||||
assert "anthropic" in (hermes_home / "config.yaml").read_text()
|
||||
|
||||
result = restore_quick_snapshot(snap_id, hermes_home=hermes_home)
|
||||
assert result is True
|
||||
assert "openrouter" in (hermes_home / "config.yaml").read_text()
|
||||
|
||||
def test_restore_state_db(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot, restore_quick_snapshot
|
||||
snap_id = create_quick_snapshot(hermes_home=hermes_home)
|
||||
|
||||
conn = sqlite3.connect(str(hermes_home / "state.db"))
|
||||
conn.execute("INSERT INTO sessions VALUES ('s2', 'new')")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
restore_quick_snapshot(snap_id, hermes_home=hermes_home)
|
||||
|
||||
conn = sqlite3.connect(str(hermes_home / "state.db"))
|
||||
rows = conn.execute("SELECT * FROM sessions").fetchall()
|
||||
conn.close()
|
||||
assert len(rows) == 1
|
||||
|
||||
def test_restore_nonexistent(self, hermes_home):
|
||||
from hermes_cli.backup import restore_quick_snapshot
|
||||
assert restore_quick_snapshot("nonexistent", hermes_home=hermes_home) is False
|
||||
|
||||
def test_auto_prune(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot, list_quick_snapshots, _QUICK_DEFAULT_KEEP
|
||||
for i in range(_QUICK_DEFAULT_KEEP + 5):
|
||||
create_quick_snapshot(label=f"snap-{i:03d}", hermes_home=hermes_home)
|
||||
snaps = list_quick_snapshots(limit=100, hermes_home=hermes_home)
|
||||
assert len(snaps) <= _QUICK_DEFAULT_KEEP
|
||||
|
||||
def test_manual_prune(self, hermes_home):
|
||||
from hermes_cli.backup import create_quick_snapshot, prune_quick_snapshots, list_quick_snapshots
|
||||
for i in range(10):
|
||||
create_quick_snapshot(label=f"s{i}", hermes_home=hermes_home)
|
||||
deleted = prune_quick_snapshots(keep=3, hermes_home=hermes_home)
|
||||
assert deleted == 7
|
||||
assert len(list_quick_snapshots(hermes_home=hermes_home)) == 3
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for hermes claw commands."""
|
||||
|
||||
from argparse import Namespace
|
||||
import subprocess
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -197,6 +198,11 @@ class TestClawCommand:
|
||||
class TestCmdMigrate:
|
||||
"""Test the migrate command handler."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_openclaw_running(self):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=[]):
|
||||
yield
|
||||
|
||||
def test_error_when_source_missing(self, tmp_path, capsys):
|
||||
args = Namespace(
|
||||
source=str(tmp_path / "nonexistent"),
|
||||
@@ -626,3 +632,120 @@ class TestPrintMigrationReport:
|
||||
claw_mod._print_migration_report(report, dry_run=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "Nothing to migrate" in captured.out
|
||||
|
||||
|
||||
class TestDetectOpenclawProcesses:
|
||||
def test_returns_match_when_pgrep_finds_openclaw(self):
|
||||
with patch.object(claw_mod, "sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
with patch.object(claw_mod, "subprocess") as mock_subprocess:
|
||||
# systemd check misses, pgrep finds openclaw
|
||||
mock_subprocess.run.side_effect = [
|
||||
MagicMock(returncode=1, stdout=""), # systemctl
|
||||
MagicMock(returncode=0, stdout="1234\n"), # pgrep
|
||||
]
|
||||
mock_subprocess.TimeoutExpired = subprocess.TimeoutExpired
|
||||
result = claw_mod._detect_openclaw_processes()
|
||||
assert len(result) == 1
|
||||
assert "1234" in result[0]
|
||||
|
||||
def test_returns_empty_when_pgrep_finds_nothing(self):
|
||||
with patch.object(claw_mod, "sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
with patch.object(claw_mod, "subprocess") as mock_subprocess:
|
||||
mock_subprocess.run.side_effect = [
|
||||
MagicMock(returncode=1, stdout=""), # systemctl (not found)
|
||||
MagicMock(returncode=1, stdout=""), # pgrep
|
||||
]
|
||||
mock_subprocess.TimeoutExpired = subprocess.TimeoutExpired
|
||||
result = claw_mod._detect_openclaw_processes()
|
||||
assert result == []
|
||||
|
||||
def test_detects_systemd_service(self):
|
||||
with patch.object(claw_mod, "sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
with patch.object(claw_mod, "subprocess") as mock_subprocess:
|
||||
mock_subprocess.run.side_effect = [
|
||||
MagicMock(returncode=0, stdout="active\n"), # systemctl
|
||||
MagicMock(returncode=1, stdout=""), # pgrep
|
||||
]
|
||||
mock_subprocess.TimeoutExpired = subprocess.TimeoutExpired
|
||||
result = claw_mod._detect_openclaw_processes()
|
||||
assert len(result) == 1
|
||||
assert "systemd" in result[0]
|
||||
|
||||
def test_returns_match_on_windows_when_openclaw_exe_running(self):
|
||||
with patch.object(claw_mod, "sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
with patch.object(claw_mod, "subprocess") as mock_subprocess:
|
||||
mock_subprocess.run.side_effect = [
|
||||
MagicMock(returncode=0, stdout="openclaw.exe 1234 Console 1 45,056 K\n"),
|
||||
]
|
||||
result = claw_mod._detect_openclaw_processes()
|
||||
assert len(result) >= 1
|
||||
assert any("openclaw.exe" in r for r in result)
|
||||
|
||||
def test_returns_match_on_windows_when_node_exe_has_openclaw_in_cmdline(self):
|
||||
with patch.object(claw_mod, "sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
with patch.object(claw_mod, "subprocess") as mock_subprocess:
|
||||
mock_subprocess.run.side_effect = [
|
||||
MagicMock(returncode=0, stdout=""), # tasklist openclaw.exe
|
||||
MagicMock(returncode=0, stdout=""), # tasklist clawd.exe
|
||||
MagicMock(returncode=0, stdout="1234\n"), # PowerShell
|
||||
]
|
||||
result = claw_mod._detect_openclaw_processes()
|
||||
assert len(result) >= 1
|
||||
assert any("node.exe" in r for r in result)
|
||||
|
||||
def test_returns_empty_on_windows_when_nothing_found(self):
|
||||
with patch.object(claw_mod, "sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
with patch.object(claw_mod, "subprocess") as mock_subprocess:
|
||||
mock_subprocess.run.side_effect = [
|
||||
MagicMock(returncode=0, stdout=""),
|
||||
MagicMock(returncode=0, stdout=""),
|
||||
MagicMock(returncode=0, stdout=""),
|
||||
]
|
||||
result = claw_mod._detect_openclaw_processes()
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestWarnIfOpenclawRunning:
|
||||
def test_noop_when_not_running(self, capsys):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=[]):
|
||||
claw_mod._warn_if_openclaw_running(auto_yes=False)
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == ""
|
||||
|
||||
def test_warns_and_exits_when_running_and_user_declines(self, capsys):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
|
||||
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
|
||||
with patch.object(claw_mod.sys.stdin, "isatty", return_value=True):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
claw_mod._warn_if_openclaw_running(auto_yes=False)
|
||||
assert exc_info.value.code == 0
|
||||
captured = capsys.readouterr()
|
||||
assert "OpenClaw appears to be running" in captured.out
|
||||
|
||||
def test_warns_and_continues_when_running_and_user_accepts(self, capsys):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
|
||||
with patch.object(claw_mod, "prompt_yes_no", return_value=True):
|
||||
with patch.object(claw_mod.sys.stdin, "isatty", return_value=True):
|
||||
claw_mod._warn_if_openclaw_running(auto_yes=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "OpenClaw appears to be running" in captured.out
|
||||
|
||||
def test_warns_and_continues_in_auto_yes_mode(self, capsys):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
|
||||
claw_mod._warn_if_openclaw_running(auto_yes=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "OpenClaw appears to be running" in captured.out
|
||||
|
||||
def test_warns_and_continues_in_non_interactive_session(self, capsys):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=["openclaw process(es) (PIDs: 1234)"]):
|
||||
with patch.object(claw_mod.sys.stdin, "isatty", return_value=False):
|
||||
claw_mod._warn_if_openclaw_running(auto_yes=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "OpenClaw appears to be running" in captured.out
|
||||
assert "Non-interactive session" in captured.out
|
||||
|
||||
@@ -10,6 +10,7 @@ from hermes_cli.config import (
|
||||
DEFAULT_CONFIG,
|
||||
get_hermes_home,
|
||||
ensure_hermes_home,
|
||||
get_compatible_custom_providers,
|
||||
load_config,
|
||||
load_env,
|
||||
migrate_config,
|
||||
@@ -424,6 +425,146 @@ class TestAnthropicTokenMigration:
|
||||
assert load_env().get("ANTHROPIC_TOKEN") == "current-token"
|
||||
|
||||
|
||||
class TestCustomProviderCompatibility:
|
||||
"""Custom provider compatibility across legacy and v12+ config schemas."""
|
||||
|
||||
def test_v11_upgrade_moves_custom_providers_into_providers(self, tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"_config_version": 11,
|
||||
"model": {
|
||||
"default": "openai/gpt-5.4",
|
||||
"provider": "openrouter",
|
||||
},
|
||||
"custom_providers": [
|
||||
{
|
||||
"name": "OpenAI Direct",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
"api_mode": "codex_responses",
|
||||
"model": "gpt-5-mini",
|
||||
}
|
||||
],
|
||||
"fallback_providers": [
|
||||
{"provider": "openai-direct", "model": "gpt-5-mini"}
|
||||
],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 17
|
||||
assert raw["providers"]["openai-direct"] == {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
"default_model": "gpt-5-mini",
|
||||
"name": "OpenAI Direct",
|
||||
"transport": "codex_responses",
|
||||
}
|
||||
# custom_providers removed by migration — runtime reads via compat layer
|
||||
assert "custom_providers" not in raw
|
||||
|
||||
def test_providers_dict_resolves_at_runtime(self, tmp_path):
|
||||
"""After migration deleted custom_providers, get_compatible_custom_providers
|
||||
still finds entries from the providers dict."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"_config_version": 17,
|
||||
"providers": {
|
||||
"openai-direct": {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
"default_model": "gpt-5-mini",
|
||||
"name": "OpenAI Direct",
|
||||
"transport": "codex_responses",
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
compatible = get_compatible_custom_providers()
|
||||
|
||||
assert len(compatible) == 1
|
||||
assert compatible[0]["name"] == "OpenAI Direct"
|
||||
assert compatible[0]["base_url"] == "https://api.openai.com/v1"
|
||||
assert compatible[0]["provider_key"] == "openai-direct"
|
||||
assert compatible[0]["api_mode"] == "codex_responses"
|
||||
|
||||
def test_compatible_custom_providers_prefers_api_then_url_then_base_url(self, tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"_config_version": 17,
|
||||
"providers": {
|
||||
"my-provider": {
|
||||
"name": "My Provider",
|
||||
"api": "https://api.example.com/v1",
|
||||
"url": "https://url.example.com/v1",
|
||||
"base_url": "https://base.example.com/v1",
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
compatible = get_compatible_custom_providers()
|
||||
|
||||
assert compatible == [
|
||||
{
|
||||
"name": "My Provider",
|
||||
"base_url": "https://api.example.com/v1",
|
||||
"provider_key": "my-provider",
|
||||
}
|
||||
]
|
||||
|
||||
def test_dedup_across_legacy_and_providers(self, tmp_path):
|
||||
"""Same name+url in both schemas should not produce duplicates."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"_config_version": 17,
|
||||
"custom_providers": [
|
||||
{
|
||||
"name": "OpenAI Direct",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "legacy-key",
|
||||
}
|
||||
],
|
||||
"providers": {
|
||||
"openai-direct": {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "new-key",
|
||||
"name": "OpenAI Direct",
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
compatible = get_compatible_custom_providers()
|
||||
|
||||
assert len(compatible) == 1
|
||||
# Legacy entry wins (read first)
|
||||
assert compatible[0]["api_key"] == "legacy-key"
|
||||
|
||||
|
||||
class TestInterimAssistantMessageConfig:
|
||||
"""Test the explicit gateway interim-message config gate."""
|
||||
|
||||
@@ -441,6 +582,6 @@ class TestInterimAssistantMessageConfig:
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 16
|
||||
assert raw["_config_version"] == 17
|
||||
assert raw["display"]["tool_progress"] == "off"
|
||||
assert raw["display"]["interim_assistant_messages"] is True
|
||||
|
||||
@@ -12,49 +12,10 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import (
|
||||
_is_inside_container,
|
||||
get_container_exec_info,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _is_inside_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_is_inside_container_dockerenv():
|
||||
"""Detects /.dockerenv marker file."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/.dockerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_containerenv():
|
||||
"""Detects Podman's /run/.containerenv marker."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/run/.containerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_cgroup_docker():
|
||||
"""Detects 'docker' in /proc/1/cgroup."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", create=True) as mock_open:
|
||||
mock_open.return_value.__enter__ = lambda s: s
|
||||
mock_open.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_open.return_value.read = MagicMock(
|
||||
return_value="12:memory:/docker/abc123\n"
|
||||
)
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_false_on_host():
|
||||
"""Returns False when none of the container indicators are present."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", side_effect=OSError("no such file")):
|
||||
assert _is_inside_container() is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# get_container_exec_info
|
||||
# =============================================================================
|
||||
@@ -81,7 +42,7 @@ def container_env(tmp_path, monkeypatch):
|
||||
|
||||
def test_get_container_exec_info_returns_metadata(container_env):
|
||||
"""Reads .container-mode and returns all fields including exec_user."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
with patch("hermes_constants.is_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
@@ -93,7 +54,7 @@ def test_get_container_exec_info_returns_metadata(container_env):
|
||||
|
||||
def test_get_container_exec_info_none_inside_container(container_env):
|
||||
"""Returns None when we're already inside a container."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=True):
|
||||
with patch("hermes_constants.is_container", return_value=True):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
@@ -106,7 +67,7 @@ def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("HERMES_DEV", raising=False)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
with patch("hermes_constants.is_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
@@ -116,7 +77,7 @@ def test_get_container_exec_info_skipped_when_hermes_dev(container_env, monkeypa
|
||||
"""Returns None when HERMES_DEV=1 is set (dev mode bypass)."""
|
||||
monkeypatch.setenv("HERMES_DEV", "1")
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
with patch("hermes_constants.is_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
@@ -126,7 +87,7 @@ def test_get_container_exec_info_not_skipped_when_hermes_dev_zero(container_env,
|
||||
"""HERMES_DEV=0 does NOT trigger bypass — only '1' does."""
|
||||
monkeypatch.setenv("HERMES_DEV", "0")
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
with patch("hermes_constants.is_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
@@ -143,7 +104,7 @@ def test_get_container_exec_info_defaults():
|
||||
"# minimal file with no keys\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
with patch("hermes_constants.is_container", return_value=False), \
|
||||
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_DEV", None)
|
||||
@@ -165,7 +126,7 @@ def test_get_container_exec_info_docker_backend(container_env):
|
||||
"hermes_bin=/opt/hermes/bin/hermes\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
with patch("hermes_constants.is_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info["backend"] == "docker"
|
||||
@@ -176,7 +137,7 @@ def test_get_container_exec_info_docker_backend(container_env):
|
||||
|
||||
def test_get_container_exec_info_crashes_on_permission_error(container_env):
|
||||
"""PermissionError propagates instead of being silently swallowed."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
with patch("hermes_constants.is_container", return_value=False), \
|
||||
patch("builtins.open", side_effect=PermissionError("permission denied")):
|
||||
with pytest.raises(PermissionError):
|
||||
get_container_exec_info()
|
||||
|
||||
461
tests/hermes_cli/test_debug.py
Normal file
461
tests/hermes_cli/test_debug.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""Tests for ``hermes debug`` CLI command and debug utilities."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_home(tmp_path, monkeypatch):
|
||||
"""Set up an isolated HERMES_HOME with minimal logs."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
# Create log files
|
||||
logs_dir = home / "logs"
|
||||
logs_dir.mkdir()
|
||||
(logs_dir / "agent.log").write_text(
|
||||
"2026-04-12 17:00:00 INFO agent: session started\n"
|
||||
"2026-04-12 17:00:01 INFO tools.terminal: running ls\n"
|
||||
"2026-04-12 17:00:02 WARNING agent: high token usage\n"
|
||||
)
|
||||
(logs_dir / "errors.log").write_text(
|
||||
"2026-04-12 17:00:05 ERROR gateway.run: connection lost\n"
|
||||
)
|
||||
(logs_dir / "gateway.log").write_text(
|
||||
"2026-04-12 17:00:10 INFO gateway.run: started\n"
|
||||
)
|
||||
|
||||
return home
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for upload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUploadPasteRs:
|
||||
"""Test paste.rs upload path."""
|
||||
|
||||
def test_upload_paste_rs_success(self):
|
||||
from hermes_cli.debug import _upload_paste_rs
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = b"https://paste.rs/abc123\n"
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("hermes_cli.debug.urllib.request.urlopen", return_value=mock_resp):
|
||||
url = _upload_paste_rs("hello world")
|
||||
|
||||
assert url == "https://paste.rs/abc123"
|
||||
|
||||
def test_upload_paste_rs_bad_response(self):
|
||||
from hermes_cli.debug import _upload_paste_rs
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = b"<html>error</html>"
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("hermes_cli.debug.urllib.request.urlopen", return_value=mock_resp):
|
||||
with pytest.raises(ValueError, match="Unexpected response"):
|
||||
_upload_paste_rs("test")
|
||||
|
||||
def test_upload_paste_rs_network_error(self):
|
||||
from hermes_cli.debug import _upload_paste_rs
|
||||
|
||||
with patch(
|
||||
"hermes_cli.debug.urllib.request.urlopen",
|
||||
side_effect=urllib.error.URLError("connection refused"),
|
||||
):
|
||||
with pytest.raises(urllib.error.URLError):
|
||||
_upload_paste_rs("test")
|
||||
|
||||
|
||||
class TestUploadDpasteCom:
|
||||
"""Test dpaste.com fallback upload path."""
|
||||
|
||||
def test_upload_dpaste_com_success(self):
|
||||
from hermes_cli.debug import _upload_dpaste_com
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = b"https://dpaste.com/ABCDEFG\n"
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("hermes_cli.debug.urllib.request.urlopen", return_value=mock_resp):
|
||||
url = _upload_dpaste_com("hello world", expiry_days=7)
|
||||
|
||||
assert url == "https://dpaste.com/ABCDEFG"
|
||||
|
||||
|
||||
class TestUploadToPastebin:
|
||||
"""Test the combined upload with fallback."""
|
||||
|
||||
def test_tries_paste_rs_first(self):
|
||||
from hermes_cli.debug import upload_to_pastebin
|
||||
|
||||
with patch("hermes_cli.debug._upload_paste_rs",
|
||||
return_value="https://paste.rs/test") as prs:
|
||||
url = upload_to_pastebin("content")
|
||||
|
||||
assert url == "https://paste.rs/test"
|
||||
prs.assert_called_once()
|
||||
|
||||
def test_falls_back_to_dpaste_com(self):
|
||||
from hermes_cli.debug import upload_to_pastebin
|
||||
|
||||
with patch("hermes_cli.debug._upload_paste_rs",
|
||||
side_effect=Exception("down")), \
|
||||
patch("hermes_cli.debug._upload_dpaste_com",
|
||||
return_value="https://dpaste.com/TEST") as dp:
|
||||
url = upload_to_pastebin("content")
|
||||
|
||||
assert url == "https://dpaste.com/TEST"
|
||||
dp.assert_called_once()
|
||||
|
||||
def test_raises_when_both_fail(self):
|
||||
from hermes_cli.debug import upload_to_pastebin
|
||||
|
||||
with patch("hermes_cli.debug._upload_paste_rs",
|
||||
side_effect=Exception("err1")), \
|
||||
patch("hermes_cli.debug._upload_dpaste_com",
|
||||
side_effect=Exception("err2")):
|
||||
with pytest.raises(RuntimeError, match="Failed to upload"):
|
||||
upload_to_pastebin("content")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Log reading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadFullLog:
|
||||
"""Test _read_full_log for standalone log uploads."""
|
||||
|
||||
def test_reads_small_file(self, hermes_home):
|
||||
from hermes_cli.debug import _read_full_log
|
||||
|
||||
content = _read_full_log("agent")
|
||||
assert content is not None
|
||||
assert "session started" in content
|
||||
|
||||
def test_returns_none_for_missing(self, tmp_path, monkeypatch):
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.debug import _read_full_log
|
||||
assert _read_full_log("agent") is None
|
||||
|
||||
def test_returns_none_for_empty(self, hermes_home):
|
||||
# Truncate agent.log to empty
|
||||
(hermes_home / "logs" / "agent.log").write_text("")
|
||||
|
||||
from hermes_cli.debug import _read_full_log
|
||||
assert _read_full_log("agent") is None
|
||||
|
||||
def test_truncates_large_file(self, hermes_home):
|
||||
"""Files larger than max_bytes get tail-truncated."""
|
||||
from hermes_cli.debug import _read_full_log
|
||||
|
||||
# Write a file larger than 1KB
|
||||
big_content = "x" * 100 + "\n"
|
||||
(hermes_home / "logs" / "agent.log").write_text(big_content * 200)
|
||||
|
||||
content = _read_full_log("agent", max_bytes=1024)
|
||||
assert content is not None
|
||||
assert "truncated" in content
|
||||
|
||||
def test_unknown_log_returns_none(self, hermes_home):
|
||||
from hermes_cli.debug import _read_full_log
|
||||
assert _read_full_log("nonexistent") is None
|
||||
|
||||
def test_falls_back_to_rotated_file(self, hermes_home):
|
||||
"""When gateway.log doesn't exist, falls back to gateway.log.1."""
|
||||
from hermes_cli.debug import _read_full_log
|
||||
|
||||
logs_dir = hermes_home / "logs"
|
||||
# Remove the primary (if any) and create a .1 rotation
|
||||
(logs_dir / "gateway.log").unlink(missing_ok=True)
|
||||
(logs_dir / "gateway.log.1").write_text(
|
||||
"2026-04-12 10:00:00 INFO gateway.run: rotated content\n"
|
||||
)
|
||||
|
||||
content = _read_full_log("gateway")
|
||||
assert content is not None
|
||||
assert "rotated content" in content
|
||||
|
||||
def test_prefers_primary_over_rotated(self, hermes_home):
|
||||
"""Primary log is used when it exists, even if .1 also exists."""
|
||||
from hermes_cli.debug import _read_full_log
|
||||
|
||||
logs_dir = hermes_home / "logs"
|
||||
(logs_dir / "gateway.log").write_text("primary content\n")
|
||||
(logs_dir / "gateway.log.1").write_text("rotated content\n")
|
||||
|
||||
content = _read_full_log("gateway")
|
||||
assert "primary content" in content
|
||||
assert "rotated" not in content
|
||||
|
||||
def test_falls_back_when_primary_empty(self, hermes_home):
|
||||
"""Empty primary log falls back to .1 rotation."""
|
||||
from hermes_cli.debug import _read_full_log
|
||||
|
||||
logs_dir = hermes_home / "logs"
|
||||
(logs_dir / "agent.log").write_text("")
|
||||
(logs_dir / "agent.log.1").write_text("rotated agent data\n")
|
||||
|
||||
content = _read_full_log("agent")
|
||||
assert content is not None
|
||||
assert "rotated agent data" in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug report collection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCollectDebugReport:
|
||||
"""Test the debug report builder."""
|
||||
|
||||
def test_report_includes_dump_output(self, hermes_home):
|
||||
from hermes_cli.debug import collect_debug_report
|
||||
|
||||
with patch("hermes_cli.dump.run_dump") as mock_dump:
|
||||
mock_dump.side_effect = lambda args: print(
|
||||
"--- hermes dump ---\nversion: 0.8.0\n--- end dump ---"
|
||||
)
|
||||
report = collect_debug_report(log_lines=50)
|
||||
|
||||
assert "--- hermes dump ---" in report
|
||||
assert "version: 0.8.0" in report
|
||||
|
||||
def test_report_includes_agent_log(self, hermes_home):
|
||||
from hermes_cli.debug import collect_debug_report
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"):
|
||||
report = collect_debug_report(log_lines=50)
|
||||
|
||||
assert "--- agent.log" in report
|
||||
assert "session started" in report
|
||||
|
||||
def test_report_includes_errors_log(self, hermes_home):
|
||||
from hermes_cli.debug import collect_debug_report
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"):
|
||||
report = collect_debug_report(log_lines=50)
|
||||
|
||||
assert "--- errors.log" in report
|
||||
assert "connection lost" in report
|
||||
|
||||
def test_report_includes_gateway_log(self, hermes_home):
|
||||
from hermes_cli.debug import collect_debug_report
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"):
|
||||
report = collect_debug_report(log_lines=50)
|
||||
|
||||
assert "--- gateway.log" in report
|
||||
|
||||
def test_missing_logs_handled(self, tmp_path, monkeypatch):
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.debug import collect_debug_report
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"):
|
||||
report = collect_debug_report(log_lines=50)
|
||||
|
||||
assert "(file not found)" in report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point — run_debug_share
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunDebugShare:
|
||||
"""Test the run_debug_share CLI handler."""
|
||||
|
||||
def test_local_flag_prints_full_logs(self, hermes_home, capsys):
|
||||
"""--local prints the report plus full log contents."""
|
||||
from hermes_cli.debug import run_debug_share
|
||||
|
||||
args = MagicMock()
|
||||
args.lines = 50
|
||||
args.expire = 7
|
||||
args.local = True
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"):
|
||||
run_debug_share(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "--- agent.log" in out
|
||||
assert "FULL agent.log" in out
|
||||
assert "FULL gateway.log" in out
|
||||
|
||||
def test_share_uploads_three_pastes(self, hermes_home, capsys):
|
||||
"""Successful share uploads report + agent.log + gateway.log."""
|
||||
from hermes_cli.debug import run_debug_share
|
||||
|
||||
args = MagicMock()
|
||||
args.lines = 50
|
||||
args.expire = 7
|
||||
args.local = False
|
||||
|
||||
call_count = [0]
|
||||
uploaded_content = []
|
||||
def _mock_upload(content, expiry_days=7):
|
||||
call_count[0] += 1
|
||||
uploaded_content.append(content)
|
||||
return f"https://paste.rs/paste{call_count[0]}"
|
||||
|
||||
with patch("hermes_cli.dump.run_dump") as mock_dump, \
|
||||
patch("hermes_cli.debug.upload_to_pastebin",
|
||||
side_effect=_mock_upload):
|
||||
mock_dump.side_effect = lambda a: print("--- hermes dump ---\nversion: test\n--- end dump ---")
|
||||
run_debug_share(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
# Should have 3 uploads: report, agent.log, gateway.log
|
||||
assert call_count[0] == 3
|
||||
assert "paste.rs/paste1" in out # Report
|
||||
assert "paste.rs/paste2" in out # agent.log
|
||||
assert "paste.rs/paste3" in out # gateway.log
|
||||
assert "Report" in out
|
||||
assert "agent.log" in out
|
||||
assert "gateway.log" in out
|
||||
|
||||
# Each log paste should start with the dump header
|
||||
agent_paste = uploaded_content[1]
|
||||
assert "--- hermes dump ---" in agent_paste
|
||||
assert "--- full agent.log ---" in agent_paste
|
||||
gateway_paste = uploaded_content[2]
|
||||
assert "--- hermes dump ---" in gateway_paste
|
||||
assert "--- full gateway.log ---" in gateway_paste
|
||||
|
||||
def test_share_skips_missing_logs(self, tmp_path, monkeypatch, capsys):
|
||||
"""Only uploads logs that exist."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.debug import run_debug_share
|
||||
|
||||
args = MagicMock()
|
||||
args.lines = 50
|
||||
args.expire = 7
|
||||
args.local = False
|
||||
|
||||
call_count = [0]
|
||||
def _mock_upload(content, expiry_days=7):
|
||||
call_count[0] += 1
|
||||
return f"https://paste.rs/paste{call_count[0]}"
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"), \
|
||||
patch("hermes_cli.debug.upload_to_pastebin",
|
||||
side_effect=_mock_upload):
|
||||
run_debug_share(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
# Only the report should be uploaded (no log files exist)
|
||||
assert call_count[0] == 1
|
||||
assert "Report" in out
|
||||
|
||||
def test_share_continues_on_log_upload_failure(self, hermes_home, capsys):
|
||||
"""Log upload failure doesn't stop the report from being shared."""
|
||||
from hermes_cli.debug import run_debug_share
|
||||
|
||||
args = MagicMock()
|
||||
args.lines = 50
|
||||
args.expire = 7
|
||||
args.local = False
|
||||
|
||||
call_count = [0]
|
||||
def _mock_upload(content, expiry_days=7):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 1:
|
||||
raise RuntimeError("upload failed")
|
||||
return "https://paste.rs/report"
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"), \
|
||||
patch("hermes_cli.debug.upload_to_pastebin",
|
||||
side_effect=_mock_upload):
|
||||
run_debug_share(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Report" in out
|
||||
assert "paste.rs/report" in out
|
||||
assert "failed to upload" in out
|
||||
|
||||
def test_share_exits_on_report_upload_failure(self, hermes_home, capsys):
|
||||
"""If the main report fails to upload, exit with code 1."""
|
||||
from hermes_cli.debug import run_debug_share
|
||||
|
||||
args = MagicMock()
|
||||
args.lines = 50
|
||||
args.expire = 7
|
||||
args.local = False
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"), \
|
||||
patch("hermes_cli.debug.upload_to_pastebin",
|
||||
side_effect=RuntimeError("all failed")):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
run_debug_share(args)
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
out = capsys.readouterr()
|
||||
assert "all failed" in out.err
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_debug router
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunDebug:
|
||||
def test_no_subcommand_shows_usage(self, capsys):
|
||||
from hermes_cli.debug import run_debug
|
||||
|
||||
args = MagicMock()
|
||||
args.debug_command = None
|
||||
|
||||
run_debug(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "hermes debug share" in out
|
||||
|
||||
def test_share_subcommand_routes(self, hermes_home):
|
||||
from hermes_cli.debug import run_debug
|
||||
|
||||
args = MagicMock()
|
||||
args.debug_command = "share"
|
||||
args.lines = 200
|
||||
args.expire = 7
|
||||
args.local = True
|
||||
|
||||
with patch("hermes_cli.dump.run_dump"):
|
||||
run_debug(args)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argparse integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestArgparseIntegration:
|
||||
def test_module_imports_clean(self):
|
||||
from hermes_cli.debug import run_debug, run_debug_share
|
||||
assert callable(run_debug)
|
||||
assert callable(run_debug_share)
|
||||
|
||||
def test_cmd_debug_dispatches(self):
|
||||
from hermes_cli.main import cmd_debug
|
||||
|
||||
args = MagicMock()
|
||||
args.debug_command = None
|
||||
cmd_debug(args)
|
||||
91
tests/hermes_cli/test_env_sanitize_on_load.py
Normal file
91
tests/hermes_cli/test_env_sanitize_on_load.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for .env sanitization during load to prevent token duplication (#8908)."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_load_env_sanitizes_concatenated_lines():
|
||||
"""Verify load_env() splits concatenated KEY=VALUE pairs.
|
||||
|
||||
Reproduces the scenario from #8908 where a corrupted .env file
|
||||
contained multiple tokens on a single line, causing the bot token
|
||||
to be duplicated 8 times.
|
||||
"""
|
||||
from hermes_cli.config import load_env
|
||||
|
||||
token = "8356550917:AAGGEkzg06Hrc3Hjb3Sa1jkGVDOdU_lYy2Q"
|
||||
# Simulate concatenated line: TOKEN=xxx followed immediately by another key
|
||||
corrupted = f"TELEGRAM_BOT_TOKEN={token}ANTHROPIC_API_KEY=sk-ant-test123\n"
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".env", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write(corrupted)
|
||||
env_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_path):
|
||||
result = load_env()
|
||||
assert result.get("TELEGRAM_BOT_TOKEN") == token, (
|
||||
f"Token should be exactly '{token}', got '{result.get('TELEGRAM_BOT_TOKEN')}'"
|
||||
)
|
||||
assert result.get("ANTHROPIC_API_KEY") == "sk-ant-test123"
|
||||
finally:
|
||||
env_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_load_env_normal_file_unchanged():
|
||||
"""A well-formed .env file should be parsed identically."""
|
||||
from hermes_cli.config import load_env
|
||||
|
||||
content = (
|
||||
"TELEGRAM_BOT_TOKEN=mytoken123\n"
|
||||
"ANTHROPIC_API_KEY=sk-ant-key\n"
|
||||
"# comment\n"
|
||||
"\n"
|
||||
"OPENAI_API_KEY=sk-openai\n"
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".env", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write(content)
|
||||
env_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_path):
|
||||
result = load_env()
|
||||
assert result["TELEGRAM_BOT_TOKEN"] == "mytoken123"
|
||||
assert result["ANTHROPIC_API_KEY"] == "sk-ant-key"
|
||||
assert result["OPENAI_API_KEY"] == "sk-openai"
|
||||
finally:
|
||||
env_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_env_loader_sanitizes_before_dotenv():
|
||||
"""Verify env_loader._sanitize_env_file_if_needed fixes corrupted files."""
|
||||
from hermes_cli.env_loader import _sanitize_env_file_if_needed
|
||||
|
||||
token = "8356550917:AAGGEkzg06Hrc3Hjb3Sa1jkGVDOdU_lYy2Q"
|
||||
corrupted = f"TELEGRAM_BOT_TOKEN={token}ANTHROPIC_API_KEY=sk-ant-test\n"
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".env", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write(corrupted)
|
||||
env_path = Path(f.name)
|
||||
|
||||
try:
|
||||
_sanitize_env_file_if_needed(env_path)
|
||||
with open(env_path, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# Should be split into two separate lines
|
||||
assert len(lines) == 2, f"Expected 2 lines, got {len(lines)}: {lines}"
|
||||
assert lines[0].startswith("TELEGRAM_BOT_TOKEN=")
|
||||
assert lines[1].startswith("ANTHROPIC_API_KEY=")
|
||||
# Token should not contain the second key
|
||||
parsed_token = lines[0].strip().split("=", 1)[1]
|
||||
assert parsed_token == token
|
||||
finally:
|
||||
env_path.unlink(missing_ok=True)
|
||||
@@ -394,6 +394,21 @@ class TestLaunchdServiceRecovery:
|
||||
|
||||
|
||||
class TestGatewayServiceDetection:
|
||||
def test_supports_systemd_services_requires_systemctl_binary(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli.shutil, "which", lambda name: None)
|
||||
|
||||
assert gateway_cli.supports_systemd_services() is False
|
||||
|
||||
def test_supports_systemd_services_returns_true_when_systemctl_present(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli.shutil, "which", lambda name: "/usr/bin/systemctl")
|
||||
|
||||
assert gateway_cli.supports_systemd_services() is True
|
||||
|
||||
def test_is_service_running_checks_system_scope_when_user_scope_is_inactive(self, monkeypatch):
|
||||
user_unit = SimpleNamespace(exists=lambda: True)
|
||||
system_unit = SimpleNamespace(exists=lambda: True)
|
||||
@@ -418,6 +433,23 @@ class TestGatewayServiceDetection:
|
||||
|
||||
assert gateway_cli._is_service_running() is True
|
||||
|
||||
def test_is_service_running_returns_false_when_systemctl_missing(self, monkeypatch):
|
||||
unit = SimpleNamespace(exists=lambda: True)
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"get_systemd_unit_path",
|
||||
lambda system=False: unit,
|
||||
)
|
||||
|
||||
def fake_run(*args, **kwargs):
|
||||
raise FileNotFoundError("systemctl")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
assert gateway_cli._is_service_running() is False
|
||||
|
||||
|
||||
class TestGatewaySystemServiceRouting:
|
||||
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
|
||||
@@ -1001,3 +1033,91 @@ class TestSystemUnitPathRemapping:
|
||||
# Target user paths should be present
|
||||
assert "/home/alice" in unit
|
||||
assert "WorkingDirectory=/home/alice/.hermes/hermes-agent" in unit
|
||||
|
||||
|
||||
class TestDockerAwareGateway:
|
||||
"""Tests for Docker container awareness in gateway commands."""
|
||||
|
||||
def test_run_systemctl_raises_runtimeerror_when_missing(self, monkeypatch):
|
||||
"""_run_systemctl raises RuntimeError with container guidance when systemctl is absent."""
|
||||
import pytest
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
raise FileNotFoundError("systemctl")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="systemctl is not available"):
|
||||
gateway_cli._run_systemctl(["start", "hermes-gateway"])
|
||||
|
||||
def test_run_systemctl_passes_through_on_success(self, monkeypatch):
|
||||
"""_run_systemctl delegates to subprocess.run when systemctl exists."""
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
result = gateway_cli._run_systemctl(["status", "hermes-gateway"])
|
||||
assert result.returncode == 0
|
||||
assert len(calls) == 1
|
||||
assert "status" in calls[0]
|
||||
|
||||
def test_install_in_container_prints_docker_guidance(self, monkeypatch, capsys):
|
||||
"""'hermes gateway install' inside Docker exits 0 with container guidance."""
|
||||
import pytest
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "is_managed", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_container", lambda: True)
|
||||
|
||||
args = SimpleNamespace(gateway_command="install", force=False, system=False, run_as_user=None)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
gateway_cli.gateway_command(args)
|
||||
|
||||
assert exc_info.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "Docker" in out or "docker" in out
|
||||
assert "restart" in out.lower()
|
||||
|
||||
def test_uninstall_in_container_prints_docker_guidance(self, monkeypatch, capsys):
|
||||
"""'hermes gateway uninstall' inside Docker exits 0 with container guidance."""
|
||||
import pytest
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "is_managed", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_container", lambda: True)
|
||||
|
||||
args = SimpleNamespace(gateway_command="uninstall", system=False)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
gateway_cli.gateway_command(args)
|
||||
|
||||
assert exc_info.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "docker" in out.lower()
|
||||
|
||||
def test_start_in_container_prints_docker_guidance(self, monkeypatch, capsys):
|
||||
"""'hermes gateway start' inside Docker exits 0 with container guidance."""
|
||||
import pytest
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_wsl", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_container", lambda: True)
|
||||
|
||||
args = SimpleNamespace(gateway_command="start", system=False)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
gateway_cli.gateway_command(args)
|
||||
|
||||
assert exc_info.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "docker" in out.lower()
|
||||
assert "hermes gateway run" in out
|
||||
|
||||
@@ -54,14 +54,19 @@ class TestAnthropicDotToHyphen:
|
||||
|
||||
# ── OpenCode Zen regression ────────────────────────────────────────────
|
||||
|
||||
class TestOpenCodeZenDotToHyphen:
|
||||
"""OpenCode Zen follows Anthropic convention (dots→hyphens)."""
|
||||
class TestOpenCodeZenModelNormalization:
|
||||
"""OpenCode Zen preserves dots for most models, but Claude stays hyphenated."""
|
||||
|
||||
@pytest.mark.parametrize("model,expected", [
|
||||
("claude-sonnet-4.6", "claude-sonnet-4-6"),
|
||||
("glm-4.5", "glm-4-5"),
|
||||
("opencode-zen/claude-opus-4.5", "claude-opus-4-5"),
|
||||
("glm-4.5", "glm-4.5"),
|
||||
("glm-5.1", "glm-5.1"),
|
||||
("gpt-5.4", "gpt-5.4"),
|
||||
("minimax-m2.5-free", "minimax-m2.5-free"),
|
||||
("kimi-k2.5", "kimi-k2.5"),
|
||||
])
|
||||
def test_zen_converts_dots(self, model, expected):
|
||||
def test_zen_normalizes_models(self, model, expected):
|
||||
result = normalize_model_for_provider(model, "opencode-zen")
|
||||
assert result == expected
|
||||
|
||||
@@ -69,6 +74,10 @@ class TestOpenCodeZenDotToHyphen:
|
||||
result = normalize_model_for_provider("opencode-zen/claude-sonnet-4.6", "opencode-zen")
|
||||
assert result == "claude-sonnet-4-6"
|
||||
|
||||
def test_zen_strips_vendor_prefix_for_non_claude(self):
|
||||
result = normalize_model_for_provider("opencode-zen/glm-5.1", "opencode-zen")
|
||||
assert result == "glm-5.1"
|
||||
|
||||
|
||||
# ── Copilot dot preservation (regression) ──────────────────────────────
|
||||
|
||||
|
||||
84
tests/hermes_cli/test_nous_hermes_non_agentic.py
Normal file
84
tests/hermes_cli/test_nous_hermes_non_agentic.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Tests for the Nous-Hermes-3/4 non-agentic warning detector.
|
||||
|
||||
Prior to this check, the warning fired on any model whose name contained
|
||||
``"hermes"`` anywhere (case-insensitive). That false-positived on unrelated
|
||||
local Modelfiles such as ``hermes-brain:qwen3-14b-ctx16k`` — a tool-capable
|
||||
Qwen3 wrapper that happens to live under the "hermes" tag namespace.
|
||||
|
||||
``is_nous_hermes_non_agentic`` should only match the actual Nous Research
|
||||
Hermes-3 / Hermes-4 chat family.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.model_switch import (
|
||||
_HERMES_MODEL_WARNING,
|
||||
_check_hermes_model_warning,
|
||||
is_nous_hermes_non_agentic,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"NousResearch/Hermes-3-Llama-3.1-70B",
|
||||
"NousResearch/Hermes-3-Llama-3.1-405B",
|
||||
"hermes-3",
|
||||
"Hermes-3",
|
||||
"hermes-4",
|
||||
"hermes-4-405b",
|
||||
"hermes_4_70b",
|
||||
"openrouter/hermes3:70b",
|
||||
"openrouter/nousresearch/hermes-4-405b",
|
||||
"NousResearch/Hermes3",
|
||||
"hermes-3.1",
|
||||
],
|
||||
)
|
||||
def test_matches_real_nous_hermes_chat_models(model_name: str) -> None:
|
||||
assert is_nous_hermes_non_agentic(model_name), (
|
||||
f"expected {model_name!r} to be flagged as Nous Hermes 3/4"
|
||||
)
|
||||
assert _check_hermes_model_warning(model_name) == _HERMES_MODEL_WARNING
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
# Kyle's local Modelfile — qwen3:14b under a custom tag
|
||||
"hermes-brain:qwen3-14b-ctx16k",
|
||||
"hermes-brain:qwen3-14b-ctx32k",
|
||||
"hermes-honcho:qwen3-8b-ctx8k",
|
||||
# Plain unrelated models
|
||||
"qwen3:14b",
|
||||
"qwen3-coder:30b",
|
||||
"qwen2.5:14b",
|
||||
"claude-opus-4-6",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"gpt-5",
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-2.5-flash",
|
||||
"deepseek-chat",
|
||||
# Non-chat Hermes models we don't warn about
|
||||
"hermes-llm-2",
|
||||
"hermes2-pro",
|
||||
"nous-hermes-2-mistral",
|
||||
# Edge cases
|
||||
"",
|
||||
"hermes", # bare "hermes" isn't the 3/4 family
|
||||
"hermes-brain",
|
||||
"brain-hermes-3-impostor", # "3" not preceded by /: boundary
|
||||
],
|
||||
)
|
||||
def test_does_not_match_unrelated_models(model_name: str) -> None:
|
||||
assert not is_nous_hermes_non_agentic(model_name), (
|
||||
f"expected {model_name!r} NOT to be flagged as Nous Hermes 3/4"
|
||||
)
|
||||
assert _check_hermes_model_warning(model_name) == ""
|
||||
|
||||
|
||||
def test_none_like_inputs_are_safe() -> None:
|
||||
assert is_nous_hermes_non_agentic("") is False
|
||||
# Defensive: the helper shouldn't crash on None-ish falsy input either.
|
||||
assert _check_hermes_model_warning("") == ""
|
||||
@@ -177,7 +177,8 @@ class TestCreateProfile:
|
||||
# No error; optional files just not copied
|
||||
assert not (profile_dir / "config.yaml").exists()
|
||||
assert not (profile_dir / ".env").exists()
|
||||
assert not (profile_dir / "SOUL.md").exists()
|
||||
# SOUL.md is always seeded with the default even when clone source lacks it
|
||||
assert (profile_dir / "SOUL.md").exists()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
|
||||
@@ -119,6 +119,11 @@ def test_resolve_runtime_provider_falls_back_when_pool_empty(monkeypatch):
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_codex(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"load_pool",
|
||||
lambda provider: type("P", (), {"has_credentials": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openai-codex")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
@@ -567,6 +572,87 @@ def test_named_custom_provider_uses_saved_credentials(monkeypatch):
|
||||
assert resolved["source"] == "custom_provider:Local"
|
||||
|
||||
|
||||
def test_named_custom_provider_uses_providers_dict_when_list_missing(monkeypatch):
|
||||
"""After v11→v12 migration deletes custom_providers, resolution should
|
||||
still find entries in the providers dict via get_compatible_custom_providers."""
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"load_config",
|
||||
lambda: {
|
||||
"providers": {
|
||||
"openai-direct-primary": {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "dir-key",
|
||||
"default_model": "gpt-5-mini",
|
||||
"name": "OpenAI Direct (Primary)",
|
||||
"transport": "codex_responses",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_provider",
|
||||
lambda *a, **k: (_ for _ in ()).throw(
|
||||
AssertionError(
|
||||
"resolve_provider should not be called for named custom providers"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="openai-direct-primary")
|
||||
|
||||
assert resolved["provider"] == "custom"
|
||||
assert resolved["api_mode"] == "codex_responses"
|
||||
assert resolved["base_url"] == "https://api.openai.com/v1"
|
||||
assert resolved["api_key"] == "dir-key"
|
||||
assert resolved["requested_provider"] == "openai-direct-primary"
|
||||
assert resolved["source"] == "custom_provider:OpenAI Direct (Primary)"
|
||||
assert resolved["model"] == "gpt-5-mini"
|
||||
|
||||
|
||||
def test_named_custom_provider_uses_key_env_from_providers_dict(monkeypatch):
|
||||
"""providers dict entries with key_env should resolve API key from env var."""
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.setenv("MYCORP_API_KEY", "env-secret")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"load_config",
|
||||
lambda: {
|
||||
"providers": {
|
||||
"mycorp-proxy": {
|
||||
"base_url": "https://proxy.example.com/v1",
|
||||
"default_model": "acme-large",
|
||||
"key_env": "MYCORP_API_KEY",
|
||||
"name": "MyCorp Proxy",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_provider",
|
||||
lambda *a, **k: (_ for _ in ()).throw(
|
||||
AssertionError(
|
||||
"resolve_provider should not be called for named custom providers"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="mycorp-proxy")
|
||||
|
||||
assert resolved["provider"] == "custom"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://proxy.example.com/v1"
|
||||
assert resolved["api_key"] == "env-secret"
|
||||
assert resolved["requested_provider"] == "mycorp-proxy"
|
||||
assert resolved["source"] == "custom_provider:MyCorp Proxy"
|
||||
assert resolved["model"] == "acme-large"
|
||||
|
||||
|
||||
def test_named_custom_provider_falls_back_to_openai_api_key(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "env-openai-key")
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Tests for setup_model_provider — verifies the delegation to
|
||||
select_provider_and_model() and config dict sync."""
|
||||
"""Tests for setup.py configuration flows."""
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
@@ -8,6 +7,7 @@ import pytest
|
||||
|
||||
from hermes_cli.auth import get_active_provider
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli import setup as setup_mod
|
||||
from hermes_cli.setup import setup_model_provider
|
||||
|
||||
|
||||
@@ -144,6 +144,85 @@ def test_setup_custom_providers_synced(tmp_path, monkeypatch):
|
||||
assert reloaded.get("custom_providers") == [{"name": "Local", "base_url": "http://localhost:8080/v1"}]
|
||||
|
||||
|
||||
def test_setup_gateway_skips_service_install_when_systemctl_missing(monkeypatch, capsys):
|
||||
env = {
|
||||
"TELEGRAM_BOT_TOKEN": "",
|
||||
"TELEGRAM_HOME_CHANNEL": "",
|
||||
"DISCORD_BOT_TOKEN": "",
|
||||
"DISCORD_HOME_CHANNEL": "",
|
||||
"SLACK_BOT_TOKEN": "",
|
||||
"SLACK_HOME_CHANNEL": "",
|
||||
"MATRIX_HOMESERVER": "https://matrix.example.com",
|
||||
"MATRIX_USER_ID": "@alice:example.com",
|
||||
"MATRIX_PASSWORD": "",
|
||||
"MATRIX_ACCESS_TOKEN": "token",
|
||||
"BLUEBUBBLES_SERVER_URL": "",
|
||||
"BLUEBUBBLES_HOME_CHANNEL": "",
|
||||
"WHATSAPP_ENABLED": "",
|
||||
"WEBHOOK_ENABLED": "",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, ""))
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False)
|
||||
|
||||
setup_mod.setup_gateway({})
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Messaging platforms configured!" in out
|
||||
assert "Start the gateway to bring your bots online:" in out
|
||||
assert "hermes gateway" in out
|
||||
|
||||
|
||||
def test_setup_gateway_in_container_shows_docker_guidance(monkeypatch, capsys):
|
||||
"""setup_gateway() in a Docker container shows Docker-specific restart instructions."""
|
||||
env = {
|
||||
"TELEGRAM_BOT_TOKEN": "",
|
||||
"TELEGRAM_HOME_CHANNEL": "",
|
||||
"DISCORD_BOT_TOKEN": "",
|
||||
"DISCORD_HOME_CHANNEL": "",
|
||||
"SLACK_BOT_TOKEN": "",
|
||||
"SLACK_HOME_CHANNEL": "",
|
||||
"MATRIX_HOMESERVER": "https://matrix.example.com",
|
||||
"MATRIX_USER_ID": "@alice:example.com",
|
||||
"MATRIX_PASSWORD": "",
|
||||
"MATRIX_ACCESS_TOKEN": "token",
|
||||
"BLUEBUBBLES_SERVER_URL": "",
|
||||
"BLUEBUBBLES_HOME_CHANNEL": "",
|
||||
"WHATSAPP_ENABLED": "",
|
||||
"WEBHOOK_ENABLED": "",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, ""))
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False)
|
||||
|
||||
# Patch is_container at the import location in setup.py
|
||||
import hermes_constants
|
||||
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
|
||||
|
||||
setup_mod.setup_gateway({})
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Messaging platforms configured!" in out
|
||||
assert "docker" in out.lower() or "Docker" in out
|
||||
assert "restart" in out.lower()
|
||||
|
||||
|
||||
def test_setup_syncs_custom_provider_removal_from_disk(tmp_path, monkeypatch):
|
||||
"""Removing the last custom provider in model setup should persist."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
@@ -119,8 +119,7 @@ def test_toolset_has_keys_for_vision_accepts_codex_auth(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AUXILIARY_VISION_PROVIDER", raising=False)
|
||||
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client.resolve_vision_provider_client",
|
||||
lambda: ("openai-codex", object(), "gpt-4.1"),
|
||||
|
||||
280
tests/hermes_cli/test_user_providers_model_switch.py
Normal file
280
tests/hermes_cli/test_user_providers_model_switch.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Tests for user-defined providers (providers: dict) in /model.
|
||||
|
||||
These tests ensure that providers defined in the config.yaml ``providers:`` section
|
||||
are properly resolved for model switching and that their full ``models:`` lists
|
||||
are exposed in the model picker.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hermes_cli.model_switch import list_authenticated_providers, switch_model
|
||||
from hermes_cli import runtime_provider as rp
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for list_authenticated_providers including full models list
|
||||
# =============================================================================
|
||||
|
||||
def test_list_authenticated_providers_includes_full_models_list_from_user_providers(monkeypatch):
|
||||
"""User-defined providers should expose both default_model and full models list.
|
||||
|
||||
Regression test: previously only default_model was shown in /model picker.
|
||||
"""
|
||||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
user_providers = {
|
||||
"local-ollama": {
|
||||
"name": "Local Ollama",
|
||||
"api": "http://localhost:11434/v1",
|
||||
"default_model": "minimax-m2.7:cloud",
|
||||
"models": [
|
||||
"minimax-m2.7:cloud",
|
||||
"kimi-k2.5:cloud",
|
||||
"glm-5.1:cloud",
|
||||
"qwen3.5:cloud",
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="local-ollama",
|
||||
user_providers=user_providers,
|
||||
custom_providers=[],
|
||||
max_models=50,
|
||||
)
|
||||
|
||||
# Find our user provider
|
||||
user_prov = next(
|
||||
(p for p in providers if p.get("is_user_defined") and p["slug"] == "local-ollama"),
|
||||
None
|
||||
)
|
||||
|
||||
assert user_prov is not None, "User provider 'local-ollama' should be in results"
|
||||
assert user_prov["total_models"] == 4, f"Expected 4 models, got {user_prov['total_models']}"
|
||||
assert "minimax-m2.7:cloud" in user_prov["models"]
|
||||
assert "kimi-k2.5:cloud" in user_prov["models"]
|
||||
assert "glm-5.1:cloud" in user_prov["models"]
|
||||
assert "qwen3.5:cloud" in user_prov["models"]
|
||||
|
||||
|
||||
def test_list_authenticated_providers_dedupes_models_when_default_in_list(monkeypatch):
|
||||
"""When default_model is also in models list, don't duplicate."""
|
||||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
user_providers = {
|
||||
"my-provider": {
|
||||
"api": "http://example.com/v1",
|
||||
"default_model": "model-a", # Included in models list below
|
||||
"models": ["model-a", "model-b", "model-c"],
|
||||
}
|
||||
}
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="my-provider",
|
||||
user_providers=user_providers,
|
||||
custom_providers=[],
|
||||
)
|
||||
|
||||
user_prov = next(
|
||||
(p for p in providers if p.get("is_user_defined")),
|
||||
None
|
||||
)
|
||||
|
||||
assert user_prov is not None
|
||||
assert user_prov["total_models"] == 3, "Should have 3 unique models, not 4"
|
||||
assert user_prov["models"].count("model-a") == 1, "model-a should not be duplicated"
|
||||
|
||||
|
||||
def test_list_authenticated_providers_fallback_to_default_only(monkeypatch):
|
||||
"""When no models array is provided, should fall back to default_model."""
|
||||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
user_providers = {
|
||||
"simple-provider": {
|
||||
"name": "Simple Provider",
|
||||
"api": "http://example.com/v1",
|
||||
"default_model": "single-model",
|
||||
# No 'models' key
|
||||
}
|
||||
}
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="",
|
||||
user_providers=user_providers,
|
||||
custom_providers=[],
|
||||
)
|
||||
|
||||
user_prov = next(
|
||||
(p for p in providers if p.get("is_user_defined")),
|
||||
None
|
||||
)
|
||||
|
||||
assert user_prov is not None
|
||||
assert user_prov["total_models"] == 1
|
||||
assert user_prov["models"] == ["single-model"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for _get_named_custom_provider with providers: dict
|
||||
# =============================================================================
|
||||
|
||||
def test_get_named_custom_provider_finds_user_providers_by_key(monkeypatch, tmp_path):
|
||||
"""Should resolve providers from providers: dict (new-style), not just custom_providers."""
|
||||
config = {
|
||||
"providers": {
|
||||
"local-localhost:11434": {
|
||||
"api": "http://localhost:11434/v1",
|
||||
"name": "Local (localhost:11434)",
|
||||
"default_model": "minimax-m2.7:cloud",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = rp._get_named_custom_provider("local-localhost:11434")
|
||||
|
||||
assert result is not None
|
||||
assert result["base_url"] == "http://localhost:11434/v1"
|
||||
assert result["name"] == "Local (localhost:11434)"
|
||||
|
||||
|
||||
def test_get_named_custom_provider_finds_by_display_name(monkeypatch, tmp_path):
|
||||
"""Should match providers by their 'name' field as well as key."""
|
||||
config = {
|
||||
"providers": {
|
||||
"my-ollama-xyz": {
|
||||
"api": "http://ollama.example.com/v1",
|
||||
"name": "My Production Ollama",
|
||||
"default_model": "llama3",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Should find by display name (normalized)
|
||||
result = rp._get_named_custom_provider("my-production-ollama")
|
||||
|
||||
assert result is not None
|
||||
assert result["base_url"] == "http://ollama.example.com/v1"
|
||||
|
||||
|
||||
def test_get_named_custom_provider_falls_back_to_legacy_format(monkeypatch, tmp_path):
|
||||
"""Should still work with custom_providers: list format."""
|
||||
config = {
|
||||
"providers": {},
|
||||
"custom_providers": [
|
||||
{
|
||||
"name": "Custom Endpoint",
|
||||
"base_url": "http://custom.example.com/v1",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = rp._get_named_custom_provider("custom-endpoint")
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_get_named_custom_provider_returns_none_for_unknown(monkeypatch, tmp_path):
|
||||
"""Should return None for providers that don't exist."""
|
||||
config = {
|
||||
"providers": {
|
||||
"known-provider": {
|
||||
"api": "http://known.example.com/v1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = rp._get_named_custom_provider("other-provider")
|
||||
|
||||
# "unknown-provider" partial-matches "known-provider" because "unknown" doesn't match
|
||||
# but our matching is loose (substring). Let's verify a truly non-matching provider
|
||||
result = rp._get_named_custom_provider("completely-different-name")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_named_custom_provider_skips_empty_base_url(monkeypatch, tmp_path):
|
||||
"""Should skip providers without a base_url."""
|
||||
config = {
|
||||
"providers": {
|
||||
"incomplete-provider": {
|
||||
"name": "Incomplete",
|
||||
# No api/base_url field
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = rp._get_named_custom_provider("incomplete-provider")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration test for switch_model with user providers
|
||||
# =============================================================================
|
||||
|
||||
def test_switch_model_resolves_user_provider_credentials(monkeypatch, tmp_path):
|
||||
"""/model switch should resolve credentials for providers: dict providers."""
|
||||
import yaml
|
||||
|
||||
config = {
|
||||
"providers": {
|
||||
"local-ollama": {
|
||||
"api": "http://localhost:11434/v1",
|
||||
"name": "Local Ollama",
|
||||
"default_model": "minimax-m2.7:cloud",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(config))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Mock validation to pass
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.validate_requested_model",
|
||||
lambda *a, **k: {"accepted": True, "persist": True, "recognized": True, "message": None}
|
||||
)
|
||||
|
||||
result = switch_model(
|
||||
raw_input="kimi-k2.5:cloud",
|
||||
current_provider="local-ollama",
|
||||
current_model="minimax-m2.7:cloud",
|
||||
current_base_url="http://localhost:11434/v1",
|
||||
is_global=False,
|
||||
user_providers=config["providers"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.error_message == ""
|
||||
675
tests/hermes_cli/test_web_server.py
Normal file
675
tests/hermes_cli/test_web_server.py
Normal file
@@ -0,0 +1,675 @@
|
||||
"""Tests for hermes_cli.web_server and related config utilities."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import (
|
||||
DEFAULT_CONFIG,
|
||||
reload_env,
|
||||
redact_key,
|
||||
_EXTRA_ENV_KEYS,
|
||||
OPTIONAL_ENV_VARS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reload_env tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReloadEnv:
|
||||
"""Tests for reload_env() — re-reads .env into os.environ."""
|
||||
|
||||
def test_adds_new_vars(self, tmp_path):
|
||||
"""reload_env() adds vars from .env that are not in os.environ."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("TEST_RELOAD_VAR=hello123\n")
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
os.environ.pop("TEST_RELOAD_VAR", None)
|
||||
count = reload_env()
|
||||
assert count >= 1
|
||||
assert os.environ.get("TEST_RELOAD_VAR") == "hello123"
|
||||
os.environ.pop("TEST_RELOAD_VAR", None)
|
||||
|
||||
def test_updates_changed_vars(self, tmp_path):
|
||||
"""reload_env() updates vars whose value changed on disk."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("TEST_RELOAD_VAR=old_value\n")
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
os.environ["TEST_RELOAD_VAR"] = "old_value"
|
||||
# Now change the file
|
||||
env_file.write_text("TEST_RELOAD_VAR=new_value\n")
|
||||
count = reload_env()
|
||||
assert count >= 1
|
||||
assert os.environ.get("TEST_RELOAD_VAR") == "new_value"
|
||||
os.environ.pop("TEST_RELOAD_VAR", None)
|
||||
|
||||
def test_removes_deleted_known_vars(self, tmp_path):
|
||||
"""reload_env() removes known Hermes vars not present in .env."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("") # empty .env
|
||||
# Pick a known key from OPTIONAL_ENV_VARS
|
||||
known_key = next(iter(OPTIONAL_ENV_VARS.keys()))
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
os.environ[known_key] = "stale_value"
|
||||
count = reload_env()
|
||||
assert known_key not in os.environ
|
||||
assert count >= 1
|
||||
|
||||
def test_does_not_remove_unknown_vars(self, tmp_path):
|
||||
"""reload_env() preserves non-Hermes env vars even when absent from .env."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("")
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
os.environ["MY_CUSTOM_UNRELATED_VAR"] = "keep_me"
|
||||
reload_env()
|
||||
assert os.environ.get("MY_CUSTOM_UNRELATED_VAR") == "keep_me"
|
||||
os.environ.pop("MY_CUSTOM_UNRELATED_VAR", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# redact_key tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRedactKey:
|
||||
def test_long_key_shows_prefix_suffix(self):
|
||||
result = redact_key("sk-1234567890abcdef")
|
||||
assert result.startswith("sk-1")
|
||||
assert result.endswith("cdef")
|
||||
assert "..." in result
|
||||
|
||||
def test_short_key_fully_masked(self):
|
||||
assert redact_key("short") == "***"
|
||||
|
||||
def test_empty_key(self):
|
||||
result = redact_key("")
|
||||
assert "not set" in result.lower() or result == "***" or "\x1b" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# web_server tests (FastAPI endpoints)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWebServerEndpoints:
|
||||
"""Test the FastAPI REST endpoints using Starlette TestClient."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_test_client(self):
|
||||
"""Create a TestClient — import is deferred to avoid requiring fastapi."""
|
||||
try:
|
||||
from starlette.testclient import TestClient
|
||||
except ImportError:
|
||||
pytest.skip("fastapi/starlette not installed")
|
||||
|
||||
from hermes_cli.web_server import app
|
||||
self.client = TestClient(app)
|
||||
|
||||
def test_get_status(self):
|
||||
resp = self.client.get("/api/status")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "version" in data
|
||||
assert "hermes_home" in data
|
||||
assert "active_sessions" in data
|
||||
|
||||
def test_get_status_filters_unconfigured_gateway_platforms(self, monkeypatch):
|
||||
import gateway.config as gateway_config
|
||||
import hermes_cli.web_server as web_server
|
||||
|
||||
class _Platform:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
class _GatewayConfig:
|
||||
def get_connected_platforms(self):
|
||||
return [_Platform("telegram")]
|
||||
|
||||
monkeypatch.setattr(web_server, "get_running_pid", lambda: 1234)
|
||||
monkeypatch.setattr(
|
||||
web_server,
|
||||
"read_runtime_status",
|
||||
lambda: {
|
||||
"gateway_state": "running",
|
||||
"updated_at": "2026-04-12T00:00:00+00:00",
|
||||
"platforms": {
|
||||
"telegram": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
|
||||
"whatsapp": {"state": "retrying", "updated_at": "2026-04-12T00:00:00+00:00"},
|
||||
"feishu": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(web_server, "check_config_version", lambda: (1, 1))
|
||||
monkeypatch.setattr(gateway_config, "load_gateway_config", lambda: _GatewayConfig())
|
||||
|
||||
resp = self.client.get("/api/status")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["gateway_platforms"] == {
|
||||
"telegram": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
|
||||
}
|
||||
|
||||
def test_get_status_hides_stale_platforms_when_gateway_not_running(self, monkeypatch):
|
||||
import gateway.config as gateway_config
|
||||
import hermes_cli.web_server as web_server
|
||||
|
||||
class _GatewayConfig:
|
||||
def get_connected_platforms(self):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(web_server, "get_running_pid", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
web_server,
|
||||
"read_runtime_status",
|
||||
lambda: {
|
||||
"gateway_state": "startup_failed",
|
||||
"updated_at": "2026-04-12T00:00:00+00:00",
|
||||
"platforms": {
|
||||
"whatsapp": {"state": "retrying", "updated_at": "2026-04-12T00:00:00+00:00"},
|
||||
"feishu": {"state": "connected", "updated_at": "2026-04-12T00:00:00+00:00"},
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(web_server, "check_config_version", lambda: (1, 1))
|
||||
monkeypatch.setattr(gateway_config, "load_gateway_config", lambda: _GatewayConfig())
|
||||
|
||||
resp = self.client.get("/api/status")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["gateway_state"] == "startup_failed"
|
||||
assert resp.json()["gateway_platforms"] == {}
|
||||
|
||||
def test_get_config_schema(self):
|
||||
resp = self.client.get("/api/config/schema")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "fields" in data
|
||||
assert "category_order" in data
|
||||
schema = data["fields"]
|
||||
assert len(schema) > 100 # Should have 150+ fields
|
||||
assert "model" in schema
|
||||
# Verify category_order is a non-empty list
|
||||
assert isinstance(data["category_order"], list)
|
||||
assert len(data["category_order"]) > 0
|
||||
assert "general" in data["category_order"]
|
||||
|
||||
def test_get_config_defaults(self):
|
||||
resp = self.client.get("/api/config/defaults")
|
||||
assert resp.status_code == 200
|
||||
defaults = resp.json()
|
||||
assert "model" in defaults
|
||||
|
||||
def test_get_env_vars(self):
|
||||
resp = self.client.get("/api/env")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Should contain known env var names
|
||||
assert any(k.endswith("_API_KEY") or k.endswith("_TOKEN") for k in data.keys())
|
||||
|
||||
def test_reveal_env_var(self, tmp_path):
|
||||
"""POST /api/env/reveal should return the real unredacted value."""
|
||||
from hermes_cli.config import save_env_value
|
||||
from hermes_cli.web_server import _SESSION_TOKEN
|
||||
save_env_value("TEST_REVEAL_KEY", "super-secret-value-12345")
|
||||
resp = self.client.post(
|
||||
"/api/env/reveal",
|
||||
json={"key": "TEST_REVEAL_KEY"},
|
||||
headers={"Authorization": f"Bearer {_SESSION_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["key"] == "TEST_REVEAL_KEY"
|
||||
assert data["value"] == "super-secret-value-12345"
|
||||
|
||||
def test_reveal_env_var_not_found(self):
|
||||
"""POST /api/env/reveal should 404 for unknown keys."""
|
||||
from hermes_cli.web_server import _SESSION_TOKEN
|
||||
resp = self.client.post(
|
||||
"/api/env/reveal",
|
||||
json={"key": "NONEXISTENT_KEY_XYZ"},
|
||||
headers={"Authorization": f"Bearer {_SESSION_TOKEN}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_reveal_env_var_no_token(self, tmp_path):
|
||||
"""POST /api/env/reveal without token should return 401."""
|
||||
from hermes_cli.config import save_env_value
|
||||
save_env_value("TEST_REVEAL_NOAUTH", "secret-value")
|
||||
resp = self.client.post(
|
||||
"/api/env/reveal",
|
||||
json={"key": "TEST_REVEAL_NOAUTH"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_reveal_env_var_bad_token(self, tmp_path):
|
||||
"""POST /api/env/reveal with wrong token should return 401."""
|
||||
from hermes_cli.config import save_env_value
|
||||
save_env_value("TEST_REVEAL_BADAUTH", "secret-value")
|
||||
resp = self.client.post(
|
||||
"/api/env/reveal",
|
||||
json={"key": "TEST_REVEAL_BADAUTH"},
|
||||
headers={"Authorization": "Bearer wrong-token-here"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_session_token_endpoint(self):
|
||||
"""GET /api/auth/session-token should return a token."""
|
||||
from hermes_cli.web_server import _SESSION_TOKEN
|
||||
resp = self.client.get("/api/auth/session-token")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["token"] == _SESSION_TOKEN
|
||||
|
||||
def test_path_traversal_blocked(self):
|
||||
"""Verify URL-encoded path traversal is blocked."""
|
||||
# %2e%2e = ..
|
||||
resp = self.client.get("/%2e%2e/%2e%2e/etc/passwd")
|
||||
# Should return 200 with index.html (SPA fallback), not the actual file
|
||||
assert resp.status_code in (200, 404)
|
||||
if resp.status_code == 200:
|
||||
# Should be the SPA fallback, not the system file
|
||||
assert "root:" not in resp.text
|
||||
|
||||
def test_path_traversal_dotdot_blocked(self):
|
||||
"""Direct .. path traversal via encoded sequences."""
|
||||
resp = self.client.get("/%2e%2e/hermes_cli/web_server.py")
|
||||
assert resp.status_code in (200, 404)
|
||||
if resp.status_code == 200:
|
||||
assert "FastAPI" not in resp.text # Should not serve the actual source
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_schema_from_config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSchemaFromConfig:
|
||||
def test_produces_expected_field_count(self):
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
# DEFAULT_CONFIG has ~150+ leaf fields
|
||||
assert len(CONFIG_SCHEMA) > 100
|
||||
|
||||
def test_schema_entries_have_required_fields(self):
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
for key, entry in list(CONFIG_SCHEMA.items())[:10]:
|
||||
assert "type" in entry, f"Missing type for {key}"
|
||||
assert "category" in entry, f"Missing category for {key}"
|
||||
|
||||
def test_overrides_applied(self):
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
# terminal.backend should be a select with options
|
||||
if "terminal.backend" in CONFIG_SCHEMA:
|
||||
entry = CONFIG_SCHEMA["terminal.backend"]
|
||||
assert entry["type"] == "select"
|
||||
assert "options" in entry
|
||||
assert "local" in entry["options"]
|
||||
|
||||
def test_empty_prefix_produces_correct_keys(self):
|
||||
from hermes_cli.web_server import _build_schema_from_config
|
||||
test_config = {"model": "test", "nested": {"key": "val"}}
|
||||
schema = _build_schema_from_config(test_config)
|
||||
assert "model" in schema
|
||||
assert "nested.key" in schema
|
||||
|
||||
def test_top_level_scalars_get_general_category(self):
|
||||
"""Top-level scalar fields should be in 'general' category."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
assert CONFIG_SCHEMA["model"]["category"] == "general"
|
||||
|
||||
def test_nested_keys_get_parent_category(self):
|
||||
"""Nested fields should use the top-level parent as their category."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
if "agent.max_turns" in CONFIG_SCHEMA:
|
||||
assert CONFIG_SCHEMA["agent.max_turns"]["category"] == "agent"
|
||||
|
||||
def test_category_merge_applied(self):
|
||||
"""Small categories should be merged into larger ones."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
categories = {e["category"] for e in CONFIG_SCHEMA.values()}
|
||||
# These should be merged away
|
||||
assert "privacy" not in categories # merged into security
|
||||
assert "context" not in categories # merged into agent
|
||||
|
||||
def test_no_single_field_categories(self):
|
||||
"""After merging, no category should have just 1 field."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
from collections import Counter
|
||||
cats = Counter(e["category"] for e in CONFIG_SCHEMA.values())
|
||||
for cat, count in cats.items():
|
||||
assert count >= 2, f"Category '{cat}' has only {count} field(s) — should be merged"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config round-trip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigRoundTrip:
|
||||
"""Verify config survives GET → edit → PUT without data loss."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self):
|
||||
try:
|
||||
from starlette.testclient import TestClient
|
||||
except ImportError:
|
||||
pytest.skip("fastapi/starlette not installed")
|
||||
from hermes_cli.web_server import app
|
||||
self.client = TestClient(app)
|
||||
|
||||
def test_get_config_no_internal_keys(self):
|
||||
"""GET /api/config should not expose _config_version or _model_meta."""
|
||||
config = self.client.get("/api/config").json()
|
||||
internal = [k for k in config if k.startswith("_")]
|
||||
assert not internal, f"Internal keys leaked to frontend: {internal}"
|
||||
|
||||
def test_get_config_model_is_string(self):
|
||||
"""GET /api/config should normalize model dict to a string."""
|
||||
config = self.client.get("/api/config").json()
|
||||
assert isinstance(config.get("model"), str), \
|
||||
f"model should be string, got {type(config.get('model'))}"
|
||||
|
||||
def test_round_trip_preserves_model_subkeys(self):
|
||||
"""Save and reload should not lose model.provider, model.base_url, etc."""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
# Set up a config with model as a dict (the common user config form)
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "anthropic/claude-sonnet-4",
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_mode": "openai",
|
||||
}
|
||||
})
|
||||
|
||||
before = load_config()
|
||||
assert isinstance(before.get("model"), dict)
|
||||
original_keys = set(before["model"].keys())
|
||||
|
||||
# GET → PUT unchanged
|
||||
web_config = self.client.get("/api/config").json()
|
||||
assert isinstance(web_config.get("model"), str), "GET should normalize model to string"
|
||||
|
||||
self.client.put("/api/config", json={"config": web_config})
|
||||
|
||||
after = load_config()
|
||||
assert isinstance(after.get("model"), dict), "model should still be a dict after save"
|
||||
assert set(after["model"].keys()) >= original_keys, \
|
||||
f"Lost model subkeys: {original_keys - set(after['model'].keys())}"
|
||||
|
||||
def test_edit_model_name_preserved(self):
|
||||
"""Changing the model string should update model.default on disk."""
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
web_config = self.client.get("/api/config").json()
|
||||
original_model = web_config["model"]
|
||||
|
||||
# Change model
|
||||
web_config["model"] = "test/editing-model"
|
||||
self.client.put("/api/config", json={"config": web_config})
|
||||
|
||||
after = load_config()
|
||||
if isinstance(after.get("model"), dict):
|
||||
assert after["model"]["default"] == "test/editing-model"
|
||||
else:
|
||||
assert after["model"] == "test/editing-model"
|
||||
|
||||
# Restore
|
||||
web_config["model"] = original_model
|
||||
self.client.put("/api/config", json={"config": web_config})
|
||||
|
||||
def test_edit_nested_value(self):
|
||||
"""Editing a nested config value should persist correctly."""
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
web_config = self.client.get("/api/config").json()
|
||||
original_turns = web_config.get("agent", {}).get("max_turns")
|
||||
|
||||
# Change max_turns
|
||||
if "agent" not in web_config:
|
||||
web_config["agent"] = {}
|
||||
web_config["agent"]["max_turns"] = 42
|
||||
|
||||
self.client.put("/api/config", json={"config": web_config})
|
||||
|
||||
after = load_config()
|
||||
assert after.get("agent", {}).get("max_turns") == 42
|
||||
|
||||
# Restore
|
||||
web_config["agent"]["max_turns"] = original_turns
|
||||
self.client.put("/api/config", json={"config": web_config})
|
||||
|
||||
def test_schema_types_match_config_values(self):
|
||||
"""Every schema field should have a matching-type value in the config."""
|
||||
config = self.client.get("/api/config").json()
|
||||
schema_resp = self.client.get("/api/config/schema").json()
|
||||
schema = schema_resp["fields"]
|
||||
|
||||
def get_nested(obj, path):
|
||||
parts = path.split(".")
|
||||
cur = obj
|
||||
for p in parts:
|
||||
if cur is None or not isinstance(cur, dict):
|
||||
return None
|
||||
cur = cur.get(p)
|
||||
return cur
|
||||
|
||||
mismatches = []
|
||||
for key, entry in schema.items():
|
||||
val = get_nested(config, key)
|
||||
if val is None:
|
||||
continue # not set in user config — fine
|
||||
expected = entry["type"]
|
||||
if expected in ("string", "select") and not isinstance(val, str):
|
||||
mismatches.append(f"{key}: expected str, got {type(val).__name__}")
|
||||
elif expected == "number" and not isinstance(val, (int, float)):
|
||||
mismatches.append(f"{key}: expected number, got {type(val).__name__}")
|
||||
elif expected == "boolean" and not isinstance(val, bool):
|
||||
mismatches.append(f"{key}: expected bool, got {type(val).__name__}")
|
||||
elif expected == "list" and not isinstance(val, list):
|
||||
mismatches.append(f"{key}: expected list, got {type(val).__name__}")
|
||||
assert not mismatches, f"Type mismatches:\n" + "\n".join(mismatches)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New feature endpoint tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewEndpoints:
|
||||
"""Tests for session detail, logs, cron, skills, tools, raw config, analytics."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self):
|
||||
try:
|
||||
from starlette.testclient import TestClient
|
||||
except ImportError:
|
||||
pytest.skip("fastapi/starlette not installed")
|
||||
from hermes_cli.web_server import app
|
||||
self.client = TestClient(app)
|
||||
|
||||
def test_get_logs_default(self):
|
||||
resp = self.client.get("/api/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "file" in data
|
||||
assert "lines" in data
|
||||
assert isinstance(data["lines"], list)
|
||||
|
||||
def test_get_logs_invalid_file(self):
|
||||
resp = self.client.get("/api/logs?file=nonexistent")
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_cron_list(self):
|
||||
resp = self.client.get("/api/cron/jobs")
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
def test_cron_job_not_found(self):
|
||||
resp = self.client.get("/api/cron/jobs/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_skills_list(self):
|
||||
resp = self.client.get("/api/skills")
|
||||
assert resp.status_code == 200
|
||||
skills = resp.json()
|
||||
assert isinstance(skills, list)
|
||||
if skills:
|
||||
assert "name" in skills[0]
|
||||
assert "enabled" in skills[0]
|
||||
|
||||
def test_skills_list_includes_disabled_skills(self, monkeypatch):
|
||||
import tools.skills_tool as skills_tool
|
||||
import hermes_cli.skills_config as skills_config
|
||||
import hermes_cli.web_server as web_server
|
||||
|
||||
def _fake_find_all_skills(*, skip_disabled=False):
|
||||
if skip_disabled:
|
||||
return [
|
||||
{"name": "active-skill", "description": "active", "category": "demo"},
|
||||
{"name": "disabled-skill", "description": "disabled", "category": "demo"},
|
||||
]
|
||||
return [
|
||||
{"name": "active-skill", "description": "active", "category": "demo"},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(skills_tool, "_find_all_skills", _fake_find_all_skills)
|
||||
monkeypatch.setattr(skills_config, "get_disabled_skills", lambda config: {"disabled-skill"})
|
||||
monkeypatch.setattr(web_server, "load_config", lambda: {"skills": {"disabled": ["disabled-skill"]}})
|
||||
|
||||
resp = self.client.get("/api/skills")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [
|
||||
{
|
||||
"name": "active-skill",
|
||||
"description": "active",
|
||||
"category": "demo",
|
||||
"enabled": True,
|
||||
},
|
||||
{
|
||||
"name": "disabled-skill",
|
||||
"description": "disabled",
|
||||
"category": "demo",
|
||||
"enabled": False,
|
||||
},
|
||||
]
|
||||
|
||||
def test_toolsets_list(self):
|
||||
resp = self.client.get("/api/tools/toolsets")
|
||||
assert resp.status_code == 200
|
||||
toolsets = resp.json()
|
||||
assert isinstance(toolsets, list)
|
||||
if toolsets:
|
||||
assert "name" in toolsets[0]
|
||||
assert "label" in toolsets[0]
|
||||
assert "enabled" in toolsets[0]
|
||||
|
||||
def test_toolsets_list_matches_cli_enabled_state(self, monkeypatch):
|
||||
import hermes_cli.tools_config as tools_config
|
||||
import toolsets as toolsets_module
|
||||
import hermes_cli.web_server as web_server
|
||||
|
||||
monkeypatch.setattr(
|
||||
tools_config,
|
||||
"_get_effective_configurable_toolsets",
|
||||
lambda: [
|
||||
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
|
||||
("skills", "📚 Skills", "list, view, manage"),
|
||||
("memory", "💾 Memory", "persistent memory across sessions"),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tools_config,
|
||||
"_get_platform_tools",
|
||||
lambda config, platform, include_default_mcp_servers=False: {"web", "skills"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tools_config,
|
||||
"_toolset_has_keys",
|
||||
lambda ts_key, config=None: ts_key != "web",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
toolsets_module,
|
||||
"resolve_toolset",
|
||||
lambda name: {
|
||||
"web": ["web_search", "web_extract"],
|
||||
"skills": ["skills_list", "skill_view"],
|
||||
"memory": ["memory_read"],
|
||||
}[name],
|
||||
)
|
||||
monkeypatch.setattr(web_server, "load_config", lambda: {"platform_toolsets": {"cli": ["web", "skills"]}})
|
||||
|
||||
resp = self.client.get("/api/tools/toolsets")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [
|
||||
{
|
||||
"name": "web",
|
||||
"label": "🔍 Web Search & Scraping",
|
||||
"description": "web_search, web_extract",
|
||||
"enabled": True,
|
||||
"available": True,
|
||||
"configured": False,
|
||||
"tools": ["web_extract", "web_search"],
|
||||
},
|
||||
{
|
||||
"name": "skills",
|
||||
"label": "📚 Skills",
|
||||
"description": "list, view, manage",
|
||||
"enabled": True,
|
||||
"available": True,
|
||||
"configured": True,
|
||||
"tools": ["skill_view", "skills_list"],
|
||||
},
|
||||
{
|
||||
"name": "memory",
|
||||
"label": "💾 Memory",
|
||||
"description": "persistent memory across sessions",
|
||||
"enabled": False,
|
||||
"available": False,
|
||||
"configured": True,
|
||||
"tools": ["memory_read"],
|
||||
},
|
||||
]
|
||||
|
||||
def test_config_raw_get(self):
|
||||
resp = self.client.get("/api/config/raw")
|
||||
assert resp.status_code == 200
|
||||
assert "yaml" in resp.json()
|
||||
|
||||
def test_config_raw_put_valid(self):
|
||||
resp = self.client.put(
|
||||
"/api/config/raw",
|
||||
json={"yaml_text": "model: test\ntoolsets:\n - all\n"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ok"] is True
|
||||
|
||||
def test_config_raw_put_invalid(self):
|
||||
resp = self.client.put(
|
||||
"/api/config/raw",
|
||||
json={"yaml_text": "- this is a list not a dict"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_analytics_usage(self):
|
||||
resp = self.client.get("/api/analytics/usage?days=7")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "daily" in data
|
||||
assert "by_model" in data
|
||||
assert "totals" in data
|
||||
assert isinstance(data["daily"], list)
|
||||
assert "total_sessions" in data["totals"]
|
||||
|
||||
def test_session_token_endpoint(self):
|
||||
from hermes_cli.web_server import _SESSION_TOKEN
|
||||
resp = self.client.get("/api/auth/session-token")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["token"] == _SESSION_TOKEN
|
||||
@@ -102,7 +102,19 @@ class _PromptTooLongError(Exception):
|
||||
self.status_code = 400
|
||||
|
||||
|
||||
class _FakeMessages:
|
||||
"""Stub for client.messages.create() / client.messages.stream()."""
|
||||
def create(self, **kwargs):
|
||||
raise NotImplementedError("_FakeAnthropicClient.messages.create should not be called directly in tests")
|
||||
|
||||
def stream(self, **kwargs):
|
||||
raise NotImplementedError("_FakeAnthropicClient.messages.stream should not be called directly in tests")
|
||||
|
||||
|
||||
class _FakeAnthropicClient:
|
||||
def __init__(self):
|
||||
self.messages = _FakeMessages()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
@@ -131,13 +143,14 @@ def _make_agent_cls(error_cls, recover_after=None):
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
calls = {"n": 0}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
def _fake_api_call(api_kwargs, **kw):
|
||||
calls["n"] += 1
|
||||
if recover_after is not None and calls["n"] > recover_after:
|
||||
return _anthropic_response("Recovered")
|
||||
raise error_cls()
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
self._interruptible_streaming_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
@@ -352,10 +365,11 @@ def test_401_refresh_fails_is_non_retryable(monkeypatch):
|
||||
return False # Simulate failed credential refresh
|
||||
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
def _fake_api_call(api_kwargs):
|
||||
def _fake_api_call(api_kwargs, **kw):
|
||||
raise _UnauthorizedError()
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
self._interruptible_streaming_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
@@ -436,13 +450,14 @@ def test_prompt_too_long_triggers_compression(monkeypatch):
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
calls = {"n": 0}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
def _fake_api_call(api_kwargs, **kw):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise _PromptTooLongError()
|
||||
return _anthropic_response("Compressed and recovered")
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
self._interruptible_streaming_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ def _make_agent(
|
||||
agent.status_callback = None
|
||||
agent.tool_progress_callback = None
|
||||
agent._compression_warning = None
|
||||
agent.config = None
|
||||
|
||||
compressor = MagicMock(spec=ContextCompressor)
|
||||
compressor.context_length = main_context
|
||||
@@ -130,6 +131,64 @@ def test_feasibility_check_passes_live_main_runtime():
|
||||
)
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=1_000_000)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_feasibility_check_passes_config_context_length(mock_get_client, mock_ctx_len):
|
||||
"""auxiliary.compression.context_length from config is forwarded to
|
||||
get_model_context_length so custom endpoints that lack /models still
|
||||
report the correct context window (fixes #8499)."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.85)
|
||||
agent.config = {
|
||||
"auxiliary": {
|
||||
"compression": {
|
||||
"context_length": 1_000_000,
|
||||
},
|
||||
},
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "http://custom-endpoint:8080/v1"
|
||||
mock_client.api_key = "sk-custom"
|
||||
mock_get_client.return_value = (mock_client, "custom/big-model")
|
||||
|
||||
agent._emit_status = lambda msg: None
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
mock_ctx_len.assert_called_once_with(
|
||||
"custom/big-model",
|
||||
base_url="http://custom-endpoint:8080/v1",
|
||||
api_key="sk-custom",
|
||||
config_context_length=1_000_000,
|
||||
)
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_feasibility_check_ignores_invalid_context_length(mock_get_client, mock_ctx_len):
|
||||
"""Non-integer context_length in config is silently ignored."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
agent.config = {
|
||||
"auxiliary": {
|
||||
"compression": {
|
||||
"context_length": "not-a-number",
|
||||
},
|
||||
},
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "http://custom:8080/v1"
|
||||
mock_client.api_key = "sk-test"
|
||||
mock_get_client.return_value = (mock_client, "custom/model")
|
||||
|
||||
agent._emit_status = lambda msg: None
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
mock_ctx_len.assert_called_once_with(
|
||||
"custom/model",
|
||||
base_url="http://custom:8080/v1",
|
||||
api_key="sk-test",
|
||||
config_context_length=None,
|
||||
)
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_warns_when_no_auxiliary_provider(mock_get_client):
|
||||
"""Warning emitted when no auxiliary provider is configured."""
|
||||
|
||||
@@ -56,6 +56,7 @@ def _make_agent(monkeypatch, api_mode, provider, response_fn):
|
||||
|
||||
def run_conversation(self, msg, conversation_history=None, task_id=None):
|
||||
self._interruptible_api_call = lambda kw: response_fn()
|
||||
self._disable_streaming = True
|
||||
return super().run_conversation(msg, conversation_history=conversation_history, task_id=task_id)
|
||||
|
||||
return _A(model="test-model", api_key="test-key", provider=provider, api_mode=api_mode)
|
||||
|
||||
@@ -66,6 +66,7 @@ def test_tool_call_validation_accepts_dict_arguments(monkeypatch):
|
||||
quiet_mode=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._disable_streaming = True
|
||||
|
||||
result = agent.run_conversation("read the file")
|
||||
|
||||
|
||||
89
tests/run_agent/test_plugin_context_engine_init.py
Normal file
89
tests/run_agent/test_plugin_context_engine_init.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests that plugin context engines get update_model() called during init.
|
||||
|
||||
Regression test for #9071 — plugin engines were never initialized with
|
||||
context_length, causing the CLI status bar to show 'ctx --'.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.context_engine import ContextEngine
|
||||
|
||||
|
||||
class _StubEngine(ContextEngine):
|
||||
"""Minimal concrete context engine for testing."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "stub"
|
||||
|
||||
def update_from_response(self, usage):
|
||||
pass
|
||||
|
||||
def should_compress(self, prompt_tokens=None):
|
||||
return False
|
||||
|
||||
def compress(self, messages, current_tokens=None):
|
||||
return messages
|
||||
|
||||
|
||||
def test_plugin_engine_gets_context_length_on_init():
|
||||
"""Plugin context engine should have context_length set during AIAgent init."""
|
||||
engine = _StubEngine()
|
||||
assert engine.context_length == 0 # ABC default before fix
|
||||
|
||||
cfg = {"context": {"engine": "stub"}, "agent": {}}
|
||||
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", return_value=cfg),
|
||||
patch("plugins.context_engine.load_context_engine", return_value=engine),
|
||||
patch("agent.model_metadata.get_model_context_length", return_value=204_800),
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
assert agent.context_compressor is engine
|
||||
assert engine.context_length == 204_800
|
||||
assert engine.threshold_tokens == int(204_800 * engine.threshold_percent)
|
||||
|
||||
|
||||
def test_plugin_engine_update_model_args():
|
||||
"""Verify update_model() receives model, context_length, base_url, api_key, provider."""
|
||||
engine = _StubEngine()
|
||||
engine.update_model = MagicMock()
|
||||
|
||||
cfg = {"context": {"engine": "stub"}, "agent": {}}
|
||||
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", return_value=cfg),
|
||||
patch("plugins.context_engine.load_context_engine", return_value=engine),
|
||||
patch("agent.model_metadata.get_model_context_length", return_value=131_072),
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="openrouter/auto",
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
engine.update_model.assert_called_once()
|
||||
kw = engine.update_model.call_args.kwargs
|
||||
assert kw["context_length"] == 131_072
|
||||
assert "model" in kw
|
||||
assert "provider" in kw
|
||||
# Should NOT pass api_mode — the ABC doesn't accept it
|
||||
assert "api_mode" not in kw
|
||||
@@ -44,11 +44,11 @@ class _FakeOpenAI:
|
||||
pass
|
||||
|
||||
|
||||
def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="https://openrouter.ai/api/v1"):
|
||||
def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="https://openrouter.ai/api/v1", model=None):
|
||||
monkeypatch.setattr("run_agent.get_tool_definitions", lambda **kw: _tool_defs("web_search", "terminal"))
|
||||
monkeypatch.setattr("run_agent.check_toolset_requirements", lambda: {})
|
||||
monkeypatch.setattr("run_agent.OpenAI", _FakeOpenAI)
|
||||
return AIAgent(
|
||||
kwargs = dict(
|
||||
api_key="test-key",
|
||||
base_url=base_url,
|
||||
provider=provider,
|
||||
@@ -58,6 +58,9 @@ def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="ht
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
return AIAgent(**kwargs)
|
||||
|
||||
|
||||
# ── _build_api_kwargs tests ─────────────────────────────────────────────────
|
||||
@@ -247,7 +250,7 @@ class TestBuildApiKwargsChatCompletionsServiceTier:
|
||||
|
||||
class TestBuildApiKwargsAIGateway:
|
||||
def test_uses_chat_completions_format(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "messages" in kwargs
|
||||
@@ -255,7 +258,7 @@ class TestBuildApiKwargsAIGateway:
|
||||
assert kwargs["messages"][-1]["content"] == "hi"
|
||||
|
||||
def test_no_responses_api_fields(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "input" not in kwargs
|
||||
@@ -263,7 +266,7 @@ class TestBuildApiKwargsAIGateway:
|
||||
assert "store" not in kwargs
|
||||
|
||||
def test_includes_reasoning_in_extra_body(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
extra = kwargs.get("extra_body", {})
|
||||
@@ -271,7 +274,7 @@ class TestBuildApiKwargsAIGateway:
|
||||
assert extra["reasoning"]["enabled"] is True
|
||||
|
||||
def test_includes_tools(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1", model="gpt-4o")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "tools" in kwargs
|
||||
|
||||
@@ -76,7 +76,8 @@ class TestRealSubagentInterrupt(unittest.TestCase):
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
parent._client_kwargs = {"api_key": "***", "base_url": "http://localhost:1"}
|
||||
parent._execution_thread_id = None
|
||||
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
|
||||
@@ -302,6 +302,17 @@ class TestStripThinkBlocks:
|
||||
assert "<think>" not in result
|
||||
assert "visible" in result
|
||||
|
||||
def test_thought_block_removed(self, agent):
|
||||
"""Gemma 4 uses <thought> tags for inline reasoning."""
|
||||
result = agent._strip_think_blocks("<thought>internal reasoning</thought> answer")
|
||||
assert "internal reasoning" not in result
|
||||
assert "<thought>" not in result
|
||||
assert "answer" in result
|
||||
|
||||
def test_orphaned_thought_tag(self, agent):
|
||||
result = agent._strip_think_blocks("<thought>orphaned reasoning without close")
|
||||
assert "<thought>" not in result
|
||||
|
||||
|
||||
class TestExtractReasoning:
|
||||
def test_reasoning_field(self, agent):
|
||||
@@ -869,6 +880,7 @@ class TestBuildApiKwargs:
|
||||
assert kwargs["extra_body"]["reasoning"] == {"enabled": False}
|
||||
|
||||
def test_reasoning_not_sent_for_unsupported_openrouter_model(self, agent):
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.model = "minimax/minimax-m2.5"
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
@@ -1564,6 +1576,7 @@ class TestHandleMaxIterations:
|
||||
assert "API down" in result
|
||||
|
||||
def test_summary_skips_reasoning_for_unsupported_openrouter_model(self, agent):
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.model = "minimax/minimax-m2.5"
|
||||
resp = _mock_response(content="Summary")
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
@@ -1694,27 +1707,6 @@ class TestRunConversation:
|
||||
assert result["completed"] is True
|
||||
assert result["api_calls"] == 2
|
||||
|
||||
def test_inline_think_blocks_reasoning_only_accepted(self, agent):
|
||||
"""Inline <think> reasoning-only responses accepted with (empty) content, no retries."""
|
||||
self._setup_agent(agent)
|
||||
empty_resp = _mock_response(
|
||||
content="<think>internal reasoning</think>",
|
||||
finish_reason="stop",
|
||||
)
|
||||
agent.client.chat.completions.create.side_effect = [empty_resp]
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("answer me")
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "(empty)"
|
||||
assert result["api_calls"] == 1 # no retries
|
||||
# Reasoning should be preserved in the assistant message
|
||||
assistant_msgs = [m for m in result["messages"] if m.get("role") == "assistant"]
|
||||
assert any(m.get("reasoning") for m in assistant_msgs)
|
||||
|
||||
def test_reasoning_only_local_resumed_no_compression_triggered(self, agent):
|
||||
"""Reasoning-only responses no longer trigger compression — prefill then accepted."""
|
||||
self._setup_agent(agent)
|
||||
@@ -1730,9 +1722,9 @@ class TestRunConversation:
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
]
|
||||
|
||||
# 3 responses: original + 2 prefill continuations (structured reasoning triggers prefill)
|
||||
# 6 responses: original + 2 prefill + 3 retries after prefill exhaustion
|
||||
with (
|
||||
patch.object(agent, "_interruptible_api_call", side_effect=[empty_resp, empty_resp, empty_resp]),
|
||||
patch.object(agent, "_interruptible_api_call", side_effect=[empty_resp] * 6),
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
@@ -1743,18 +1735,18 @@ class TestRunConversation:
|
||||
mock_compress.assert_not_called() # no compression triggered
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "(empty)"
|
||||
assert result["api_calls"] == 3 # 1 original + 2 prefill continuations
|
||||
assert result["api_calls"] == 6 # 1 original + 2 prefill + 3 retries
|
||||
|
||||
def test_reasoning_only_response_prefill_then_empty(self, agent):
|
||||
"""Structured reasoning-only triggers prefill continuation (up to 2), then falls through to (empty)."""
|
||||
"""Structured reasoning-only triggers prefill (2), then retries (3), then (empty)."""
|
||||
self._setup_agent(agent)
|
||||
empty_resp = _mock_response(
|
||||
content=None,
|
||||
finish_reason="stop",
|
||||
reasoning_content="structured reasoning answer",
|
||||
)
|
||||
# 3 responses: original + 2 prefill continuations, all reasoning-only
|
||||
agent.client.chat.completions.create.side_effect = [empty_resp, empty_resp, empty_resp]
|
||||
# 6 responses: 1 original + 2 prefill + 3 retries after prefill exhaustion
|
||||
agent.client.chat.completions.create.side_effect = [empty_resp] * 6
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
@@ -1763,7 +1755,7 @@ class TestRunConversation:
|
||||
result = agent.run_conversation("answer me")
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "(empty)"
|
||||
assert result["api_calls"] == 3 # 1 original + 2 prefill continuations
|
||||
assert result["api_calls"] == 6 # 1 original + 2 prefill + 3 retries
|
||||
|
||||
def test_reasoning_only_prefill_succeeds_on_continuation(self, agent):
|
||||
"""When prefill continuation produces content, it becomes the final response."""
|
||||
@@ -1938,6 +1930,88 @@ class TestRunConversation:
|
||||
failure_msgs = [m for m in status_messages if "no content" in m.lower() or "no fallback" in m.lower()]
|
||||
assert len(failure_msgs) >= 1, f"Expected at least 1 failure status, got: {status_messages}"
|
||||
|
||||
def test_partial_stream_recovery_uses_streamed_content(self, agent):
|
||||
"""When streaming fails after partial delivery, recovered partial content becomes final response."""
|
||||
self._setup_agent(agent)
|
||||
# Simulate a partial-stream-stub response: content recovered from streaming
|
||||
partial_resp = _mock_response(
|
||||
content="Here is the partial answer that was stream",
|
||||
finish_reason="stop",
|
||||
)
|
||||
agent.client.chat.completions.create.return_value = partial_resp
|
||||
# Simulate that streaming had already delivered this text
|
||||
agent._current_streamed_assistant_text = "Here is the partial answer that was stream"
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("explain something")
|
||||
# The partial content should be used as-is (not empty, not retried)
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Here is the partial answer that was stream"
|
||||
assert result["api_calls"] == 1 # No retries
|
||||
|
||||
def test_partial_stream_recovery_on_empty_stub(self, agent):
|
||||
"""When stub response has no content but text was streamed, use streamed text."""
|
||||
self._setup_agent(agent)
|
||||
# Stub response with no content (old behavior before fix)
|
||||
empty_stub = _mock_response(content=None, finish_reason="stop")
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
# Simulate what streaming does: accumulate text before returning
|
||||
# a stub with no content (connection died mid-stream)
|
||||
agent._current_streamed_assistant_text = "The answer to your question is that"
|
||||
return empty_stub
|
||||
|
||||
status_messages = []
|
||||
|
||||
def _capture_status(msg):
|
||||
status_messages.append(msg)
|
||||
|
||||
with (
|
||||
patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call),
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
patch.object(agent, "_emit_status", side_effect=_capture_status),
|
||||
):
|
||||
result = agent.run_conversation("ask me")
|
||||
# Should recover partial streamed content, not fall through to (empty)
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "The answer to your question is that"
|
||||
assert result["api_calls"] == 1 # No wasted retries
|
||||
# Should emit the stream-interrupted status, NOT the empty-retry status
|
||||
recovery_msgs = [m for m in status_messages if "stream interrupted" in m.lower()]
|
||||
assert len(recovery_msgs) >= 1, f"Expected stream recovery status, got: {status_messages}"
|
||||
# Should NOT have retry statuses
|
||||
retry_msgs = [m for m in status_messages if "retrying" in m.lower()]
|
||||
assert len(retry_msgs) == 0, f"Should not retry when stream content exists: {status_messages}"
|
||||
|
||||
def test_partial_stream_recovery_preempts_prior_turn_fallback(self, agent):
|
||||
"""Partial streamed content takes priority over _last_content_with_tools fallback."""
|
||||
self._setup_agent(agent)
|
||||
# Set up the prior-turn fallback content (from a previous turn with tool calls)
|
||||
agent._last_content_with_tools = "Old content from prior turn with tools"
|
||||
# Stub response with no content
|
||||
empty_stub = _mock_response(content=None, finish_reason="stop")
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
# Simulate partial streaming before connection death
|
||||
agent._current_streamed_assistant_text = "Fresh partial content from this turn"
|
||||
return empty_stub
|
||||
|
||||
with (
|
||||
patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call),
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("question")
|
||||
# Should use the streamed content, not the old prior-turn fallback
|
||||
assert result["final_response"] == "Fresh partial content from this turn"
|
||||
assert result["api_calls"] == 1
|
||||
|
||||
def test_nous_401_refreshes_after_remint_and_retries(self, agent):
|
||||
self._setup_agent(agent)
|
||||
agent.provider = "nous"
|
||||
@@ -3426,8 +3500,8 @@ class TestStreamingApiCall:
|
||||
call_kwargs = agent.client.chat.completions.create.call_args
|
||||
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
||||
|
||||
def test_api_exception_falls_back_to_non_streaming(self, agent):
|
||||
"""When streaming fails before any deltas, fallback to non-streaming is attempted."""
|
||||
def test_api_exception_propagates_no_non_streaming_fallback(self, agent):
|
||||
"""When streaming fails before any deltas, error propagates to the main retry loop."""
|
||||
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
||||
# Prevent stream retry logic from replacing the mock client
|
||||
with patch.object(agent, "_replace_primary_openai_client", return_value=False):
|
||||
|
||||
@@ -243,6 +243,22 @@ def test_api_mode_respects_explicit_openrouter_provider_over_codex_url(monkeypat
|
||||
assert agent.provider == "openrouter"
|
||||
|
||||
|
||||
def test_copilot_acp_stays_on_chat_completions_for_gpt_5_models(monkeypatch):
|
||||
_patch_agent_bootstrap(monkeypatch)
|
||||
agent = run_agent.AIAgent(
|
||||
model="gpt-5.4-mini",
|
||||
base_url="acp://copilot",
|
||||
provider="copilot-acp",
|
||||
api_key="copilot-acp",
|
||||
quiet_mode=True,
|
||||
max_iterations=1,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
assert agent.provider == "copilot-acp"
|
||||
assert agent.api_mode == "chat_completions"
|
||||
|
||||
|
||||
def test_build_api_kwargs_codex(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
kwargs = agent._build_api_kwargs(
|
||||
|
||||
@@ -291,6 +291,38 @@ class TestStreamingCallbacks:
|
||||
|
||||
assert len(first_delta_calls) == 1
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_chat_stream_refreshes_activity_on_every_chunk(self, mock_close, mock_create):
|
||||
"""Each streamed chat chunk should refresh the activity timestamp."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="a"),
|
||||
_make_stream_chunk(content="b"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
touch_calls = []
|
||||
agent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert touch_calls.count("receiving stream response") == len(chunks)
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_only_does_not_fire_callback(self, mock_close, mock_create):
|
||||
@@ -374,13 +406,19 @@ class TestStreamingCallbacks:
|
||||
|
||||
|
||||
class TestStreamingFallback:
|
||||
"""Verify fallback to non-streaming on ANY streaming error."""
|
||||
"""Verify streaming errors propagate to the main retry loop.
|
||||
|
||||
Previously, streaming errors triggered an inline fallback to
|
||||
non-streaming. Now they propagate so the main retry loop can apply
|
||||
richer recovery (credential rotation, provider fallback, backoff).
|
||||
The only special case: 'stream not supported' sets _disable_streaming
|
||||
so the *next* main-loop retry uses non-streaming automatically.
|
||||
"""
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""'not supported' error triggers fallback to non-streaming."""
|
||||
def test_stream_not_supported_sets_flag_and_raises(self, mock_close, mock_create):
|
||||
"""'not supported' error sets _disable_streaming and propagates."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -389,23 +427,6 @@ class TestStreamingFallback:
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback response",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
@@ -415,16 +436,16 @@ class TestStreamingFallback:
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
with pytest.raises(Exception, match="Streaming is not supported"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback response"
|
||||
mock_non_stream.assert_called_once()
|
||||
# The flag should be set so the main retry loop switches to non-streaming
|
||||
assert agent._disable_streaming is True
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_any_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""ANY streaming error triggers fallback — not just specific messages."""
|
||||
def test_non_transport_error_propagates(self, mock_close, mock_create):
|
||||
"""Non-transport streaming errors propagate to the main retry loop."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -433,23 +454,6 @@ class TestStreamingFallback:
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback after connection error",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
@@ -459,24 +463,19 @@ class TestStreamingFallback:
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
with pytest.raises(Exception, match="Connection reset by peer"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback after connection error"
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_fallback_error_propagates(self, mock_close, mock_create, mock_non_stream):
|
||||
"""When both streaming AND fallback fail, the fallback error propagates."""
|
||||
def test_stream_error_propagates_original(self, mock_close, mock_create):
|
||||
"""The original streaming error propagates (not a fallback error)."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception("stream broke")
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
mock_non_stream.side_effect = Exception("Rate limit exceeded")
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
@@ -486,14 +485,13 @@ class TestStreamingFallback:
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
with pytest.raises(Exception, match="stream broke"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_exhausted_transient_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""Transient stream errors retry first, then fall back after retries are exhausted."""
|
||||
def test_exhausted_transient_stream_error_propagates(self, mock_close, mock_create):
|
||||
"""Transient stream errors retry first, then propagate after retries exhausted."""
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
|
||||
@@ -501,23 +499,6 @@ class TestStreamingFallback:
|
||||
mock_client.chat.completions.create.side_effect = httpx.ConnectError("socket closed")
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback after retries exhausted",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
@@ -527,23 +508,22 @@ class TestStreamingFallback:
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
with pytest.raises(httpx.ConnectError, match="socket closed"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback after retries exhausted"
|
||||
# Should have retried 3 times (default HERMES_STREAM_RETRIES=2 → 3 attempts)
|
||||
assert mock_client.chat.completions.create.call_count == 3
|
||||
mock_non_stream.assert_called_once()
|
||||
assert mock_close.call_count >= 1
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_sse_connection_lost_retried_as_transient(self, mock_close, mock_create, mock_non_stream):
|
||||
def test_sse_connection_lost_retried_as_transient(self, mock_close, mock_create):
|
||||
"""SSE 'Network connection lost' (APIError w/ no status_code) retries like httpx errors.
|
||||
|
||||
OpenRouter sends {"error":{"message":"Network connection lost."}} as an SSE
|
||||
event when the upstream stream drops. The OpenAI SDK raises APIError from
|
||||
this. It should be retried at the streaming level, same as httpx connection
|
||||
errors, before falling back to non-streaming.
|
||||
errors, then propagate to the main retry loop after exhaustion.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
@@ -561,23 +541,6 @@ class TestStreamingFallback:
|
||||
mock_client.chat.completions.create.side_effect = sse_error
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback after SSE retries",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
@@ -587,21 +550,18 @@ class TestStreamingFallback:
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
with pytest.raises(OAIAPIError):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback after SSE retries"
|
||||
# Should retry 3 times (default HERMES_STREAM_RETRIES=2 → 3 attempts)
|
||||
# before falling back to non-streaming
|
||||
assert mock_client.chat.completions.create.call_count == 3
|
||||
mock_non_stream.assert_called_once()
|
||||
# Connection cleanup should happen for each failed retry
|
||||
assert mock_close.call_count >= 2
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_sse_non_connection_error_falls_back_immediately(self, mock_close, mock_create, mock_non_stream):
|
||||
"""SSE errors that aren't connection-related still fall back immediately (no stream retry)."""
|
||||
def test_sse_non_connection_error_propagates_immediately(self, mock_close, mock_create):
|
||||
"""SSE errors that aren't connection-related propagate immediately (no stream retry)."""
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
|
||||
@@ -616,23 +576,6 @@ class TestStreamingFallback:
|
||||
mock_client.chat.completions.create.side_effect = sse_error
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback no retry",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
@@ -642,12 +585,11 @@ class TestStreamingFallback:
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
with pytest.raises(OAIAPIError):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback no retry"
|
||||
# Should NOT retry — goes straight to non-streaming fallback
|
||||
# Should NOT retry — propagates immediately
|
||||
assert mock_client.chat.completions.create.call_count == 1
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
|
||||
# ── Test: Reasoning Streaming ────────────────────────────────────────────
|
||||
@@ -783,6 +725,55 @@ class TestCodexStreamCallbacks:
|
||||
response = agent._run_codex_stream({}, client=mock_client)
|
||||
assert "Hello from Codex!" in deltas
|
||||
|
||||
def test_codex_stream_refreshes_activity_on_every_event(self):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
touch_calls = []
|
||||
agent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
mock_event_text_1 = SimpleNamespace(
|
||||
type="response.output_text.delta",
|
||||
delta="Hello",
|
||||
)
|
||||
mock_event_text_2 = SimpleNamespace(
|
||||
type="response.output_text.delta",
|
||||
delta=" world",
|
||||
)
|
||||
mock_event_done = SimpleNamespace(
|
||||
type="response.completed",
|
||||
delta="",
|
||||
)
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
|
||||
mock_stream.__exit__ = MagicMock(return_value=False)
|
||||
mock_stream.__iter__ = MagicMock(
|
||||
return_value=iter([mock_event_text_1, mock_event_text_2, mock_event_done])
|
||||
)
|
||||
mock_stream.get_final_response.return_value = SimpleNamespace(
|
||||
output=[SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="Hello world")],
|
||||
)],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.return_value = mock_stream
|
||||
|
||||
agent._run_codex_stream({}, client=mock_client)
|
||||
|
||||
assert touch_calls.count("receiving stream response") == 3
|
||||
|
||||
def test_codex_remote_protocol_error_falls_back_to_create_stream(self):
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
@@ -814,3 +805,102 @@ class TestCodexStreamCallbacks:
|
||||
|
||||
assert response is fallback_response
|
||||
mock_fallback.assert_called_once_with({}, client=mock_client)
|
||||
|
||||
def test_codex_create_stream_fallback_refreshes_activity_on_every_event(self):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
|
||||
touch_calls = []
|
||||
agent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
events = [
|
||||
SimpleNamespace(type="response.output_text.delta", delta="Hello"),
|
||||
SimpleNamespace(type="response.output_item.done", item=SimpleNamespace(type="message")),
|
||||
SimpleNamespace(
|
||||
type="response.completed",
|
||||
response=SimpleNamespace(
|
||||
output=[SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="Hello")],
|
||||
)]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
class _FakeCreateStream:
|
||||
def __iter__(self_inner):
|
||||
return iter(events)
|
||||
|
||||
def close(self_inner):
|
||||
return None
|
||||
|
||||
mock_stream = _FakeCreateStream()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.create.return_value = mock_stream
|
||||
|
||||
agent._run_codex_create_stream_fallback(
|
||||
{"model": "test/model", "instructions": "hi", "input": []},
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
assert touch_calls.count("receiving stream response") == len(events)
|
||||
|
||||
|
||||
class TestAnthropicStreamCallbacks:
|
||||
"""Verify Anthropic streaming refreshes activity on every event."""
|
||||
|
||||
def test_anthropic_stream_refreshes_activity_on_every_event(self):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
touch_calls = []
|
||||
agent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
events = [
|
||||
SimpleNamespace(
|
||||
type="content_block_delta",
|
||||
delta=SimpleNamespace(type="text_delta", text="Hello"),
|
||||
),
|
||||
SimpleNamespace(
|
||||
type="content_block_delta",
|
||||
delta=SimpleNamespace(type="thinking_delta", thinking="thinking"),
|
||||
),
|
||||
SimpleNamespace(
|
||||
type="content_block_start",
|
||||
content_block=SimpleNamespace(type="tool_use", name="terminal"),
|
||||
),
|
||||
]
|
||||
|
||||
final_message = SimpleNamespace(
|
||||
content=[],
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
|
||||
mock_stream.__exit__ = MagicMock(return_value=False)
|
||||
mock_stream.__iter__ = MagicMock(return_value=iter(events))
|
||||
mock_stream.get_final_message.return_value = final_message
|
||||
|
||||
agent._anthropic_client = MagicMock()
|
||||
agent._anthropic_client.messages.stream.return_value = mock_stream
|
||||
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert touch_calls.count("receiving stream response") == len(events)
|
||||
|
||||
@@ -9,6 +9,8 @@ import pytest
|
||||
from run_agent import (
|
||||
_strip_non_ascii,
|
||||
_sanitize_messages_non_ascii,
|
||||
_sanitize_structure_non_ascii,
|
||||
_sanitize_tools_non_ascii,
|
||||
_sanitize_messages_surrogates,
|
||||
)
|
||||
|
||||
@@ -138,3 +140,66 @@ class TestSurrogateVsAsciiSanitization:
|
||||
"""When no surrogates present, _sanitize_messages_surrogates returns False."""
|
||||
messages = [{"role": "user", "content": "hello ⚕ world"}]
|
||||
assert _sanitize_messages_surrogates(messages) is False
|
||||
|
||||
|
||||
class TestSanitizeToolsNonAscii:
|
||||
"""Tests for _sanitize_tools_non_ascii."""
|
||||
|
||||
def test_sanitizes_tool_description_and_parameter_descriptions(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Print structured output │ with emoji 🤖",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path │ with unicode",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert _sanitize_tools_non_ascii(tools) is True
|
||||
assert tools[0]["function"]["description"] == "Print structured output with emoji "
|
||||
assert tools[0]["function"]["parameters"]["properties"]["path"]["description"] == "File path with unicode"
|
||||
|
||||
def test_no_change_for_ascii_only_tools(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read file content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert _sanitize_tools_non_ascii(tools) is False
|
||||
|
||||
|
||||
class TestSanitizeStructureNonAscii:
|
||||
def test_sanitizes_nested_dict_structure(self):
|
||||
payload = {
|
||||
"default_headers": {
|
||||
"X-Title": "Hermes │ Agent",
|
||||
"User-Agent": "Hermes/1.0 🤖",
|
||||
}
|
||||
}
|
||||
assert _sanitize_structure_non_ascii(payload) is True
|
||||
assert payload["default_headers"]["X-Title"] == "Hermes Agent"
|
||||
assert payload["default_headers"]["User-Agent"] == "Hermes/1.0 "
|
||||
|
||||
@@ -179,6 +179,7 @@ class TestEphemeralMaxOutputTokens:
|
||||
return_value=[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
agent._anthropic_preserve_dots = MagicMock(return_value=False)
|
||||
agent.request_overrides = {}
|
||||
return agent
|
||||
|
||||
def test_ephemeral_override_is_used_on_first_call(self):
|
||||
@@ -253,6 +254,7 @@ class TestContextNotHalvedOnOutputCapError:
|
||||
)
|
||||
agent._anthropic_preserve_dots = MagicMock(return_value=False)
|
||||
agent._vprint = MagicMock()
|
||||
agent.request_overrides = {}
|
||||
return agent
|
||||
|
||||
def test_output_cap_error_sets_ephemeral_not_context_length(self):
|
||||
|
||||
@@ -6,7 +6,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_constants import get_default_hermes_root
|
||||
import hermes_constants
|
||||
from hermes_constants import get_default_hermes_root, is_container
|
||||
|
||||
|
||||
class TestGetDefaultHermesRoot:
|
||||
@@ -60,3 +61,53 @@ class TestGetDefaultHermesRoot:
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile))
|
||||
assert get_default_hermes_root() == docker_root
|
||||
|
||||
|
||||
class TestIsContainer:
|
||||
"""Tests for is_container() — Docker/Podman detection."""
|
||||
|
||||
def _reset_cache(self, monkeypatch):
|
||||
"""Reset the cached detection result before each test."""
|
||||
monkeypatch.setattr(hermes_constants, "_container_detected", None)
|
||||
|
||||
def test_detects_dockerenv(self, monkeypatch, tmp_path):
|
||||
"""/.dockerenv triggers container detection."""
|
||||
self._reset_cache(monkeypatch)
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: p == "/.dockerenv")
|
||||
assert is_container() is True
|
||||
|
||||
def test_detects_containerenv(self, monkeypatch, tmp_path):
|
||||
"""/run/.containerenv triggers container detection (Podman)."""
|
||||
self._reset_cache(monkeypatch)
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: p == "/run/.containerenv")
|
||||
assert is_container() is True
|
||||
|
||||
def test_detects_cgroup_docker(self, monkeypatch, tmp_path):
|
||||
"""/proc/1/cgroup containing 'docker' triggers detection."""
|
||||
import builtins
|
||||
self._reset_cache(monkeypatch)
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: False)
|
||||
cgroup_file = tmp_path / "cgroup"
|
||||
cgroup_file.write_text("12:memory:/docker/abc123\n")
|
||||
_real_open = builtins.open
|
||||
monkeypatch.setattr("builtins.open", lambda p, *a, **kw: _real_open(str(cgroup_file), *a, **kw) if p == "/proc/1/cgroup" else _real_open(p, *a, **kw))
|
||||
assert is_container() is True
|
||||
|
||||
def test_negative_case(self, monkeypatch, tmp_path):
|
||||
"""Returns False on a regular Linux host."""
|
||||
import builtins
|
||||
self._reset_cache(monkeypatch)
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: False)
|
||||
cgroup_file = tmp_path / "cgroup"
|
||||
cgroup_file.write_text("12:memory:/\n")
|
||||
_real_open = builtins.open
|
||||
monkeypatch.setattr("builtins.open", lambda p, *a, **kw: _real_open(str(cgroup_file), *a, **kw) if p == "/proc/1/cgroup" else _real_open(p, *a, **kw))
|
||||
assert is_container() is False
|
||||
|
||||
def test_caches_result(self, monkeypatch):
|
||||
"""Second call uses cached value without re-probing."""
|
||||
monkeypatch.setattr(hermes_constants, "_container_detected", True)
|
||||
assert is_container() is True
|
||||
# Even if we make os.path.exists return False, cached value wins
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: False)
|
||||
assert is_container() is True
|
||||
|
||||
@@ -298,8 +298,17 @@ class TestGatewayMode:
|
||||
"""agent.log (catch-all) still receives gateway AND tool records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
logging.getLogger("gateway.run").info("gateway msg")
|
||||
logging.getLogger("tools.file_tools").info("file msg")
|
||||
gw_logger = logging.getLogger("gateway.run")
|
||||
file_logger = logging.getLogger("tools.file_tools")
|
||||
# Ensure propagation and levels are clean (cross-test pollution defense)
|
||||
gw_logger.propagate = True
|
||||
file_logger.propagate = True
|
||||
logging.getLogger("tools").propagate = True
|
||||
file_logger.setLevel(logging.NOTSET)
|
||||
logging.getLogger("tools").setLevel(logging.NOTSET)
|
||||
|
||||
gw_logger.info("gateway msg")
|
||||
file_logger.info("file msg")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestSourceLineVerification:
|
||||
if "self.async_client = AsyncOpenAI(" in line and "_get_async_client" not in lines[max(0,i-3):i+1]:
|
||||
# Allow it inside _get_async_client method
|
||||
# Check if we're inside _get_async_client by looking at context
|
||||
context = "\n".join(lines[max(0,i-10):i+1])
|
||||
context = "\n".join(lines[max(0,i-20):i+1])
|
||||
if "_get_async_client" not in context:
|
||||
pytest.fail(
|
||||
f"Line {i}: AsyncOpenAI created eagerly outside _get_async_client()"
|
||||
|
||||
@@ -64,4 +64,4 @@ class TestCamofoxConfigDefaults:
|
||||
|
||||
# The current schema version is tracked globally; unrelated default
|
||||
# options may bump it after browser defaults are added.
|
||||
assert DEFAULT_CONFIG["_config_version"] == 15
|
||||
assert DEFAULT_CONFIG["_config_version"] == 17
|
||||
|
||||
@@ -79,5 +79,33 @@ class TestSafeWriteRoot:
|
||||
assert _is_write_denied(os.path.expanduser("~/.ssh/id_rsa")) is True
|
||||
|
||||
|
||||
class TestCheckSensitivePathMacOSBypass:
|
||||
"""Verify _check_sensitive_path blocks /private/etc paths (issue #8734)."""
|
||||
|
||||
def test_etc_hosts_blocked(self):
|
||||
from tools.file_tools import _check_sensitive_path
|
||||
assert _check_sensitive_path("/etc/hosts") is not None
|
||||
|
||||
def test_private_etc_hosts_blocked(self):
|
||||
from tools.file_tools import _check_sensitive_path
|
||||
assert _check_sensitive_path("/private/etc/hosts") is not None
|
||||
|
||||
def test_private_etc_ssh_config_blocked(self):
|
||||
from tools.file_tools import _check_sensitive_path
|
||||
assert _check_sensitive_path("/private/etc/ssh/sshd_config") is not None
|
||||
|
||||
def test_private_var_blocked(self):
|
||||
from tools.file_tools import _check_sensitive_path
|
||||
assert _check_sensitive_path("/private/var/db/something") is not None
|
||||
|
||||
def test_boot_still_blocked(self):
|
||||
from tools.file_tools import _check_sensitive_path
|
||||
assert _check_sensitive_path("/boot/grub/grub.cfg") is not None
|
||||
|
||||
def test_safe_path_allowed(self):
|
||||
from tools.file_tools import _check_sensitive_path
|
||||
assert _check_sensitive_path("/tmp/safe_file.txt") is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -5,6 +5,7 @@ handler validation, and availability gating.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -18,6 +19,7 @@ from tools.homeassistant_tool import (
|
||||
_handle_call_service,
|
||||
_BLOCKED_DOMAINS,
|
||||
_ENTITY_ID_RE,
|
||||
_SERVICE_NAME_RE,
|
||||
)
|
||||
|
||||
|
||||
@@ -303,6 +305,147 @@ class TestEntityIdValidation:
|
||||
assert "Invalid entity_id" not in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# String-data deserialization (XML tool calling workaround)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCallServiceStringData:
|
||||
"""data param may arrive as a JSON string (XML tool calling mode)."""
|
||||
|
||||
@patch("tools.homeassistant_tool._run_async", return_value={"success": True})
|
||||
def test_string_data_deserialized(self, mock_run):
|
||||
"""JSON string data is parsed into a dict before dispatch."""
|
||||
_handle_call_service({
|
||||
"domain": "climate",
|
||||
"service": "set_hvac_mode",
|
||||
"entity_id": "climate.living_room",
|
||||
"data": '{"hvac_mode": "heat"}',
|
||||
})
|
||||
call_args = mock_run.call_args[0][0] # the coroutine arg
|
||||
# _run_async was called, meaning we got past validation
|
||||
|
||||
@patch("tools.homeassistant_tool._run_async", return_value={"success": True})
|
||||
def test_dict_data_passthrough(self, mock_run):
|
||||
"""Dict data (JSON tool calling mode) still works unchanged."""
|
||||
_handle_call_service({
|
||||
"domain": "light",
|
||||
"service": "turn_on",
|
||||
"entity_id": "light.bedroom",
|
||||
"data": {"brightness": 255},
|
||||
})
|
||||
mock_run.assert_called_once()
|
||||
|
||||
def test_invalid_json_string_returns_error(self):
|
||||
"""Malformed JSON string in data returns a clear error."""
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "light",
|
||||
"service": "turn_on",
|
||||
"entity_id": "light.bedroom",
|
||||
"data": "{not valid json}",
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "Invalid JSON" in result["error"]
|
||||
|
||||
@patch("tools.homeassistant_tool._run_async", return_value={"success": True})
|
||||
def test_empty_string_data_becomes_none(self, mock_run):
|
||||
"""Empty/whitespace string data is treated as None."""
|
||||
_handle_call_service({
|
||||
"domain": "light",
|
||||
"service": "turn_on",
|
||||
"entity_id": "light.bedroom",
|
||||
"data": " ",
|
||||
})
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: domain/service name format validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestServiceNameValidation:
|
||||
"""Verify domain/service format validation prevents path traversal in URL.
|
||||
|
||||
The domain and service parameters are interpolated into
|
||||
/api/services/{domain}/{service}, so allowing arbitrary strings would
|
||||
enable SSRF via path traversal or blocked-domain bypass.
|
||||
"""
|
||||
|
||||
def test_valid_domain_names(self):
|
||||
assert _SERVICE_NAME_RE.match("light")
|
||||
assert _SERVICE_NAME_RE.match("switch")
|
||||
assert _SERVICE_NAME_RE.match("climate")
|
||||
assert _SERVICE_NAME_RE.match("shell_command")
|
||||
assert _SERVICE_NAME_RE.match("media_player")
|
||||
|
||||
def test_valid_service_names(self):
|
||||
assert _SERVICE_NAME_RE.match("turn_on")
|
||||
assert _SERVICE_NAME_RE.match("turn_off")
|
||||
assert _SERVICE_NAME_RE.match("set_temperature")
|
||||
assert _SERVICE_NAME_RE.match("toggle")
|
||||
|
||||
def test_path_traversal_in_domain_rejected(self):
|
||||
assert _SERVICE_NAME_RE.match("../../api/config") is None
|
||||
assert _SERVICE_NAME_RE.match("light/../../../etc") is None
|
||||
assert _SERVICE_NAME_RE.match("../config") is None
|
||||
|
||||
def test_path_traversal_in_service_rejected(self):
|
||||
assert _SERVICE_NAME_RE.match("../../api/config") is None
|
||||
assert _SERVICE_NAME_RE.match("turn_on/../../config") is None
|
||||
|
||||
def test_blocked_domain_bypass_via_traversal_rejected(self):
|
||||
"""Ensure shell_command/../light is rejected, not just checked against blocklist."""
|
||||
assert _SERVICE_NAME_RE.match("shell_command/../light") is None
|
||||
assert _SERVICE_NAME_RE.match("python_script/../scene") is None
|
||||
assert _SERVICE_NAME_RE.match("hassio/../automation") is None
|
||||
|
||||
def test_slashes_rejected(self):
|
||||
assert _SERVICE_NAME_RE.match("light/turn_on") is None
|
||||
assert _SERVICE_NAME_RE.match("a/b/c") is None
|
||||
|
||||
def test_dots_rejected(self):
|
||||
assert _SERVICE_NAME_RE.match("light.turn_on") is None
|
||||
assert _SERVICE_NAME_RE.match("..") is None
|
||||
|
||||
def test_uppercase_rejected(self):
|
||||
assert _SERVICE_NAME_RE.match("LIGHT") is None
|
||||
assert _SERVICE_NAME_RE.match("Turn_On") is None
|
||||
|
||||
def test_special_chars_rejected(self):
|
||||
assert _SERVICE_NAME_RE.match("light;rm") is None
|
||||
assert _SERVICE_NAME_RE.match("light&cmd") is None
|
||||
assert _SERVICE_NAME_RE.match("light cmd") is None
|
||||
|
||||
def test_handler_rejects_traversal_domain(self):
|
||||
"""_handle_call_service must reject domain with path traversal."""
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "../../api/config",
|
||||
"service": "turn_on",
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "Invalid domain" in result["error"]
|
||||
|
||||
def test_handler_rejects_traversal_service(self):
|
||||
"""_handle_call_service must reject service with path traversal."""
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "light",
|
||||
"service": "../../api/config",
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "Invalid service" in result["error"]
|
||||
|
||||
def test_handler_rejects_blocklist_bypass_traversal(self):
|
||||
"""Blocklist bypass via shell_command/../light must be caught by format validation."""
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "shell_command/../light",
|
||||
"service": "turn_on",
|
||||
}))
|
||||
assert "error" in result
|
||||
# Must be rejected as "Invalid domain", not slip through the blocklist
|
||||
assert "Invalid domain" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestInterruptModule:
|
||||
assert not is_interrupted()
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Set from one thread, check from another."""
|
||||
"""Set from one thread targeting another thread's ident."""
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
set_interrupt(False)
|
||||
|
||||
@@ -45,11 +45,12 @@ class TestInterruptModule:
|
||||
time.sleep(0.05)
|
||||
assert not seen["value"]
|
||||
|
||||
set_interrupt(True)
|
||||
# Target the checker thread's ident so it sees the interrupt
|
||||
set_interrupt(True, thread_id=t.ident)
|
||||
t.join(timeout=1)
|
||||
assert seen["value"]
|
||||
|
||||
set_interrupt(False)
|
||||
set_interrupt(False, thread_id=t.ident)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -189,10 +190,10 @@ class TestSIGKILLEscalation:
|
||||
t.start()
|
||||
|
||||
time.sleep(0.5)
|
||||
set_interrupt(True)
|
||||
set_interrupt(True, thread_id=t.ident)
|
||||
|
||||
t.join(timeout=5)
|
||||
set_interrupt(False)
|
||||
set_interrupt(False, thread_id=t.ident)
|
||||
|
||||
assert result_holder["value"] is not None
|
||||
assert result_holder["value"]["returncode"] == 130
|
||||
|
||||
@@ -146,6 +146,40 @@ class TestTruncateAroundMatches:
|
||||
result = _truncate_around_matches(text, "KEYWORD")
|
||||
assert "KEYWORD" in result
|
||||
|
||||
def test_multiword_phrase_match_beats_individual_term(self):
|
||||
"""Full phrase deep in text should be found even when a single term
|
||||
appears much earlier in boilerplate."""
|
||||
boilerplate = "The project setup is complex. " * 500 # ~15K, has 'project' early
|
||||
filler = "x" * (MAX_SESSION_CHARS + 20000)
|
||||
target = "We reviewed the keystone project roadmap in detail."
|
||||
text = boilerplate + filler + target + filler
|
||||
result = _truncate_around_matches(text, "keystone project")
|
||||
assert "keystone project" in result.lower()
|
||||
|
||||
def test_multiword_proximity_cooccurrence(self):
|
||||
"""When exact phrase is absent, terms co-occurring within proximity
|
||||
should be preferred over a lone early term."""
|
||||
early = "project " + "a" * (MAX_SESSION_CHARS + 20000)
|
||||
# Place 'keystone' and 'project' near each other (but not as exact phrase)
|
||||
cooccur = "this keystone initiative for the project was pivotal"
|
||||
tail = "b" * (MAX_SESSION_CHARS + 20000)
|
||||
text = early + cooccur + tail
|
||||
result = _truncate_around_matches(text, "keystone project")
|
||||
assert "keystone" in result.lower()
|
||||
assert "project" in result.lower()
|
||||
|
||||
def test_multiword_window_maximises_coverage(self):
|
||||
"""Sliding window should capture as many match clusters as possible."""
|
||||
# Place two phrase matches: one at ~50K, one at ~60K, both should fit
|
||||
pre = "z" * 50000
|
||||
match1 = " alpha beta "
|
||||
gap = "z" * 10000
|
||||
match2 = " alpha beta "
|
||||
post = "z" * (MAX_SESSION_CHARS + 40000)
|
||||
text = pre + match1 + gap + match2 + post
|
||||
result = _truncate_around_matches(text, "alpha beta")
|
||||
assert result.lower().count("alpha beta") == 2
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# session_search (dispatcher)
|
||||
|
||||
145
tests/tools/test_tts_speed.py
Normal file
145
tests/tools/test_tts_speed.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for TTS speed configuration across providers."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_env(monkeypatch):
|
||||
for key in ("OPENAI_API_KEY", "MINIMAX_API_KEY", "HERMES_SESSION_PLATFORM"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge TTS speed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeTtsSpeed:
|
||||
def _run(self, tts_config, tmp_path):
|
||||
mock_comm = MagicMock()
|
||||
mock_comm.save = AsyncMock()
|
||||
mock_edge = MagicMock()
|
||||
mock_edge.Communicate = MagicMock(return_value=mock_comm)
|
||||
|
||||
with patch("tools.tts_tool._import_edge_tts", return_value=mock_edge):
|
||||
from tools.tts_tool import _generate_edge_tts
|
||||
asyncio.run(_generate_edge_tts("Hello", str(tmp_path / "out.mp3"), tts_config))
|
||||
return mock_edge.Communicate
|
||||
|
||||
def test_default_no_rate_kwarg(self, tmp_path):
|
||||
"""No speed config => no rate kwarg passed to Communicate."""
|
||||
comm_cls = self._run({}, tmp_path)
|
||||
kwargs = comm_cls.call_args[1]
|
||||
assert "rate" not in kwargs
|
||||
|
||||
def test_global_speed_applied(self, tmp_path):
|
||||
"""Global tts.speed used as fallback."""
|
||||
comm_cls = self._run({"speed": 1.5}, tmp_path)
|
||||
kwargs = comm_cls.call_args[1]
|
||||
assert kwargs["rate"] == "+50%"
|
||||
|
||||
def test_provider_speed_overrides_global(self, tmp_path):
|
||||
"""tts.edge.speed takes precedence over tts.speed."""
|
||||
comm_cls = self._run({"speed": 1.5, "edge": {"speed": 2.0}}, tmp_path)
|
||||
kwargs = comm_cls.call_args[1]
|
||||
assert kwargs["rate"] == "+100%"
|
||||
|
||||
def test_speed_below_one(self, tmp_path):
|
||||
"""Speed < 1.0 produces a negative rate string."""
|
||||
comm_cls = self._run({"speed": 0.5}, tmp_path)
|
||||
kwargs = comm_cls.call_args[1]
|
||||
assert kwargs["rate"] == "-50%"
|
||||
|
||||
def test_speed_exactly_one_no_rate(self, tmp_path):
|
||||
"""Explicit speed=1.0 should not pass rate kwarg."""
|
||||
comm_cls = self._run({"speed": 1.0}, tmp_path)
|
||||
kwargs = comm_cls.call_args[1]
|
||||
assert "rate" not in kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI TTS speed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOpenaiTtsSpeed:
|
||||
def _run(self, tts_config, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
mock_response = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.speech.create.return_value = mock_response
|
||||
mock_cls = MagicMock(return_value=mock_client)
|
||||
|
||||
with patch("tools.tts_tool._import_openai_client", return_value=mock_cls), \
|
||||
patch("tools.tts_tool._resolve_openai_audio_client_config",
|
||||
return_value=("test-key", None)):
|
||||
from tools.tts_tool import _generate_openai_tts
|
||||
_generate_openai_tts("Hello", str(tmp_path / "out.mp3"), tts_config)
|
||||
return mock_client.audio.speech.create
|
||||
|
||||
def test_default_no_speed_kwarg(self, tmp_path, monkeypatch):
|
||||
"""No speed config => no speed kwarg in create call."""
|
||||
create = self._run({}, tmp_path, monkeypatch)
|
||||
kwargs = create.call_args[1]
|
||||
assert "speed" not in kwargs
|
||||
|
||||
def test_global_speed_applied(self, tmp_path, monkeypatch):
|
||||
"""Global tts.speed used as fallback."""
|
||||
create = self._run({"speed": 1.5}, tmp_path, monkeypatch)
|
||||
kwargs = create.call_args[1]
|
||||
assert kwargs["speed"] == 1.5
|
||||
|
||||
def test_provider_speed_overrides_global(self, tmp_path, monkeypatch):
|
||||
"""tts.openai.speed takes precedence over tts.speed."""
|
||||
create = self._run({"speed": 1.5, "openai": {"speed": 2.0}}, tmp_path, monkeypatch)
|
||||
kwargs = create.call_args[1]
|
||||
assert kwargs["speed"] == 2.0
|
||||
|
||||
def test_speed_clamped_low(self, tmp_path, monkeypatch):
|
||||
"""Speed below 0.25 is clamped to 0.25."""
|
||||
create = self._run({"speed": 0.1}, tmp_path, monkeypatch)
|
||||
kwargs = create.call_args[1]
|
||||
assert kwargs["speed"] == 0.25
|
||||
|
||||
def test_speed_clamped_high(self, tmp_path, monkeypatch):
|
||||
"""Speed above 4.0 is clamped to 4.0."""
|
||||
create = self._run({"speed": 10.0}, tmp_path, monkeypatch)
|
||||
kwargs = create.call_args[1]
|
||||
assert kwargs["speed"] == 4.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MiniMax TTS speed (global fallback wired)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMinimaxTtsSpeed:
|
||||
def _run(self, tts_config, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": {"audio": "deadbeef"},
|
||||
"base_resp": {"status_code": 0, "status_msg": "success"},
|
||||
"extra_info": {"audio_size": 8},
|
||||
}
|
||||
|
||||
# requests is imported locally inside _generate_minimax_tts
|
||||
with patch("requests.post", return_value=mock_response) as mock_post:
|
||||
from tools.tts_tool import _generate_minimax_tts
|
||||
_generate_minimax_tts("Hello", str(tmp_path / "out.mp3"), tts_config)
|
||||
return mock_post
|
||||
|
||||
def test_global_speed_fallback(self, tmp_path, monkeypatch):
|
||||
"""Global tts.speed used when minimax.speed not set."""
|
||||
mock_post = self._run({"speed": 1.5}, tmp_path, monkeypatch)
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert payload["voice_setting"]["speed"] == 1.5
|
||||
|
||||
def test_provider_speed_overrides_global(self, tmp_path, monkeypatch):
|
||||
"""tts.minimax.speed takes precedence over tts.speed."""
|
||||
mock_post = self._run(
|
||||
{"speed": 1.5, "minimax": {"speed": 2.0}}, tmp_path, monkeypatch
|
||||
)
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert payload["voice_setting"]["speed"] == 2.0
|
||||
@@ -463,8 +463,6 @@ class TestVisionRequirements:
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AUXILIARY_VISION_PROVIDER", raising=False)
|
||||
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
|
||||
|
||||
assert check_vision_requirements() is True
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ def _make_voice_cli(**overrides):
|
||||
cli._voice_tts_done.set()
|
||||
cli._pending_input = queue.Queue()
|
||||
cli._app = None
|
||||
cli._attached_images = []
|
||||
cli.console = SimpleNamespace(width=80)
|
||||
for k, v in overrides.items():
|
||||
setattr(cli, k, v)
|
||||
|
||||
@@ -190,17 +190,38 @@ class TestGatewayCleanupWiring:
|
||||
def test_gateway_stop_calls_close(self):
|
||||
"""gateway stop() should call close() on all running agents."""
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
runner = MagicMock()
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._running = True
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner.adapters = {}
|
||||
runner._background_tasks = set()
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_code = None
|
||||
runner._stop_task = None
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = 5.0
|
||||
runner._voice_mode = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._update_prompt_pending = {}
|
||||
runner._busy_input_mode = "interrupt"
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner._update_runtime_status = MagicMock()
|
||||
|
||||
mock_agent_1 = MagicMock()
|
||||
mock_agent_2 = MagicMock()
|
||||
@@ -209,8 +230,6 @@ class TestGatewayCleanupWiring:
|
||||
"session-2": mock_agent_2,
|
||||
}
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with patch("gateway.status.remove_pid_file"), \
|
||||
|
||||
Reference in New Issue
Block a user