Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor

This commit is contained in:
Brooklyn Nicholson
2026-04-12 13:18:55 -05:00
131 changed files with 12350 additions and 1164 deletions

View File

@@ -971,6 +971,74 @@ class TestTaskSpecificOverrides:
client, model = get_text_auxiliary_client("compression")
assert model == "google/gemini-3-flash-preview" # auto → OpenRouter
def test_resolve_auto_prefers_live_main_runtime_over_persisted_config(self, monkeypatch, tmp_path):
"""Session-only live model switches should override persisted config for auto routing."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "config.yaml").write_text(
"""model:
default: glm-5.1
provider: opencode-go
compression:
summary_provider: auto
"""
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
calls = []
def _fake_resolve(provider, model=None, *args, **kwargs):
calls.append((provider, model, kwargs))
return MagicMock(), model or "resolved-model"
with patch("agent.auxiliary_client.resolve_provider_client", side_effect=_fake_resolve):
client, model = _resolve_auto(
main_runtime={
"provider": "openai-codex",
"model": "gpt-5.4",
"api_mode": "codex_responses",
}
)
assert client is not None
assert model == "gpt-5.4"
assert calls[0][0] == "openai-codex"
assert calls[0][1] == "gpt-5.4"
assert calls[0][2]["api_mode"] == "codex_responses"
def test_explicit_compression_pin_still_wins_over_live_main_runtime(self, monkeypatch, tmp_path):
"""Task-level compression config should beat a live session override."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "config.yaml").write_text(
"""auxiliary:
compression:
provider: openrouter
model: google/gemini-3-flash-preview
model:
default: glm-5.1
provider: opencode-go
"""
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(MagicMock(), "google/gemini-3-flash-preview")) as mock_resolve:
client, model = get_text_auxiliary_client(
"compression",
main_runtime={
"provider": "openai-codex",
"model": "gpt-5.4",
},
)
assert client is not None
assert model == "google/gemini-3-flash-preview"
assert mock_resolve.call_args.args[0] == "openrouter"
assert mock_resolve.call_args.kwargs["main_runtime"] == {
"provider": "openai-codex",
"model": "gpt-5.4",
}
def test_compression_summary_base_url_from_config(self, monkeypatch, tmp_path):
"""compression.summary_base_url should produce a custom-endpoint client."""
hermes_home = tmp_path / "hermes"
@@ -1560,3 +1628,74 @@ class TestStaleBaseUrlWarning:
assert not any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
"Warning should not fire a second time"
# ---------------------------------------------------------------------------
# Anthropic-compatible image block conversion
# ---------------------------------------------------------------------------
class TestAnthropicCompatImageConversion:
"""Tests for _is_anthropic_compat_endpoint and _convert_openai_images_to_anthropic."""
def test_known_providers_detected(self):
from agent.auxiliary_client import _is_anthropic_compat_endpoint
assert _is_anthropic_compat_endpoint("minimax", "")
assert _is_anthropic_compat_endpoint("minimax-cn", "")
def test_openrouter_not_detected(self):
from agent.auxiliary_client import _is_anthropic_compat_endpoint
assert not _is_anthropic_compat_endpoint("openrouter", "")
assert not _is_anthropic_compat_endpoint("anthropic", "")
def test_url_based_detection(self):
from agent.auxiliary_client import _is_anthropic_compat_endpoint
assert _is_anthropic_compat_endpoint("custom", "https://api.minimax.io/anthropic")
assert _is_anthropic_compat_endpoint("custom", "https://example.com/anthropic/v1")
assert not _is_anthropic_compat_endpoint("custom", "https://api.openai.com/v1")
def test_base64_image_converted(self):
from agent.auxiliary_client import _convert_openai_images_to_anthropic
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR="}}
]
}]
result = _convert_openai_images_to_anthropic(messages)
img_block = result[0]["content"][1]
assert img_block["type"] == "image"
assert img_block["source"]["type"] == "base64"
assert img_block["source"]["media_type"] == "image/png"
assert img_block["source"]["data"] == "iVBOR="
def test_url_image_converted(self):
from agent.auxiliary_client import _convert_openai_images_to_anthropic
messages = [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}}
]
}]
result = _convert_openai_images_to_anthropic(messages)
img_block = result[0]["content"][0]
assert img_block["type"] == "image"
assert img_block["source"]["type"] == "url"
assert img_block["source"]["url"] == "https://example.com/img.jpg"
def test_text_only_messages_unchanged(self):
from agent.auxiliary_client import _convert_openai_images_to_anthropic
messages = [{"role": "user", "content": "Hello"}]
result = _convert_openai_images_to_anthropic(messages)
assert result[0] is messages[0] # same object, not copied
def test_jpeg_media_type_parsed(self):
from agent.auxiliary_client import _convert_openai_images_to_anthropic
messages = [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/="}}
]
}]
result = _convert_openai_images_to_anthropic(messages)
assert result[0]["content"][0]["source"]["media_type"] == "image/jpeg"

View File

@@ -0,0 +1,139 @@
"""Tests for focus_topic flowing through the compressor.
Verifies that _generate_summary and compress accept and use the focus_topic
parameter correctly. Inspired by Claude Code's /compact <focus>.
"""
from unittest.mock import MagicMock, patch
from agent.context_compressor import ContextCompressor
def _make_compressor():
"""Create a ContextCompressor with minimal state for testing."""
compressor = ContextCompressor.__new__(ContextCompressor)
compressor.protect_first_n = 2
compressor.protect_last_n = 5
compressor.tail_token_budget = 20000
compressor.context_length = 200000
compressor.threshold_percent = 0.80
compressor.threshold_tokens = 160000
compressor.max_summary_tokens = 10000
compressor.quiet_mode = True
compressor.compression_count = 0
compressor.last_prompt_tokens = 0
compressor._previous_summary = None
compressor._summary_failure_cooldown_until = 0.0
compressor.summary_model = None
return compressor
def test_focus_topic_injected_into_summary_prompt():
"""When focus_topic is provided, the LLM prompt includes focus guidance."""
compressor = _make_compressor()
turns = [
{"role": "user", "content": "Tell me about the database schema"},
{"role": "assistant", "content": "The schema has tables: users, orders, products."},
]
captured_prompt = {}
def mock_call_llm(**kwargs):
captured_prompt["messages"] = kwargs["messages"]
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = "## Goal\nUnderstand DB schema."
return resp
with patch("agent.context_compressor.call_llm", mock_call_llm):
result = compressor._generate_summary(turns, focus_topic="database schema")
assert result is not None
prompt_text = captured_prompt["messages"][0]["content"]
assert 'FOCUS TOPIC: "database schema"' in prompt_text
assert "PRIORITISE" in prompt_text
assert "60-70%" in prompt_text
def test_no_focus_topic_no_injection():
"""Without focus_topic, the prompt doesn't contain focus guidance."""
compressor = _make_compressor()
turns = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
captured_prompt = {}
def mock_call_llm(**kwargs):
captured_prompt["messages"] = kwargs["messages"]
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = "## Goal\nGreeting."
return resp
with patch("agent.context_compressor.call_llm", mock_call_llm):
result = compressor._generate_summary(turns)
prompt_text = captured_prompt["messages"][0]["content"]
assert "FOCUS TOPIC" not in prompt_text
def test_compress_passes_focus_to_generate_summary():
"""compress() passes focus_topic through to _generate_summary."""
compressor = _make_compressor()
# Track what _generate_summary receives
received_kwargs = {}
original_generate = compressor._generate_summary
def tracking_generate(turns, **kwargs):
received_kwargs.update(kwargs)
return "## Goal\nTest."
compressor._generate_summary = tracking_generate
messages = [
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "first"},
{"role": "assistant", "content": "reply1"},
{"role": "user", "content": "second"},
{"role": "assistant", "content": "reply2"},
{"role": "user", "content": "third"},
{"role": "assistant", "content": "reply3"},
{"role": "user", "content": "fourth"},
{"role": "assistant", "content": "reply4"},
]
compressor.compress(messages, current_tokens=100000, focus_topic="authentication flow")
assert received_kwargs.get("focus_topic") == "authentication flow"
def test_compress_none_focus_by_default():
"""compress() passes None focus_topic by default."""
compressor = _make_compressor()
received_kwargs = {}
def tracking_generate(turns, **kwargs):
received_kwargs.update(kwargs)
return "## Goal\nTest."
compressor._generate_summary = tracking_generate
messages = [
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "first"},
{"role": "assistant", "content": "reply1"},
{"role": "user", "content": "second"},
{"role": "assistant", "content": "reply2"},
{"role": "user", "content": "third"},
{"role": "assistant", "content": "reply3"},
{"role": "user", "content": "fourth"},
{"role": "assistant", "content": "reply4"},
]
compressor.compress(messages, current_tokens=100000)
assert received_kwargs.get("focus_topic") is None

View File

@@ -191,6 +191,37 @@ class TestNonStringContent:
kwargs = mock_call.call_args.kwargs
assert "temperature" not in kwargs
def test_summary_call_passes_live_main_runtime(self):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "ok"
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
c = ContextCompressor(
model="gpt-5.4",
provider="openai-codex",
base_url="https://chatgpt.com/backend-api/codex",
api_key="codex-token",
api_mode="codex_responses",
quiet_mode=True,
)
messages = [
{"role": "user", "content": "do something"},
{"role": "assistant", "content": "ok"},
]
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_call:
c._generate_summary(messages)
assert mock_call.call_args.kwargs["main_runtime"] == {
"model": "gpt-5.4",
"provider": "openai-codex",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_key": "codex-token",
"api_mode": "codex_responses",
}
class TestSummaryFailureCooldown:
def test_summary_failure_enters_cooldown_and_skips_retry(self):
@@ -576,11 +607,19 @@ class TestSummaryTargetRatio:
assert c.summary_target_ratio == 0.80
def test_default_threshold_is_50_percent(self):
"""Default compression threshold should be 50%."""
"""Default compression threshold should be 50%, with a 64K floor."""
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
c = ContextCompressor(model="test", quiet_mode=True)
assert c.threshold_percent == 0.50
assert c.threshold_tokens == 50_000
# 50% of 100K = 50K, but the floor is 64K
assert c.threshold_tokens == 64_000
def test_threshold_floor_does_not_apply_above_128k(self):
"""On large-context models the 50% percentage is used directly."""
with patch("agent.context_compressor.get_model_context_length", return_value=200_000):
c = ContextCompressor(model="test", quiet_mode=True)
# 50% of 200K = 100K, which is above the 64K floor
assert c.threshold_tokens == 100_000
def test_default_protect_last_n_is_20(self):
"""Default protect_last_n should be 20."""

View File

@@ -50,7 +50,8 @@ class TestEstimateTokensRough:
assert estimate_tokens_rough("a" * 400) == 100
def test_short_text(self):
assert estimate_tokens_rough("hello") == 1
# "hello" = 5 chars → ceil(5/4) = 2
assert estimate_tokens_rough("hello") == 2
def test_proportional(self):
short = estimate_tokens_rough("hello world")
@@ -68,10 +69,11 @@ class TestEstimateMessagesTokensRough:
assert estimate_messages_tokens_rough([]) == 0
def test_single_message_concrete_value(self):
"""Verify against known str(msg) length."""
"""Verify against known str(msg) length (ceiling division)."""
msg = {"role": "user", "content": "a" * 400}
result = estimate_messages_tokens_rough([msg])
expected = len(str(msg)) // 4
n = len(str(msg))
expected = (n + 3) // 4
assert result == expected
def test_multiple_messages_additive(self):
@@ -80,7 +82,8 @@ class TestEstimateMessagesTokensRough:
{"role": "assistant", "content": "Hi there, how can I help?"},
]
result = estimate_messages_tokens_rough(msgs)
expected = sum(len(str(m)) for m in msgs) // 4
n = sum(len(str(m)) for m in msgs)
expected = (n + 3) // 4
assert result == expected
def test_tool_call_message(self):
@@ -89,7 +92,7 @@ class TestEstimateMessagesTokensRough:
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
result = estimate_messages_tokens_rough([msg])
assert result > 0
assert result == len(str(msg)) // 4
assert result == (len(str(msg)) + 3) // 4
def test_message_with_list_content(self):
"""Vision messages with multimodal content arrays."""
@@ -98,7 +101,7 @@ class TestEstimateMessagesTokensRough:
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
]}
result = estimate_messages_tokens_rough([msg])
assert result == len(str(msg)) // 4
assert result == (len(str(msg)) + 3) // 4
# =========================================================================

View File

@@ -87,7 +87,10 @@ class TestProviderMapping:
def test_unmapped_provider_not_in_dict(self):
assert "nous" not in PROVIDER_TO_MODELS_DEV
assert "openai-codex" not in PROVIDER_TO_MODELS_DEV
def test_openai_codex_mapped_to_openai(self):
assert PROVIDER_TO_MODELS_DEV["openai"] == "openai"
assert PROVIDER_TO_MODELS_DEV["openai-codex"] == "openai"
class TestExtractContext:

View File

@@ -18,6 +18,7 @@ from agent.prompt_builder import (
build_skills_system_prompt,
build_nous_subscription_prompt,
build_context_files_prompt,
build_environment_hints,
CONTEXT_FILE_MAX_CHARS,
DEFAULT_AGENT_IDENTITY,
TOOL_USE_ENFORCEMENT_GUIDANCE,
@@ -26,6 +27,7 @@ from agent.prompt_builder import (
MEMORY_GUIDANCE,
SESSION_SEARCH_GUIDANCE,
PLATFORM_HINTS,
WSL_ENVIRONMENT_HINT,
)
from hermes_cli.nous_subscription import NousFeatureState, NousSubscriptionFeatures
@@ -770,6 +772,29 @@ class TestPromptBuilderConstants:
assert "cli" in PLATFORM_HINTS
# =========================================================================
# Environment hints
# =========================================================================
class TestEnvironmentHints:
def test_wsl_hint_constant_mentions_mnt(self):
assert "/mnt/c/" in WSL_ENVIRONMENT_HINT
assert "WSL" in WSL_ENVIRONMENT_HINT
def test_build_environment_hints_on_wsl(self, monkeypatch):
import agent.prompt_builder as _pb
monkeypatch.setattr(_pb, "is_wsl", lambda: True)
result = _pb.build_environment_hints()
assert "/mnt/" in result
assert "WSL" in result
def test_build_environment_hints_not_wsl(self, monkeypatch):
import agent.prompt_builder as _pb
monkeypatch.setattr(_pb, "is_wsl", lambda: False)
result = _pb.build_environment_hints()
assert result == ""
# =========================================================================
# Conditional skill activation
# =========================================================================
@@ -1009,65 +1034,4 @@ class TestOpenAIModelExecutionGuidance:
# =========================================================================
class TestStripBudgetWarningsFromHistory:
def test_strips_json_budget_warning_key(self):
import json
from run_agent import _strip_budget_warnings_from_history
messages = [
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
"output": "hello",
"exit_code": 0,
"_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]",
})},
]
_strip_budget_warnings_from_history(messages)
parsed = json.loads(messages[0]["content"])
assert "_budget_warning" not in parsed
assert parsed["output"] == "hello"
assert parsed["exit_code"] == 0
def test_strips_text_budget_warning(self):
from run_agent import _strip_budget_warnings_from_history
messages = [
{"role": "tool", "tool_call_id": "c1",
"content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
]
_strip_budget_warnings_from_history(messages)
assert messages[0]["content"] == "some result"
def test_leaves_non_tool_messages_unchanged(self):
from run_agent import _strip_budget_warnings_from_history
messages = [
{"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
{"role": "user", "content": "hello"},
]
original_contents = [m["content"] for m in messages]
_strip_budget_warnings_from_history(messages)
assert [m["content"] for m in messages] == original_contents
def test_handles_empty_and_missing_content(self):
from run_agent import _strip_budget_warnings_from_history
messages = [
{"role": "tool", "tool_call_id": "c1", "content": ""},
{"role": "tool", "tool_call_id": "c2"},
]
_strip_budget_warnings_from_history(messages)
assert messages[0]["content"] == ""
def test_strips_caution_variant(self):
import json
from run_agent import _strip_budget_warnings_from_history
messages = [
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
"output": "ok",
"_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]",
})},
]
_strip_budget_warnings_from_history(messages)
parsed = json.loads(messages[0]["content"])
assert "_budget_warning" not in parsed

View File

@@ -0,0 +1,118 @@
"""Tests for /compress <focus> — guided compression with focus topic.
Inspired by Claude Code's /compact <focus> feature.
"""
from unittest.mock import MagicMock, patch
from tests.cli.test_cli_init import _make_cli
def _make_history() -> list[dict[str, str]]:
return [
{"role": "user", "content": "one"},
{"role": "assistant", "content": "two"},
{"role": "user", "content": "three"},
{"role": "assistant", "content": "four"},
]
def test_focus_topic_extracted_and_passed(capsys):
"""Focus topic is extracted from the command and passed to _compress_context."""
shell = _make_cli()
history = _make_history()
compressed = [history[0], history[-1]]
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent._compress_context.return_value = (compressed, "")
def _estimate(messages):
if messages is history:
return 100
return 50
with patch("agent.model_metadata.estimate_messages_tokens_rough", side_effect=_estimate):
shell._manual_compress("/compress database schema")
output = capsys.readouterr().out
assert 'focus: "database schema"' in output
# Verify focus_topic was passed through
shell.agent._compress_context.assert_called_once()
call_kwargs = shell.agent._compress_context.call_args
assert call_kwargs.kwargs.get("focus_topic") == "database schema"
def test_no_focus_topic_when_bare_command(capsys):
"""When no focus topic is provided, None is passed."""
shell = _make_cli()
history = _make_history()
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent._compress_context.return_value = (list(history), "")
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
shell._manual_compress("/compress")
shell.agent._compress_context.assert_called_once()
call_kwargs = shell.agent._compress_context.call_args
assert call_kwargs.kwargs.get("focus_topic") is None
def test_empty_focus_after_command_treated_as_none(capsys):
"""Trailing whitespace after /compress does not produce a focus topic."""
shell = _make_cli()
history = _make_history()
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent._compress_context.return_value = (list(history), "")
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
shell._manual_compress("/compress ")
shell.agent._compress_context.assert_called_once()
call_kwargs = shell.agent._compress_context.call_args
assert call_kwargs.kwargs.get("focus_topic") is None
def test_focus_topic_printed_in_compression_banner(capsys):
"""The focus topic shows in the compression progress banner."""
shell = _make_cli()
history = _make_history()
compressed = [history[0], history[-1]]
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent._compress_context.return_value = (compressed, "")
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
shell._manual_compress("/compress API endpoints")
output = capsys.readouterr().out
assert 'focus: "API endpoints"' in output
def test_no_focus_prints_standard_banner(capsys):
"""Without focus, the standard banner (no focus: line) is printed."""
shell = _make_cli()
history = _make_history()
compressed = [history[0], history[-1]]
shell.conversation_history = history
shell.agent = MagicMock()
shell.agent.compression_enabled = True
shell.agent._cached_system_prompt = ""
shell.agent._compress_context.return_value = (compressed, "")
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
shell._manual_compress("/compress")
output = capsys.readouterr().out
assert "focus:" not in output
assert "Compressing" in output

View File

@@ -0,0 +1,189 @@
"""Tests for stacked tool progress scrollback lines in the CLI TUI.
When tool_progress_mode is "all" or "new", _on_tool_progress should print
persistent lines to scrollback on tool.completed, restoring the stacked
tool history that was lost when the TUI switched to a single-line spinner.
"""
import os
import sys
import importlib
from unittest.mock import MagicMock, patch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# Module-level reference to the cli module (set by _make_cli on first call)
_cli_mod = None
def _make_cli(tool_progress="all"):
"""Create a HermesCLI instance with minimal mocking."""
global _cli_mod
_clean_config = {
"model": {
"default": "anthropic/claude-opus-4.6",
"base_url": "https://openrouter.ai/api/v1",
"provider": "auto",
},
"display": {"compact": False, "tool_progress": tool_progress},
"agent": {},
"terminal": {"env_type": "local"},
}
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
prompt_toolkit_stubs = {
"prompt_toolkit": MagicMock(),
"prompt_toolkit.history": MagicMock(),
"prompt_toolkit.styles": MagicMock(),
"prompt_toolkit.patch_stdout": MagicMock(),
"prompt_toolkit.application": MagicMock(),
"prompt_toolkit.layout": MagicMock(),
"prompt_toolkit.layout.processors": MagicMock(),
"prompt_toolkit.filters": MagicMock(),
"prompt_toolkit.layout.dimension": MagicMock(),
"prompt_toolkit.layout.menus": MagicMock(),
"prompt_toolkit.widgets": MagicMock(),
"prompt_toolkit.key_binding": MagicMock(),
"prompt_toolkit.completion": MagicMock(),
"prompt_toolkit.formatted_text": MagicMock(),
"prompt_toolkit.auto_suggest": MagicMock(),
}
with patch.dict(sys.modules, prompt_toolkit_stubs), \
patch.dict("os.environ", clean_env, clear=False):
import cli as mod
mod = importlib.reload(mod)
_cli_mod = mod
with patch.object(mod, "get_tool_definitions", return_value=[]), \
patch.dict(mod.__dict__, {"CLI_CONFIG": _clean_config}):
return mod.HermesCLI()
class TestToolProgressScrollback:
"""Stacked scrollback lines for 'all' and 'new' modes."""
def test_all_mode_prints_scrollback_on_completed(self):
"""In 'all' mode, tool.completed prints a stacked line."""
cli = _make_cli(tool_progress="all")
# Simulate tool.started
cli._on_tool_progress("tool.started", "terminal", "git log", {"command": "git log"})
# Simulate tool.completed
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=1.5, is_error=False)
mock_print.assert_called_once()
line = mock_print.call_args[0][0]
# Should contain tool info (the cute message format has "git log" for terminal)
assert "git log" in line or "$" in line
def test_all_mode_prints_every_call(self):
"""In 'all' mode, consecutive calls to the same tool each get a line."""
cli = _make_cli(tool_progress="all")
with patch.object(_cli_mod, "_cprint") as mock_print:
# First call
cli._on_tool_progress("tool.started", "read_file", "cli.py", {"path": "cli.py"})
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.1, is_error=False)
# Second call (same tool)
cli._on_tool_progress("tool.started", "read_file", "run_agent.py", {"path": "run_agent.py"})
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.2, is_error=False)
assert mock_print.call_count == 2
def test_new_mode_skips_consecutive_repeats(self):
"""In 'new' mode, consecutive calls to the same tool only print once."""
cli = _make_cli(tool_progress="new")
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.started", "read_file", "cli.py", {"path": "cli.py"})
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.1, is_error=False)
cli._on_tool_progress("tool.started", "read_file", "run_agent.py", {"path": "run_agent.py"})
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.2, is_error=False)
assert mock_print.call_count == 1 # Only the first read_file
def test_new_mode_prints_when_tool_changes(self):
"""In 'new' mode, a different tool name triggers a new line."""
cli = _make_cli(tool_progress="new")
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.started", "read_file", "cli.py", {"path": "cli.py"})
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.1, is_error=False)
cli._on_tool_progress("tool.started", "search_files", "pattern", {"pattern": "test"})
cli._on_tool_progress("tool.completed", "search_files", None, None, duration=0.3, is_error=False)
cli._on_tool_progress("tool.started", "read_file", "run_agent.py", {"path": "run_agent.py"})
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.2, is_error=False)
# read_file, search_files, read_file (3rd prints because search_files broke the streak)
assert mock_print.call_count == 3
def test_off_mode_no_scrollback(self):
"""In 'off' mode, no stacked lines are printed."""
cli = _make_cli(tool_progress="off")
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=False)
mock_print.assert_not_called()
def test_error_suffix_on_failed_tool(self):
"""When is_error=True, the stacked line includes [error]."""
cli = _make_cli(tool_progress="all")
cli._on_tool_progress("tool.started", "terminal", "bad cmd", {"command": "bad cmd"})
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=True)
line = mock_print.call_args[0][0]
assert "[error]" in line
def test_spinner_still_updates_on_started(self):
"""tool.started still updates the spinner text for live display."""
cli = _make_cli(tool_progress="all")
cli._on_tool_progress("tool.started", "terminal", "git status", {"command": "git status"})
assert "git status" in cli._spinner_text
def test_spinner_timer_clears_on_completed(self):
"""tool.completed still clears the tool timer."""
cli = _make_cli(tool_progress="all")
cli._on_tool_progress("tool.started", "terminal", "git status", {"command": "git status"})
assert cli._tool_start_time > 0
with patch.object(_cli_mod, "_cprint"):
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=False)
assert cli._tool_start_time == 0.0
def test_concurrent_tools_produce_stacked_lines(self):
"""Multiple tool.started followed by multiple tool.completed all produce lines."""
cli = _make_cli(tool_progress="all")
with patch.object(_cli_mod, "_cprint") as mock_print:
# All start first (concurrent pattern)
cli._on_tool_progress("tool.started", "web_search", "query 1", {"query": "test 1"})
cli._on_tool_progress("tool.started", "web_search", "query 2", {"query": "test 2"})
# All complete
cli._on_tool_progress("tool.completed", "web_search", None, None, duration=1.0, is_error=False)
cli._on_tool_progress("tool.completed", "web_search", None, None, duration=1.5, is_error=False)
assert mock_print.call_count == 2
def test_verbose_mode_no_duplicate_scrollback(self):
"""In 'verbose' mode, scrollback lines are NOT printed (run_agent handles verbose output)."""
cli = _make_cli(tool_progress="verbose")
with patch.object(_cli_mod, "_cprint") as mock_print:
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=False)
mock_print.assert_not_called()
def test_pending_info_stores_on_started(self):
"""tool.started stores args for later use by tool.completed."""
cli = _make_cli(tool_progress="all")
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
assert "terminal" in cli._pending_tool_info
assert len(cli._pending_tool_info["terminal"]) == 1
assert cli._pending_tool_info["terminal"][0] == {"command": "ls"}
def test_pending_info_consumed_on_completed(self):
"""tool.completed consumes stored args (FIFO for concurrent)."""
cli = _make_cli(tool_progress="all")
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
cli._on_tool_progress("tool.started", "terminal", "pwd", {"command": "pwd"})
assert len(cli._pending_tool_info["terminal"]) == 2
with patch.object(_cli_mod, "_cprint"):
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.1, is_error=False)
# First entry consumed, second remains
assert len(cli._pending_tool_info.get("terminal", [])) == 1
assert cli._pending_tool_info["terminal"][0] == {"command": "pwd"}

View File

@@ -0,0 +1,226 @@
"""Tests for the clean shutdown marker that prevents unwanted session auto-resets.
When the gateway shuts down gracefully (hermes update, gateway restart, /restart),
it writes a .clean_shutdown marker. On the next startup, if the marker exists,
suspend_recently_active() is skipped so users don't lose their sessions.
After a crash (no marker), suspension still fires as a safety net for stuck sessions.
"""
import os
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig, SessionResetPolicy
from gateway.session import SessionEntry, SessionSource, SessionStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
return SessionSource(platform=platform, chat_id=chat_id, user_id=user_id)
def _make_store(tmp_path, policy=None):
config = GatewayConfig()
if policy:
config.default_reset_policy = policy
return SessionStore(sessions_dir=tmp_path, config=config)
# ---------------------------------------------------------------------------
# SessionStore.suspend_recently_active
# ---------------------------------------------------------------------------
class TestSuspendRecentlyActive:
"""Verify suspend_recently_active only marks recent sessions."""
def test_suspends_recently_active_sessions(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
assert not entry.suspended
count = store.suspend_recently_active()
assert count == 1
# Re-fetch — should be suspended now
refreshed = store.get_or_create_session(source)
assert refreshed.was_auto_reset
def test_does_not_suspend_old_sessions(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
# Backdate the session's updated_at beyond the cutoff
with store._lock:
entry.updated_at = datetime.now() - timedelta(seconds=300)
store._save()
count = store.suspend_recently_active(max_age_seconds=120)
assert count == 0
def test_already_suspended_not_double_counted(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
# Suspend once
count1 = store.suspend_recently_active()
assert count1 == 1
# Create a new session (the old one got reset on next access)
entry2 = store.get_or_create_session(source)
# Suspend again — the new session is recent but not yet suspended
count2 = store.suspend_recently_active()
assert count2 == 1
# ---------------------------------------------------------------------------
# Clean shutdown marker integration
# ---------------------------------------------------------------------------
class TestCleanShutdownMarker:
"""Test that the marker file controls session suspension on startup."""
def test_marker_written_on_graceful_stop(self, tmp_path, monkeypatch):
"""stop() should write .clean_shutdown marker."""
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
marker = tmp_path / ".clean_shutdown"
assert not marker.exists()
# Create a minimal runner and call the shutdown logic directly
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._restart_requested = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_task_started = False
runner._running = True
runner._draining = False
runner._stop_task = None
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._background_tasks = set()
runner._shutdown_event = MagicMock()
runner._restart_drain_timeout = 5
runner._exit_code = None
runner._exit_reason = None
runner.adapters = {}
runner.config = GatewayConfig()
# Mock heavy dependencies
with patch("gateway.run.GatewayRunner._drain_active_agents", new_callable=AsyncMock, return_value=([], False)), \
patch("gateway.run.GatewayRunner._finalize_shutdown_agents"), \
patch("gateway.run.GatewayRunner._update_runtime_status"), \
patch("gateway.status.remove_pid_file"), \
patch("tools.process_registry.process_registry") as mock_proc_reg, \
patch("tools.terminal_tool.cleanup_all_environments"), \
patch("tools.browser_tool.cleanup_all_browsers"):
mock_proc_reg.kill_all = MagicMock()
import asyncio
asyncio.get_event_loop().run_until_complete(runner.stop())
assert marker.exists(), ".clean_shutdown marker should exist after graceful stop"
def test_marker_skips_suspension_on_startup(self, tmp_path, monkeypatch):
"""If .clean_shutdown exists, suspend_recently_active should NOT be called."""
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
# Create the marker
marker = tmp_path / ".clean_shutdown"
marker.touch()
# Create a store with a recently active session
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
assert not entry.suspended
# Simulate what start() does:
if marker.exists():
marker.unlink()
# Should NOT call suspend_recently_active
else:
store.suspend_recently_active()
# Session should NOT be suspended
with store._lock:
store._ensure_loaded_locked()
for e in store._entries.values():
assert not e.suspended, "Session should NOT be suspended after clean shutdown"
assert not marker.exists(), "Marker should be cleaned up"
def test_no_marker_triggers_suspension(self, tmp_path, monkeypatch):
"""Without .clean_shutdown marker (crash), suspension should fire."""
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
marker = tmp_path / ".clean_shutdown"
assert not marker.exists()
# Create a store with a recently active session
store = _make_store(tmp_path)
source = _make_source()
entry = store.get_or_create_session(source)
assert not entry.suspended
# Simulate what start() does:
if marker.exists():
marker.unlink()
else:
store.suspend_recently_active()
# Session SHOULD be suspended (crash recovery)
with store._lock:
store._ensure_loaded_locked()
suspended_count = sum(1 for e in store._entries.values() if e.suspended)
assert suspended_count == 1, "Session should be suspended after crash (no marker)"
def test_marker_written_on_restart_stop(self, tmp_path, monkeypatch):
"""stop(restart=True) should also write the marker."""
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
marker = tmp_path / ".clean_shutdown"
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._restart_requested = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_task_started = False
runner._running = True
runner._draining = False
runner._stop_task = None
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._background_tasks = set()
runner._shutdown_event = MagicMock()
runner._restart_drain_timeout = 5
runner._exit_code = None
runner._exit_reason = None
runner.adapters = {}
runner.config = GatewayConfig()
with patch("gateway.run.GatewayRunner._drain_active_agents", new_callable=AsyncMock, return_value=([], False)), \
patch("gateway.run.GatewayRunner._finalize_shutdown_agents"), \
patch("gateway.run.GatewayRunner._update_runtime_status"), \
patch("gateway.status.remove_pid_file"), \
patch("tools.process_registry.process_registry") as mock_proc_reg, \
patch("tools.terminal_tool.cleanup_all_environments"), \
patch("tools.browser_tool.cleanup_all_browsers"):
mock_proc_reg.kill_all = MagicMock()
import asyncio
asyncio.get_event_loop().run_until_complete(runner.stop(restart=True))
assert marker.exists(), ".clean_shutdown marker should exist after restart-stop too"

View File

@@ -0,0 +1,118 @@
"""Tests for gateway /compress <focus> — focus topic on the gateway side."""
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str = "/compress") -> MessageEvent:
return MessageEvent(text=text, source=_make_source(), message_id="m1")
def _make_history() -> list[dict[str, str]]:
return [
{"role": "user", "content": "one"},
{"role": "assistant", "content": "two"},
{"role": "user", "content": "three"},
{"role": "assistant", "content": "four"},
]
def _make_runner(history: list[dict[str, str]]):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = history
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner.session_store._save = MagicMock()
return runner
@pytest.mark.asyncio
async def test_compress_focus_topic_passed_to_agent():
"""Focus topic from /compress <focus> is passed through to _compress_context."""
history = _make_history()
compressed = [history[0], history[-1]]
runner = _make_runner(history)
agent_instance = MagicMock()
agent_instance.context_compressor.protect_first_n = 0
agent_instance.context_compressor._align_boundary_forward.return_value = 0
agent_instance.context_compressor._find_tail_cut_by_tokens.return_value = 2
agent_instance.session_id = "sess-1"
agent_instance._compress_context.return_value = (compressed, "")
def _estimate(messages):
return 100
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "***"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=agent_instance),
patch("agent.model_metadata.estimate_messages_tokens_rough", side_effect=_estimate),
):
result = await runner._handle_compress_command(_make_event("/compress database schema"))
# Verify focus_topic was passed
agent_instance._compress_context.assert_called_once()
call_kwargs = agent_instance._compress_context.call_args
assert call_kwargs.kwargs.get("focus_topic") == "database schema"
# Verify focus is mentioned in response
assert 'Focus: "database schema"' in result
@pytest.mark.asyncio
async def test_compress_no_focus_passes_none():
"""Bare /compress passes focus_topic=None."""
history = _make_history()
runner = _make_runner(history)
agent_instance = MagicMock()
agent_instance.context_compressor.protect_first_n = 0
agent_instance.context_compressor._align_boundary_forward.return_value = 0
agent_instance.context_compressor._find_tail_cut_by_tokens.return_value = 2
agent_instance.session_id = "sess-1"
agent_instance._compress_context.return_value = (list(history), "")
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "***"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=agent_instance),
patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100),
):
result = await runner._handle_compress_command(_make_event("/compress"))
agent_instance._compress_context.assert_called_once()
call_kwargs = agent_instance._compress_context.call_args
assert call_kwargs.kwargs.get("focus_topic") is None
# No focus line in response
assert "Focus:" not in result

View File

@@ -74,6 +74,26 @@ class FakeBot:
return None
class SlowSyncTree(FakeTree):
def __init__(self):
super().__init__()
self.started = asyncio.Event()
self.allow_finish = asyncio.Event()
async def _slow_sync():
self.started.set()
await self.allow_finish.wait()
return []
self.sync = AsyncMock(side_effect=_slow_sync)
class SlowSyncBot(FakeBot):
def __init__(self, *, intents, proxy=None):
super().__init__(intents=intents, proxy=proxy)
self.tree = SlowSyncTree()
@pytest.mark.asyncio
@pytest.mark.parametrize(
("allowed_users", "expected_members_intent"),
@@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
assert ok is False
assert released == [("discord-bot-token", "test-token")]
assert adapter._platform_lock_identity is None
@pytest.mark.asyncio
async def test_connect_does_not_wait_for_slash_sync(monkeypatch):
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None))
monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None)
intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False)
monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents)
created = {}
def fake_bot_factory(*, command_prefix, intents, proxy=None):
bot = SlowSyncBot(intents=intents, proxy=proxy)
created["bot"] = bot
return bot
monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory)
monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock())
ok = await asyncio.wait_for(adapter.connect(), timeout=1.0)
assert ok is True
assert adapter._ready_event.is_set()
await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0)
assert created["bot"].tree.sync.await_count == 1
created["bot"].tree.allow_finish.set()
await asyncio.sleep(0)
await adapter.disconnect()

View File

@@ -0,0 +1,355 @@
"""Tests for gateway.display_config — per-platform display/verbosity resolver."""
import pytest
# ---------------------------------------------------------------------------
# Resolver: resolution order
# ---------------------------------------------------------------------------
class TestResolveDisplaySetting:
"""resolve_display_setting() resolves with correct priority."""
def test_explicit_platform_override_wins(self):
"""display.platforms.<plat>.<key> takes top priority."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"tool_progress": "all",
"platforms": {
"telegram": {"tool_progress": "verbose"},
},
}
}
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
def test_global_setting_when_no_platform_override(self):
"""Falls back to display.<key> when no platform override exists."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"tool_progress": "new",
"platforms": {},
}
}
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
def test_platform_default_when_no_user_config(self):
"""Falls back to built-in platform default."""
from gateway.display_config import resolve_display_setting
# Empty config — should get built-in defaults
config = {}
# Telegram defaults to tier_high → "all"
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
# Email defaults to tier_minimal → "off"
assert resolve_display_setting(config, "email", "tool_progress") == "off"
def test_global_default_for_unknown_platform(self):
"""Unknown platforms get the global defaults."""
from gateway.display_config import resolve_display_setting
config = {}
# Unknown platform, no config → global default "all"
assert resolve_display_setting(config, "unknown_platform", "tool_progress") == "all"
def test_fallback_parameter_used_last(self):
"""Explicit fallback is used when nothing else matches."""
from gateway.display_config import resolve_display_setting
config = {}
# "nonexistent_key" isn't in any defaults
result = resolve_display_setting(config, "telegram", "nonexistent_key", "my_fallback")
assert result == "my_fallback"
def test_platform_override_only_affects_that_platform(self):
"""Other platforms are unaffected by a specific platform override."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"tool_progress": "all",
"platforms": {
"slack": {"tool_progress": "off"},
},
}
}
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
# ---------------------------------------------------------------------------
# Backward compatibility: tool_progress_overrides
# ---------------------------------------------------------------------------
class TestBackwardCompat:
"""Legacy tool_progress_overrides is still respected as a fallback."""
def test_legacy_overrides_read(self):
"""tool_progress_overrides is read when no platforms entry exists."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"tool_progress": "all",
"tool_progress_overrides": {
"signal": "off",
"telegram": "verbose",
},
}
}
assert resolve_display_setting(config, "signal", "tool_progress") == "off"
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
def test_new_platforms_takes_precedence_over_legacy(self):
"""display.platforms beats tool_progress_overrides."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"tool_progress": "all",
"tool_progress_overrides": {"telegram": "verbose"},
"platforms": {"telegram": {"tool_progress": "new"}},
}
}
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
def test_legacy_overrides_only_for_tool_progress(self):
"""Legacy overrides don't affect other settings."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"tool_progress_overrides": {"telegram": "verbose"},
}
}
# show_reasoning should NOT read from tool_progress_overrides
assert resolve_display_setting(config, "telegram", "show_reasoning") is False
# ---------------------------------------------------------------------------
# YAML normalisation
# ---------------------------------------------------------------------------
class TestYAMLNormalisation:
"""YAML 1.1 quirks (bare off → False, on → True) are handled."""
def test_tool_progress_false_normalised_to_off(self):
"""YAML's bare `off` parses as False — normalised to 'off' string."""
from gateway.display_config import resolve_display_setting
config = {"display": {"tool_progress": False}}
assert resolve_display_setting(config, "telegram", "tool_progress") == "off"
def test_tool_progress_true_normalised_to_all(self):
"""YAML's bare `on` parses as True — normalised to 'all'."""
from gateway.display_config import resolve_display_setting
config = {"display": {"tool_progress": True}}
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
def test_show_reasoning_string_true(self):
"""String 'true' is normalised to bool True."""
from gateway.display_config import resolve_display_setting
config = {"display": {"platforms": {"telegram": {"show_reasoning": "true"}}}}
assert resolve_display_setting(config, "telegram", "show_reasoning") is True
def test_tool_preview_length_string(self):
"""String numbers are normalised to int."""
from gateway.display_config import resolve_display_setting
config = {"display": {"platforms": {"slack": {"tool_preview_length": "80"}}}}
assert resolve_display_setting(config, "slack", "tool_preview_length") == 80
def test_platform_override_false_tool_progress(self):
"""Per-platform bare off → normalised."""
from gateway.display_config import resolve_display_setting
config = {"display": {"platforms": {"slack": {"tool_progress": False}}}}
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
# ---------------------------------------------------------------------------
# Built-in platform defaults (tier system)
# ---------------------------------------------------------------------------
class TestPlatformDefaults:
"""Built-in defaults reflect platform capability tiers."""
def test_high_tier_platforms(self):
"""Telegram and Discord default to 'all' tool progress."""
from gateway.display_config import resolve_display_setting
for plat in ("telegram", "discord"):
assert resolve_display_setting({}, plat, "tool_progress") == "all", plat
def test_medium_tier_platforms(self):
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
from gateway.display_config import resolve_display_setting
for plat in ("slack", "mattermost", "matrix", "feishu"):
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
def test_low_tier_platforms(self):
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
from gateway.display_config import resolve_display_setting
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
def test_minimal_tier_platforms(self):
"""Email, SMS, webhook default to 'off' tool progress."""
from gateway.display_config import resolve_display_setting
for plat in ("email", "sms", "webhook", "homeassistant"):
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
def test_low_tier_streaming_defaults_to_false(self):
"""Low-tier platforms default streaming to False."""
from gateway.display_config import resolve_display_setting
assert resolve_display_setting({}, "signal", "streaming") is False
assert resolve_display_setting({}, "email", "streaming") is False
def test_high_tier_streaming_defaults_to_none(self):
"""High-tier platforms default streaming to None (follow global)."""
from gateway.display_config import resolve_display_setting
assert resolve_display_setting({}, "telegram", "streaming") is None
# ---------------------------------------------------------------------------
# get_effective_display / get_platform_defaults
# ---------------------------------------------------------------------------
class TestHelpers:
"""Helper functions return correct composite results."""
def test_get_effective_display_merges_correctly(self):
from gateway.display_config import get_effective_display
config = {
"display": {
"tool_progress": "new",
"show_reasoning": True,
"platforms": {
"telegram": {"tool_progress": "verbose"},
},
}
}
eff = get_effective_display(config, "telegram")
assert eff["tool_progress"] == "verbose" # platform override
assert eff["show_reasoning"] is True # global
assert "tool_preview_length" in eff # default filled in
def test_get_platform_defaults_returns_dict(self):
from gateway.display_config import get_platform_defaults
defaults = get_platform_defaults("telegram")
assert "tool_progress" in defaults
assert "show_reasoning" in defaults
# Returns a new dict (not the shared tier dict)
defaults["tool_progress"] = "changed"
assert get_platform_defaults("telegram")["tool_progress"] != "changed"
# ---------------------------------------------------------------------------
# Config migration: tool_progress_overrides → display.platforms
# ---------------------------------------------------------------------------
class TestConfigMigration:
"""Version 16 migration moves tool_progress_overrides into display.platforms."""
def test_migration_creates_platforms_entries(self, tmp_path, monkeypatch):
"""Old overrides are migrated into display.platforms.<plat>.tool_progress."""
import yaml
config_path = tmp_path / "config.yaml"
config = {
"_config_version": 15,
"display": {
"tool_progress_overrides": {
"signal": "off",
"telegram": "all",
},
},
}
config_path.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Re-import to pick up the new HERMES_HOME
import importlib
import hermes_cli.config as cfg_mod
importlib.reload(cfg_mod)
result = cfg_mod.migrate_config(interactive=False, quiet=True)
# Re-read config
updated = yaml.safe_load(config_path.read_text())
platforms = updated.get("display", {}).get("platforms", {})
assert platforms.get("signal", {}).get("tool_progress") == "off"
assert platforms.get("telegram", {}).get("tool_progress") == "all"
def test_migration_preserves_existing_platforms_entries(self, tmp_path, monkeypatch):
"""Existing display.platforms entries are NOT overwritten by migration."""
import yaml
config_path = tmp_path / "config.yaml"
config = {
"_config_version": 15,
"display": {
"tool_progress_overrides": {"telegram": "off"},
"platforms": {"telegram": {"tool_progress": "verbose"}},
},
}
config_path.write_text(yaml.dump(config))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import importlib
import hermes_cli.config as cfg_mod
importlib.reload(cfg_mod)
cfg_mod.migrate_config(interactive=False, quiet=True)
updated = yaml.safe_load(config_path.read_text())
# Existing "verbose" should NOT be overwritten by legacy "off"
assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
# ---------------------------------------------------------------------------
# Streaming per-platform (None = follow global)
# ---------------------------------------------------------------------------
class TestStreamingPerPlatform:
"""Streaming per-platform override semantics."""
def test_none_means_follow_global(self):
"""When streaming is None, the caller should use global config."""
from gateway.display_config import resolve_display_setting
config = {}
# Telegram has no streaming override in defaults → None
result = resolve_display_setting(config, "telegram", "streaming")
assert result is None # caller should check global StreamingConfig
def test_explicit_false_disables(self):
"""Explicit False disables streaming for that platform."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"platforms": {"telegram": {"streaming": False}},
}
}
assert resolve_display_setting(config, "telegram", "streaming") is False
def test_explicit_true_enables(self):
"""Explicit True enables streaming for that platform."""
from gateway.display_config import resolve_display_setting
config = {
"display": {
"platforms": {"email": {"streaming": True}},
}
}
assert resolve_display_setting(config, "email", "streaming") is True

View File

@@ -28,12 +28,16 @@ class _FakeRegistry:
def __init__(self, sessions):
self._sessions = list(sessions)
self._completion_consumed: set = set()
def get(self, session_id):
if self._sessions:
return self._sessions.pop(0)
return None
def is_completion_consumed(self, session_id):
return session_id in self._completion_consumed
def _build_runner(monkeypatch, tmp_path) -> GatewayRunner:
"""Create a GatewayRunner with notifications set to 'all'."""

View File

@@ -157,12 +157,44 @@ def _make_fake_mautrix():
mautrix_crypto_store = types.ModuleType("mautrix.crypto.store")
class MemoryCryptoStore:
def __init__(self, account_id="", pickle_key=""):
def __init__(self, account_id="", pickle_key=""): # noqa: S301
self.account_id = account_id
self.pickle_key = pickle_key
mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore
# --- mautrix.crypto.store.asyncpg ---
mautrix_crypto_store_asyncpg = types.ModuleType("mautrix.crypto.store.asyncpg")
class PgCryptoStore:
upgrade_table = MagicMock()
def __init__(self, account_id="", pickle_key="", db=None): # noqa: S301
self.account_id = account_id
self.pickle_key = pickle_key
self.db = db
async def open(self):
pass
mautrix_crypto_store_asyncpg.PgCryptoStore = PgCryptoStore
# --- mautrix.util ---
mautrix_util = types.ModuleType("mautrix.util")
# --- mautrix.util.async_db ---
mautrix_util_async_db = types.ModuleType("mautrix.util.async_db")
class Database:
@classmethod
def create(cls, url, upgrade_table=None):
db = MagicMock()
db.start = AsyncMock()
db.stop = AsyncMock()
return db
mautrix_util_async_db.Database = Database
return {
"mautrix": mautrix,
"mautrix.api": mautrix_api,
@@ -171,6 +203,9 @@ def _make_fake_mautrix():
"mautrix.client.state_store": mautrix_client_state_store,
"mautrix.crypto": mautrix_crypto,
"mautrix.crypto.store": mautrix_crypto_store,
"mautrix.crypto.store.asyncpg": mautrix_crypto_store_asyncpg,
"mautrix.util": mautrix_util,
"mautrix.util.async_db": mautrix_util_async_db,
}
@@ -740,6 +775,12 @@ class TestMatrixAccessTokenAuth:
mock_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
mock_client.add_event_handler = MagicMock()
mock_client.handle_sync = MagicMock(return_value=[])
mock_client.query_keys = AsyncMock(return_value={
"device_keys": {"@bot:example.org": {"DEV123": {
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
}}},
})
mock_client.api = MagicMock()
mock_client.api.token = "syt_test_access_token"
mock_client.api.session = MagicMock()
@@ -751,6 +792,8 @@ class TestMatrixAccessTokenAuth:
mock_olm.share_keys = AsyncMock()
mock_olm.share_keys_min_trust = None
mock_olm.send_keys_min_trust = None
mock_olm.account = MagicMock()
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
# Patch Client constructor to return our mock
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
@@ -924,6 +967,12 @@ class TestMatrixDeviceId:
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="WHOAMI_DEV"))
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
mock_client.add_event_handler = MagicMock()
mock_client.handle_sync = MagicMock(return_value=[])
mock_client.query_keys = AsyncMock(return_value={
"device_keys": {"@bot:example.org": {"MY_STABLE_DEVICE": {
"keys": {"ed25519:MY_STABLE_DEVICE": "fake_ed25519_key"},
}}},
})
mock_client.api = MagicMock()
mock_client.api.token = "syt_test_access_token"
mock_client.api.session = MagicMock()
@@ -934,6 +983,8 @@ class TestMatrixDeviceId:
mock_olm.share_keys = AsyncMock()
mock_olm.share_keys_min_trust = None
mock_olm.send_keys_min_trust = None
mock_olm.account = MagicMock()
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
@@ -1030,8 +1081,8 @@ class TestMatrixDeviceIdConfig:
class TestMatrixSyncLoop:
@pytest.mark.asyncio
async def test_sync_loop_shares_keys_when_encryption_enabled(self):
"""_sync_loop should call crypto.share_keys() after each sync."""
async def test_sync_loop_dispatches_events_and_stores_token(self):
"""_sync_loop should call handle_sync() and persist next_batch."""
adapter = _make_adapter()
adapter._encryption = True
adapter._closing = False
@@ -1046,7 +1097,6 @@ class TestMatrixSyncLoop:
return {"rooms": {"join": {"!room:example.org": {}}}, "next_batch": "s1234"}
mock_crypto = MagicMock()
mock_crypto.share_keys = AsyncMock()
mock_sync_store = MagicMock()
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
@@ -1062,7 +1112,6 @@ class TestMatrixSyncLoop:
await adapter._sync_loop()
fake_client.sync.assert_awaited_once()
mock_crypto.share_keys.assert_awaited_once()
fake_client.handle_sync.assert_called_once()
mock_sync_store.put_next_batch.assert_awaited_once_with("s1234")
@@ -1248,6 +1297,12 @@ class TestMatrixEncryptedEventHandler:
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123"))
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
mock_client.add_event_handler = MagicMock()
mock_client.handle_sync = MagicMock(return_value=[])
mock_client.query_keys = AsyncMock(return_value={
"device_keys": {"@bot:example.org": {"DEV123": {
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
}}},
})
mock_client.api = MagicMock()
mock_client.api.token = "syt_test_token"
mock_client.api.session = MagicMock()
@@ -1258,6 +1313,8 @@ class TestMatrixEncryptedEventHandler:
mock_olm.share_keys = AsyncMock()
mock_olm.share_keys_min_trust = None
mock_olm.send_keys_min_trust = None
mock_olm.account = MagicMock()
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)

View File

@@ -8,8 +8,8 @@ from types import SimpleNamespace
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, SendResult
from gateway.config import Platform, PlatformConfig, StreamingConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
from gateway.session import SessionSource
@@ -104,6 +104,11 @@ def _make_runner(adapter):
runner._session_db = None
runner._running_agents = {}
runner.hooks = SimpleNamespace(loaded_hooks=False)
runner.config = SimpleNamespace(
thread_sessions_per_user=False,
group_sessions_per_user=False,
stt_enabled=False,
)
return runner
@@ -118,6 +123,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = FakeAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
adapter = ProgressCaptureAdapter()
runner = _make_runner(adapter)
@@ -144,7 +150,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
assert adapter.sent == [
{
"chat_id": "-1001",
"content": '⚙️ terminal: "pwd"',
"content": '💻 terminal: "pwd"',
"reply_to": None,
"metadata": {"thread_id": "17585"},
}
@@ -334,3 +340,238 @@ def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
content = adapter.sent[0]["content"]
# With a 200-char cap, the 165-char command should NOT be truncated
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
class CommentaryAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
self.stream_delta_callback = kwargs.get("stream_delta_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
if self.interim_assistant_callback:
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
time.sleep(0.1)
if self.stream_delta_callback:
self.stream_delta_callback("done")
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
class PreviewedResponseAgent:
def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
if self.interim_assistant_callback:
self.interim_assistant_callback("You're welcome.", already_streamed=False)
return {
"final_response": "You're welcome.",
"response_previewed": True,
"messages": [],
"api_calls": 1,
}
class QueuedCommentaryAgent:
calls = 0
def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
type(self).calls += 1
if type(self).calls == 1 and self.interim_assistant_callback:
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
return {
"final_response": f"final response {type(self).calls}",
"messages": [],
"api_calls": 1,
}
async def _run_with_agent(
monkeypatch,
tmp_path,
agent_cls,
*,
session_id,
pending_text=None,
config_data=None,
):
if config_data:
import yaml
(tmp_path / "config.yaml").write_text(yaml.dump(config_data), encoding="utf-8")
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = agent_cls
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
adapter = ProgressCaptureAdapter()
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
if config_data and "streaming" in config_data:
runner.config.streaming = StreamingConfig.from_dict(config_data["streaming"])
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1001",
chat_type="group",
thread_id="17585",
)
session_key = "agent:main:telegram:group:-1001:17585"
if pending_text is not None:
adapter._pending_messages[session_key] = MessageEvent(
text=pending_text,
message_type=MessageType.TEXT,
source=source,
message_id="queued-1",
)
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id=session_id,
session_key=session_key,
)
return adapter, result
@pytest.mark.asyncio
async def test_run_agent_surfaces_real_interim_commentary(monkeypatch, tmp_path):
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
CommentaryAgent,
session_id="sess-commentary",
config_data={"display": {"interim_assistant_messages": True}},
)
assert result.get("already_sent") is not True
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
@pytest.mark.asyncio
async def test_run_agent_surfaces_interim_commentary_by_default(monkeypatch, tmp_path):
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
CommentaryAgent,
session_id="sess-commentary-default-on",
)
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
@pytest.mark.asyncio
async def test_run_agent_suppresses_interim_commentary_when_disabled(monkeypatch, tmp_path):
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
CommentaryAgent,
session_id="sess-commentary-disabled",
config_data={"display": {"interim_assistant_messages": False}},
)
assert result.get("already_sent") is not True
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
@pytest.mark.asyncio
async def test_run_agent_tool_progress_does_not_control_interim_commentary(monkeypatch, tmp_path):
"""tool_progress=all with interim_assistant_messages=false should not surface commentary."""
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
CommentaryAgent,
session_id="sess-commentary-tool-progress",
config_data={"display": {"tool_progress": "all", "interim_assistant_messages": False}},
)
assert result.get("already_sent") is not True
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
@pytest.mark.asyncio
async def test_run_agent_streaming_does_not_enable_completed_interim_commentary(
monkeypatch, tmp_path
):
"""Streaming alone with interim_assistant_messages=false should not surface commentary."""
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
CommentaryAgent,
session_id="sess-commentary-streaming",
config_data={
"display": {"tool_progress": "off", "interim_assistant_messages": False},
"streaming": {"enabled": True},
},
)
assert result.get("already_sent") is True
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
@pytest.mark.asyncio
async def test_run_agent_interim_commentary_works_with_tool_progress_off(monkeypatch, tmp_path):
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
CommentaryAgent,
session_id="sess-commentary-explicit-on",
config_data={
"display": {
"tool_progress": "off",
"interim_assistant_messages": True,
},
},
)
assert result.get("already_sent") is not True
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
@pytest.mark.asyncio
async def test_run_agent_previewed_final_marks_already_sent(monkeypatch, tmp_path):
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
PreviewedResponseAgent,
session_id="sess-previewed",
config_data={"display": {"interim_assistant_messages": True}},
)
assert result.get("already_sent") is True
assert [call["content"] for call in adapter.sent] == ["You're welcome."]
@pytest.mark.asyncio
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
QueuedCommentaryAgent.calls = 0
adapter, result = await _run_with_agent(
monkeypatch,
tmp_path,
QueuedCommentaryAgent,
session_id="sess-queued-commentary",
pending_text="queued follow-up",
config_data={"display": {"interim_assistant_messages": True}},
)
sent_texts = [call["content"] for call in adapter.sent]
assert result["final_response"] == "final response 2"
assert "I'll inspect the repo first." in sent_texts
assert "final response 1" in sent_texts

View File

@@ -1,4 +1,5 @@
import pytest
from unittest.mock import AsyncMock
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter
@@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter):
return {"id": chat_id}
class _SuccessfulAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD)
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
self._mark_disconnected()
async def send(self, chat_id, content, reply_to=None, metadata=None):
raise NotImplementedError
async def get_chat_info(self, chat_id):
return {"id": chat_id}
@pytest.mark.asyncio
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
@@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch,
state = read_runtime_status()
assert state["gateway_state"] == "startup_failed"
assert "temporary DNS resolution failure" in state["exit_reason"]
assert state["platforms"]["telegram"]["state"] == "fatal"
assert state["platforms"]["telegram"]["state"] == "retrying"
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
@@ -89,6 +107,64 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey
assert state["gateway_state"] == "running"
@pytest.mark.asyncio
async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
},
sessions_dir=tmp_path / "sessions",
)
runner = GatewayRunner(config)
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None)
monkeypatch.setattr(runner.hooks, "emit", AsyncMock())
ok = await runner.start()
assert ok is True
state = read_runtime_status()
assert state["gateway_state"] == "running"
assert state["platforms"]["discord"]["state"] == "connected"
assert state["platforms"]["discord"]["error_code"] is None
assert state["platforms"]["discord"]["error_message"] is None
@pytest.mark.asyncio
async def test_start_gateway_verbosity_imports_redacting_formatter(monkeypatch, tmp_path):
"""Verbosity != None must not crash with NameError on RedactingFormatter (#8044)."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
class _CleanExitRunner:
def __init__(self, config):
self.config = config
self.should_exit_cleanly = True
self.exit_reason = None
self.adapters = {}
async def start(self):
return True
async def stop(self):
return None
monkeypatch.setattr("gateway.status.get_running_pid", lambda: None)
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
monkeypatch.setattr("gateway.run.GatewayRunner", _CleanExitRunner)
from gateway.run import start_gateway
# verbosity=1 triggers the code path that uses RedactingFormatter.
# Before the fix this raised NameError.
ok = await start_gateway(config=GatewayConfig(), replace=False, verbosity=1)
assert ok is True
@pytest.mark.asyncio
async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))

View File

@@ -1,3 +1,4 @@
import asyncio
import os
from gateway.config import Platform
@@ -130,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields():
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
runner._clear_session_env(tokens)
# ---------------------------------------------------------------------------
# SESSION_KEY contextvars tests
# ---------------------------------------------------------------------------
def test_session_key_set_via_contextvars(monkeypatch):
"""set_session_vars should set HERMES_SESSION_KEY via contextvars."""
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
tokens = set_session_vars(
platform="telegram",
chat_id="-1001",
session_key="tg:-1001:17585",
)
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
clear_session_vars(tokens)
assert get_session_env("HERMES_SESSION_KEY") == ""
def test_session_key_falls_back_to_os_environ(monkeypatch):
"""get_session_env for SESSION_KEY should fall back to os.environ."""
monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123")
# No contextvar set — should read from os.environ
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
# Set contextvar — should prefer it
tokens = set_session_vars(session_key="ctx-session-456")
assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456"
# Restore — should fall back to os.environ
clear_session_vars(tokens)
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
def test_set_session_env_includes_session_key():
"""_set_session_env should propagate session_key from SessionContext."""
runner = object.__new__(GatewayRunner)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1001",
chat_name="Group",
chat_type="group",
thread_id="17585",
)
context = SessionContext(
source=source,
connected_platforms=[],
home_channels={},
session_key="tg:-1001:17585",
)
tokens = runner._set_session_env(context)
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
runner._clear_session_env(tokens)
assert get_session_env("HERMES_SESSION_KEY") == ""
def test_session_key_no_race_condition_with_contextvars(monkeypatch):
"""Prove contextvars isolates SESSION_KEY across concurrent async tasks.
Two tasks set different session keys. With contextvars each task
reads back its own value. With os.environ the second task would
overwrite the first (the old bug).
"""
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
results = {}
async def handler(key: str, delay: float):
tokens = set_session_vars(session_key=key)
try:
await asyncio.sleep(delay)
read_back = get_session_env("HERMES_SESSION_KEY")
results[key] = read_back
finally:
clear_session_vars(tokens)
async def run():
task_a = asyncio.create_task(handler("session-A", 0.15))
await asyncio.sleep(0.05)
task_b = asyncio.create_task(handler("session-B", 0.05))
await asyncio.gather(task_a, task_b)
asyncio.run(run())
# Both tasks must read back their own session key
assert results["session-A"] == "session-A", (
f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!"
)
assert results["session-B"] == "session-B", (
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
)

View File

@@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus:
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
status.write_runtime_status(
gateway_state="startup_failed",
exit_reason="stale error",
platform="discord",
platform_state="fatal",
error_code="discord_timeout",
error_message="stale platform error",
)
status.write_runtime_status(
gateway_state="running",
exit_reason=None,
platform="discord",
platform_state="connected",
error_code=None,
error_message=None,
)
payload = status.read_runtime_status()
assert payload["gateway_state"] == "running"
assert payload["exit_reason"] is None
assert payload["platforms"]["discord"]["state"] == "connected"
assert payload["platforms"]["discord"]["error_code"] is None
assert payload["platforms"]["discord"]["error_message"] is None
class TestTerminatePid:
def test_force_uses_taskkill_on_windows(self, monkeypatch):

View File

@@ -505,3 +505,81 @@ class TestSegmentBreakOnToolBoundary:
assert len(sent_texts) == 3
assert sent_texts[0].startswith(prefix)
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
class TestInterimCommentaryMessages:
@pytest.mark.asyncio
async def test_commentary_message_stays_separate_from_final_stream(self):
adapter = MagicMock()
adapter.send = AsyncMock(side_effect=[
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
])
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
consumer = GatewayStreamConsumer(
adapter,
"chat_123",
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
)
consumer.on_commentary("I'll inspect the repository first.")
consumer.on_delta("Done.")
consumer.finish()
await consumer.run()
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
assert sent_texts == ["I'll inspect the repository first.", "Done."]
assert consumer.final_response_sent is True
@pytest.mark.asyncio
async def test_failed_final_send_does_not_mark_final_response_sent(self):
adapter = MagicMock()
adapter.send = AsyncMock(return_value=SimpleNamespace(success=False, message_id=None))
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
consumer = GatewayStreamConsumer(
adapter,
"chat_123",
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
)
consumer.on_delta("Done.")
consumer.finish()
await consumer.run()
assert consumer.final_response_sent is False
assert consumer.already_sent is False
@pytest.mark.asyncio
async def test_success_without_message_id_marks_visible_and_sends_only_tail(self):
adapter = MagicMock()
adapter.send = AsyncMock(side_effect=[
SimpleNamespace(success=True, message_id=None),
SimpleNamespace(success=True, message_id=None),
])
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
consumer = GatewayStreamConsumer(
adapter,
"chat_123",
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=""),
)
consumer.on_delta("Hello")
task = asyncio.create_task(consumer.run())
await asyncio.sleep(0.08)
consumer.on_delta(" world")
await asyncio.sleep(0.08)
consumer.finish()
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
assert sent_texts == ["Hello ▉", "world"]
assert consumer.already_sent is True
assert consumer.final_response_sent is True

View File

@@ -403,6 +403,56 @@ class TestWatchUpdateProgress:
# Should not crash; legacy notification handles this case
@pytest.mark.asyncio
async def test_prompt_forwarded_only_once(self, tmp_path):
"""Regression: prompt must not be re-sent on every poll cycle.
Before the fix, the watcher never deleted .update_prompt.json after
forwarding, causing the same prompt to be sent every poll_interval.
"""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222",
"session_key": "agent:main:telegram:dm:111"}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
# Write the prompt file up front (before the watcher starts).
# The watcher should forward it exactly once, then delete it.
prompt = {"prompt": "Would you like to configure new options now? Y/n",
"default": "n", "id": "dup-test"}
(hermes_home / ".update_prompt.json").write_text(json.dumps(prompt))
async def finish_after_polls():
# Wait long enough for multiple poll cycles to occur, then
# simulate a response + completion.
await asyncio.sleep(1.0)
(hermes_home / ".update_response").write_text("n")
await asyncio.sleep(0.3)
(hermes_home / ".update_exit_code").write_text("0")
with patch("gateway.run._hermes_home", hermes_home):
task = asyncio.create_task(finish_after_polls())
await runner._watch_update_progress(
poll_interval=0.1,
stream_interval=0.2,
timeout=10.0,
)
await task
# Count how many times the prompt text was sent
all_sent = [str(c) for c in mock_adapter.send.call_args_list]
prompt_sends = [s for s in all_sent if "configure new options" in s]
assert len(prompt_sends) == 1, (
f"Prompt was sent {len(prompt_sends)} times (expected 1). "
f"All sends: {all_sent}"
)
# ---------------------------------------------------------------------------
# Message interception for update prompts

View File

@@ -63,7 +63,7 @@ class TestVerboseCommand:
@pytest.mark.asyncio
async def test_enabled_cycles_mode(self, tmp_path, monkeypatch):
"""When enabled, /verbose cycles tool_progress mode."""
"""When enabled, /verbose cycles tool_progress mode per-platform."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
@@ -79,10 +79,11 @@ class TestVerboseCommand:
# all -> verbose
assert "VERBOSE" in result
assert "telegram" in result.lower() # per-platform feedback
# Verify config was saved
# Verify config was saved to display.platforms.telegram
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert saved["display"]["tool_progress"] == "verbose"
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
@pytest.mark.asyncio
async def test_cycles_through_all_modes(self, tmp_path, monkeypatch):
@@ -103,8 +104,9 @@ class TestVerboseCommand:
for mode in expected:
result = await runner._handle_verbose_command(_make_event())
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert saved["display"]["tool_progress"] == mode, \
f"Expected {mode}, got {saved['display']['tool_progress']}"
actual = saved["display"]["platforms"]["telegram"]["tool_progress"]
assert actual == mode, \
f"Expected {mode}, got {actual}"
@pytest.mark.asyncio
async def test_defaults_to_all_when_no_tool_progress_set(self, tmp_path, monkeypatch):
@@ -122,10 +124,45 @@ class TestVerboseCommand:
runner = _make_runner()
result = await runner._handle_verbose_command(_make_event())
# default "all" -> verbose
# Telegram default is "all" (high tier) → cycles to verbose
assert "VERBOSE" in result
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert saved["display"]["tool_progress"] == "verbose"
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
@pytest.mark.asyncio
async def test_per_platform_isolation(self, tmp_path, monkeypatch):
"""Cycling /verbose on Telegram doesn't change Slack's setting.
Without a global tool_progress, each platform uses its built-in
default: Telegram = 'all' (high tier), Slack = 'new' (medium tier).
"""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
# No global tool_progress → built-in platform defaults apply
config_path.write_text(
"display:\n tool_progress_command: true\n",
encoding="utf-8",
)
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
runner = _make_runner()
# Cycle on Telegram
await runner._handle_verbose_command(
_make_event(platform=Platform.TELEGRAM)
)
# Cycle on Slack
await runner._handle_verbose_command(
_make_event(platform=Platform.SLACK)
)
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
platforms = saved["display"]["platforms"]
# Telegram: all -> verbose (high tier default = all)
assert platforms["telegram"]["tool_progress"] == "verbose"
# Slack: new -> all (medium tier default = new, cycle to all)
assert platforms["slack"]["tool_progress"] == "all"
@pytest.mark.asyncio
async def test_no_config_file_returns_disabled(self, tmp_path, monkeypatch):

View File

@@ -0,0 +1,185 @@
"""Tests for the WeCom callback-mode adapter."""
import asyncio
from xml.etree import ElementTree as ET
import pytest
from gateway.config import PlatformConfig
from gateway.platforms.wecom_callback import WecomCallbackAdapter
from gateway.platforms.wecom_crypto import WXBizMsgCrypt
def _app(name="test-app", corp_id="ww1234567890", agent_id="1000002"):
return {
"name": name,
"corp_id": corp_id,
"corp_secret": "test-secret",
"agent_id": agent_id,
"token": "test-callback-token",
"encoding_aes_key": "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG",
}
def _config(apps=None):
return PlatformConfig(
enabled=True,
extra={"mode": "callback", "host": "127.0.0.1", "port": 0, "apps": apps or [_app()]},
)
class TestWecomCrypto:
def test_roundtrip_encrypt_decrypt(self):
app = _app()
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
encrypted_xml = crypt.encrypt(
"<xml><Content>hello</Content></xml>", nonce="nonce123", timestamp="123456",
)
root = ET.fromstring(encrypted_xml)
decrypted = crypt.decrypt(
root.findtext("MsgSignature", default=""),
root.findtext("TimeStamp", default=""),
root.findtext("Nonce", default=""),
root.findtext("Encrypt", default=""),
)
assert b"<Content>hello</Content>" in decrypted
def test_signature_mismatch_raises(self):
app = _app()
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
encrypted_xml = crypt.encrypt("<xml/>", nonce="n", timestamp="1")
root = ET.fromstring(encrypted_xml)
from gateway.platforms.wecom_crypto import SignatureError
with pytest.raises(SignatureError):
crypt.decrypt("bad-sig", "1", "n", root.findtext("Encrypt", default=""))
class TestWecomCallbackEventConstruction:
def test_build_event_extracts_text_message(self):
adapter = WecomCallbackAdapter(_config())
xml_text = """
<xml>
<ToUserName>ww1234567890</ToUserName>
<FromUserName>zhangsan</FromUserName>
<CreateTime>1710000000</CreateTime>
<MsgType>text</MsgType>
<Content>\u4f60\u597d</Content>
<MsgId>123456789</MsgId>
</xml>
"""
event = adapter._build_event(_app(), xml_text)
assert event is not None
assert event.source is not None
assert event.source.user_id == "zhangsan"
assert event.source.chat_id == "ww1234567890:zhangsan"
assert event.message_id == "123456789"
assert event.text == "\u4f60\u597d"
def test_build_event_returns_none_for_subscribe(self):
adapter = WecomCallbackAdapter(_config())
xml_text = """
<xml>
<ToUserName>ww1234567890</ToUserName>
<FromUserName>zhangsan</FromUserName>
<CreateTime>1710000000</CreateTime>
<MsgType>event</MsgType>
<Event>subscribe</Event>
</xml>
"""
event = adapter._build_event(_app(), xml_text)
assert event is None
class TestWecomCallbackRouting:
def test_user_app_key_scopes_across_corps(self):
adapter = WecomCallbackAdapter(_config())
assert adapter._user_app_key("corpA", "alice") == "corpA:alice"
assert adapter._user_app_key("corpB", "alice") == "corpB:alice"
assert adapter._user_app_key("corpA", "alice") != adapter._user_app_key("corpB", "alice")
@pytest.mark.asyncio
async def test_send_selects_correct_app_for_scoped_chat_id(self):
apps = [
_app(name="corp-a", corp_id="corpA", agent_id="1001"),
_app(name="corp-b", corp_id="corpB", agent_id="2002"),
]
adapter = WecomCallbackAdapter(_config(apps=apps))
adapter._user_app_map["corpB:alice"] = "corp-b"
adapter._access_tokens["corp-b"] = {"token": "tok-b", "expires_at": 9999999999}
calls = {}
class FakeResponse:
def json(self):
return {"errcode": 0, "msgid": "ok1"}
class FakeClient:
async def post(self, url, json):
calls["url"] = url
calls["json"] = json
return FakeResponse()
adapter._http_client = FakeClient()
result = await adapter.send("corpB:alice", "hello")
assert result.success is True
assert calls["json"]["touser"] == "alice"
assert calls["json"]["agentid"] == 2002
assert "tok-b" in calls["url"]
@pytest.mark.asyncio
async def test_send_falls_back_from_bare_user_id_when_unique(self):
apps = [_app(name="corp-a", corp_id="corpA", agent_id="1001")]
adapter = WecomCallbackAdapter(_config(apps=apps))
adapter._user_app_map["corpA:alice"] = "corp-a"
adapter._access_tokens["corp-a"] = {"token": "tok-a", "expires_at": 9999999999}
calls = {}
class FakeResponse:
def json(self):
return {"errcode": 0, "msgid": "ok2"}
class FakeClient:
async def post(self, url, json):
calls["url"] = url
calls["json"] = json
return FakeResponse()
adapter._http_client = FakeClient()
result = await adapter.send("alice", "hello")
assert result.success is True
assert calls["json"]["agentid"] == 1001
class TestWecomCallbackPollLoop:
@pytest.mark.asyncio
async def test_poll_loop_dispatches_handle_message(self, monkeypatch):
adapter = WecomCallbackAdapter(_config())
calls = []
async def fake_handle_message(event):
calls.append(event.text)
monkeypatch.setattr(adapter, "handle_message", fake_handle_message)
event = adapter._build_event(
_app(),
"""
<xml>
<ToUserName>ww1234567890</ToUserName>
<FromUserName>lisi</FromUserName>
<CreateTime>1710000000</CreateTime>
<MsgType>text</MsgType>
<Content>test</Content>
<MsgId>m2</MsgId>
</xml>
""",
)
task = asyncio.create_task(adapter._poll_loop())
await adapter._message_queue.put(event)
await asyncio.sleep(0.05)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
assert calls == ["test"]

View File

@@ -64,13 +64,44 @@ class TestWeixinFormatting:
class TestWeixinChunking:
def test_split_text_keeps_short_multiline_message_in_single_chunk(self):
def test_split_text_splits_short_chatty_replies_into_separate_bubbles(self):
adapter = _make_adapter()
content = adapter.format_message("第一行\n第二行\n第三行")
chunks = adapter._split_text(content)
assert chunks == ["第一行\n第二行\n第三行"]
assert chunks == ["第一行", "第二行", "第三行"]
def test_split_text_keeps_structured_table_block_together(self):
adapter = _make_adapter()
content = adapter.format_message(
"- Setting: Timeout\n Value: 30s\n- Setting: Retries\n Value: 3"
)
chunks = adapter._split_text(content)
assert chunks == ["- Setting: Timeout\n Value: 30s\n- Setting: Retries\n Value: 3"]
def test_split_text_keeps_four_line_structured_blocks_together(self):
adapter = _make_adapter()
content = adapter.format_message(
"今天结论:\n"
"- 留存下降 3%\n"
"- 转化上涨 8%\n"
"- 主要问题在首日激活"
)
chunks = adapter._split_text(content)
assert chunks == ["今天结论:\n- 留存下降 3%\n- 转化上涨 8%\n- 主要问题在首日激活"]
def test_split_text_keeps_heading_with_body_together(self):
adapter = _make_adapter()
content = adapter.format_message("## 结论\n这是正文")
chunks = adapter._split_text(content)
assert chunks == ["**结论**\n这是正文"]
def test_split_text_keeps_short_reformatted_table_in_single_chunk(self):
adapter = _make_adapter()

View File

@@ -14,6 +14,7 @@ from hermes_cli.auth import (
PROVIDER_REGISTRY,
_read_codex_tokens,
_save_codex_tokens,
_write_codex_cli_tokens,
_import_codex_cli_tokens,
get_codex_auth_status,
get_provider_auth_state,
@@ -161,7 +162,7 @@ def test_import_codex_cli_tokens_missing(tmp_path, monkeypatch):
def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
"""Verify Hermes never writes to ~/.codex/auth.json."""
"""Verify _save_codex_tokens writes only to Hermes auth store, not ~/.codex/."""
hermes_home = tmp_path / "hermes"
codex_home = tmp_path / "codex-cli"
hermes_home.mkdir(parents=True, exist_ok=True)
@@ -173,7 +174,7 @@ def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
_save_codex_tokens({"access_token": "hermes-at", "refresh_token": "hermes-rt"})
# ~/.codex/auth.json should NOT exist
# ~/.codex/auth.json should NOT exist — _save_codex_tokens only touches Hermes store
assert not (codex_home / "auth.json").exists()
# Hermes auth store should have the tokens
@@ -181,6 +182,98 @@ def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
assert data["tokens"]["access_token"] == "hermes-at"
def test_write_codex_cli_tokens_creates_file(tmp_path, monkeypatch):
"""_write_codex_cli_tokens creates ~/.codex/auth.json with refreshed tokens."""
codex_home = tmp_path / "codex-cli"
monkeypatch.setenv("CODEX_HOME", str(codex_home))
_write_codex_cli_tokens("new-access", "new-refresh", last_refresh="2026-04-12T00:00:00Z")
auth_path = codex_home / "auth.json"
assert auth_path.exists()
data = json.loads(auth_path.read_text())
assert data["tokens"]["access_token"] == "new-access"
assert data["tokens"]["refresh_token"] == "new-refresh"
assert data["last_refresh"] == "2026-04-12T00:00:00Z"
# Verify file permissions are restricted
assert (auth_path.stat().st_mode & 0o777) == 0o600
def test_write_codex_cli_tokens_preserves_existing(tmp_path, monkeypatch):
"""_write_codex_cli_tokens preserves extra fields in existing auth.json."""
codex_home = tmp_path / "codex-cli"
codex_home.mkdir(parents=True, exist_ok=True)
monkeypatch.setenv("CODEX_HOME", str(codex_home))
existing = {
"tokens": {
"access_token": "old-access",
"refresh_token": "old-refresh",
"extra_field": "preserved",
},
"last_refresh": "2026-01-01T00:00:00Z",
"custom_key": "keep_me",
}
(codex_home / "auth.json").write_text(json.dumps(existing))
_write_codex_cli_tokens("updated-access", "updated-refresh")
data = json.loads((codex_home / "auth.json").read_text())
assert data["tokens"]["access_token"] == "updated-access"
assert data["tokens"]["refresh_token"] == "updated-refresh"
assert data["tokens"]["extra_field"] == "preserved"
assert data["custom_key"] == "keep_me"
# last_refresh not updated since we didn't pass it
assert data["last_refresh"] == "2026-01-01T00:00:00Z"
def test_write_codex_cli_tokens_handles_missing_dir(tmp_path, monkeypatch):
"""_write_codex_cli_tokens creates parent directories if missing."""
codex_home = tmp_path / "does" / "not" / "exist"
monkeypatch.setenv("CODEX_HOME", str(codex_home))
_write_codex_cli_tokens("at", "rt")
assert (codex_home / "auth.json").exists()
data = json.loads((codex_home / "auth.json").read_text())
assert data["tokens"]["access_token"] == "at"
def test_refresh_codex_auth_tokens_writes_back_to_cli(tmp_path, monkeypatch):
"""After refreshing, _refresh_codex_auth_tokens writes back to ~/.codex/auth.json."""
from hermes_cli.auth import _refresh_codex_auth_tokens
hermes_home = tmp_path / "hermes"
codex_home = tmp_path / "codex-cli"
hermes_home.mkdir(parents=True, exist_ok=True)
codex_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "auth.json").write_text(json.dumps({"version": 1, "providers": {}}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setenv("CODEX_HOME", str(codex_home))
# Write initial CLI tokens
(codex_home / "auth.json").write_text(json.dumps({
"tokens": {"access_token": "old-at", "refresh_token": "old-rt"},
}))
# Mock the pure refresh to return new tokens
monkeypatch.setattr("hermes_cli.auth.refresh_codex_oauth_pure", lambda *a, **kw: {
"access_token": "refreshed-at",
"refresh_token": "refreshed-rt",
"last_refresh": "2026-04-12T01:00:00Z",
})
_refresh_codex_auth_tokens(
{"access_token": "old-at", "refresh_token": "old-rt"},
timeout_seconds=10,
)
# Verify CLI file was updated
cli_data = json.loads((codex_home / "auth.json").read_text())
assert cli_data["tokens"]["access_token"] == "refreshed-at"
assert cli_data["tokens"]["refresh_token"] == "refreshed-rt"
def test_resolve_returns_hermes_auth_store_source(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_hermes_auth(hermes_home)

View File

@@ -0,0 +1,897 @@
"""Tests for hermes backup and import commands."""
import os
import zipfile
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_hermes_tree(root: Path) -> None:
"""Create a realistic ~/.hermes directory structure for testing."""
(root / "config.yaml").write_text("model:\n provider: openrouter\n")
(root / ".env").write_text("OPENROUTER_API_KEY=sk-test-123\n")
(root / "memory_store.db").write_bytes(b"fake-sqlite")
(root / "hermes_state.db").write_bytes(b"fake-state")
# Sessions
(root / "sessions").mkdir(exist_ok=True)
(root / "sessions" / "abc123.json").write_text("{}")
# Skills
(root / "skills").mkdir(exist_ok=True)
(root / "skills" / "my-skill").mkdir()
(root / "skills" / "my-skill" / "SKILL.md").write_text("# My Skill\n")
# Skins
(root / "skins").mkdir(exist_ok=True)
(root / "skins" / "cyber.yaml").write_text("name: cyber\n")
# Cron
(root / "cron").mkdir(exist_ok=True)
(root / "cron" / "jobs.json").write_text("[]")
# Memories
(root / "memories").mkdir(exist_ok=True)
(root / "memories" / "notes.json").write_text("{}")
# Profiles
(root / "profiles").mkdir(exist_ok=True)
(root / "profiles" / "coder").mkdir()
(root / "profiles" / "coder" / "config.yaml").write_text("model:\n provider: anthropic\n")
(root / "profiles" / "coder" / ".env").write_text("ANTHROPIC_API_KEY=sk-ant-123\n")
# hermes-agent repo (should be EXCLUDED)
(root / "hermes-agent").mkdir(exist_ok=True)
(root / "hermes-agent" / "run_agent.py").write_text("# big file\n")
(root / "hermes-agent" / ".git").mkdir()
(root / "hermes-agent" / ".git" / "HEAD").write_text("ref: refs/heads/main\n")
# __pycache__ (should be EXCLUDED)
(root / "plugins").mkdir(exist_ok=True)
(root / "plugins" / "__pycache__").mkdir()
(root / "plugins" / "__pycache__" / "mod.cpython-312.pyc").write_bytes(b"\x00")
# PID files (should be EXCLUDED)
(root / "gateway.pid").write_text("12345")
# Logs (should be included)
(root / "logs").mkdir(exist_ok=True)
(root / "logs" / "agent.log").write_text("log line\n")
# ---------------------------------------------------------------------------
# _should_exclude tests
# ---------------------------------------------------------------------------
class TestShouldExclude:
def test_excludes_hermes_agent(self):
from hermes_cli.backup import _should_exclude
assert _should_exclude(Path("hermes-agent/run_agent.py"))
assert _should_exclude(Path("hermes-agent/.git/HEAD"))
def test_excludes_pycache(self):
from hermes_cli.backup import _should_exclude
assert _should_exclude(Path("plugins/__pycache__/mod.cpython-312.pyc"))
def test_excludes_pyc_files(self):
from hermes_cli.backup import _should_exclude
assert _should_exclude(Path("some/module.pyc"))
def test_excludes_pid_files(self):
from hermes_cli.backup import _should_exclude
assert _should_exclude(Path("gateway.pid"))
assert _should_exclude(Path("cron.pid"))
def test_includes_config(self):
from hermes_cli.backup import _should_exclude
assert not _should_exclude(Path("config.yaml"))
def test_includes_env(self):
from hermes_cli.backup import _should_exclude
assert not _should_exclude(Path(".env"))
def test_includes_skills(self):
from hermes_cli.backup import _should_exclude
assert not _should_exclude(Path("skills/my-skill/SKILL.md"))
def test_includes_profiles(self):
from hermes_cli.backup import _should_exclude
assert not _should_exclude(Path("profiles/coder/config.yaml"))
def test_includes_sessions(self):
from hermes_cli.backup import _should_exclude
assert not _should_exclude(Path("sessions/abc.json"))
def test_includes_logs(self):
from hermes_cli.backup import _should_exclude
assert not _should_exclude(Path("logs/agent.log"))
# ---------------------------------------------------------------------------
# Backup tests
# ---------------------------------------------------------------------------
class TestBackup:
def test_creates_zip(self, tmp_path, monkeypatch):
"""Backup creates a valid zip containing expected files."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
_make_hermes_tree(hermes_home)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
# get_default_hermes_root needs this
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_zip = tmp_path / "backup.zip"
args = Namespace(output=str(out_zip))
from hermes_cli.backup import run_backup
run_backup(args)
assert out_zip.exists()
with zipfile.ZipFile(out_zip, "r") as zf:
names = zf.namelist()
# Config should be present
assert "config.yaml" in names
assert ".env" in names
# Skills
assert "skills/my-skill/SKILL.md" in names
# Profiles
assert "profiles/coder/config.yaml" in names
assert "profiles/coder/.env" in names
# Sessions
assert "sessions/abc123.json" in names
# Logs
assert "logs/agent.log" in names
# Skins
assert "skins/cyber.yaml" in names
def test_excludes_hermes_agent(self, tmp_path, monkeypatch):
"""Backup does NOT include hermes-agent/ directory."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
_make_hermes_tree(hermes_home)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_zip = tmp_path / "backup.zip"
args = Namespace(output=str(out_zip))
from hermes_cli.backup import run_backup
run_backup(args)
with zipfile.ZipFile(out_zip, "r") as zf:
names = zf.namelist()
agent_files = [n for n in names if "hermes-agent" in n]
assert agent_files == [], f"hermes-agent files leaked into backup: {agent_files}"
def test_excludes_pycache(self, tmp_path, monkeypatch):
"""Backup does NOT include __pycache__ dirs."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
_make_hermes_tree(hermes_home)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_zip = tmp_path / "backup.zip"
args = Namespace(output=str(out_zip))
from hermes_cli.backup import run_backup
run_backup(args)
with zipfile.ZipFile(out_zip, "r") as zf:
names = zf.namelist()
pycache_files = [n for n in names if "__pycache__" in n]
assert pycache_files == []
def test_excludes_pid_files(self, tmp_path, monkeypatch):
"""Backup does NOT include PID files."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
_make_hermes_tree(hermes_home)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_zip = tmp_path / "backup.zip"
args = Namespace(output=str(out_zip))
from hermes_cli.backup import run_backup
run_backup(args)
with zipfile.ZipFile(out_zip, "r") as zf:
names = zf.namelist()
pid_files = [n for n in names if n.endswith(".pid")]
assert pid_files == []
def test_default_output_path(self, tmp_path, monkeypatch):
"""When no output path given, zip goes to ~/hermes-backup-*.zip."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("model: test\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
args = Namespace(output=None)
from hermes_cli.backup import run_backup
run_backup(args)
# Should exist in home dir
zips = list(tmp_path.glob("hermes-backup-*.zip"))
assert len(zips) == 1
# ---------------------------------------------------------------------------
# Import tests
# ---------------------------------------------------------------------------
class TestImport:
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
"""Create a test zip with given files."""
with zipfile.ZipFile(zip_path, "w") as zf:
for name, content in files.items():
if isinstance(content, bytes):
zf.writestr(name, content)
else:
zf.writestr(name, content)
def test_restores_files(self, tmp_path, monkeypatch):
"""Import extracts files into hermes home."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model:\n provider: openrouter\n",
".env": "OPENROUTER_API_KEY=sk-test\n",
"skills/my-skill/SKILL.md": "# My Skill\n",
"profiles/coder/config.yaml": "model:\n provider: anthropic\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
assert (hermes_home / "config.yaml").read_text() == "model:\n provider: openrouter\n"
assert (hermes_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test\n"
assert (hermes_home / "skills" / "my-skill" / "SKILL.md").read_text() == "# My Skill\n"
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
def test_strips_hermes_prefix(self, tmp_path, monkeypatch):
"""Import strips .hermes/ prefix if all entries share it."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
".hermes/config.yaml": "model: test\n",
".hermes/skills/a/SKILL.md": "# A\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
assert (hermes_home / "config.yaml").read_text() == "model: test\n"
assert (hermes_home / "skills" / "a" / "SKILL.md").read_text() == "# A\n"
def test_rejects_empty_zip(self, tmp_path, monkeypatch):
"""Import rejects an empty zip."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "empty.zip"
with zipfile.ZipFile(zip_path, "w"):
pass # empty
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
with pytest.raises(SystemExit):
run_import(args)
def test_rejects_non_hermes_zip(self, tmp_path, monkeypatch):
"""Import rejects a zip that doesn't look like a hermes backup."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "random.zip"
self._make_backup_zip(zip_path, {
"some/random/file.txt": "hello",
"another/thing.json": "{}",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
with pytest.raises(SystemExit):
run_import(args)
def test_blocks_path_traversal(self, tmp_path, monkeypatch):
"""Import blocks zip entries with path traversal."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "evil.zip"
# Include a marker file so validation passes
self._make_backup_zip(zip_path, {
"config.yaml": "model: test\n",
"../../etc/passwd": "root:x:0:0\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
# config.yaml should be restored
assert (hermes_home / "config.yaml").exists()
# traversal file should NOT exist outside hermes home
assert not (tmp_path / "etc" / "passwd").exists()
def test_confirmation_prompt_abort(self, tmp_path, monkeypatch):
"""Import aborts when user says no to confirmation."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
# Pre-existing config triggers the confirmation
(hermes_home / "config.yaml").write_text("existing: true\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model: restored\n",
})
args = Namespace(zipfile=str(zip_path), force=False)
from hermes_cli.backup import run_import
with patch("builtins.input", return_value="n"):
run_import(args)
# Original config should be unchanged
assert (hermes_home / "config.yaml").read_text() == "existing: true\n"
def test_force_skips_confirmation(self, tmp_path, monkeypatch):
"""Import with --force skips confirmation and overwrites."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("existing: true\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model: restored\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
assert (hermes_home / "config.yaml").read_text() == "model: restored\n"
def test_missing_file_exits(self, tmp_path, monkeypatch):
"""Import exits with error for nonexistent file."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
args = Namespace(zipfile=str(tmp_path / "nonexistent.zip"), force=True)
from hermes_cli.backup import run_import
with pytest.raises(SystemExit):
run_import(args)
# ---------------------------------------------------------------------------
# Round-trip test
# ---------------------------------------------------------------------------
class TestRoundTrip:
def test_backup_then_import(self, tmp_path, monkeypatch):
"""Full round-trip: backup -> import to a new location -> verify."""
# Source
src_home = tmp_path / "source" / ".hermes"
src_home.mkdir(parents=True)
_make_hermes_tree(src_home)
monkeypatch.setenv("HERMES_HOME", str(src_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path / "source")
# Backup
out_zip = tmp_path / "roundtrip.zip"
from hermes_cli.backup import run_backup, run_import
run_backup(Namespace(output=str(out_zip)))
assert out_zip.exists()
# Import into a different location
dst_home = tmp_path / "dest" / ".hermes"
dst_home.mkdir(parents=True)
monkeypatch.setenv("HERMES_HOME", str(dst_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path / "dest")
run_import(Namespace(zipfile=str(out_zip), force=True))
# Verify key files
assert (dst_home / "config.yaml").read_text() == "model:\n provider: openrouter\n"
assert (dst_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test-123\n"
assert (dst_home / "skills" / "my-skill" / "SKILL.md").exists()
assert (dst_home / "profiles" / "coder" / "config.yaml").exists()
assert (dst_home / "sessions" / "abc123.json").exists()
assert (dst_home / "logs" / "agent.log").exists()
# hermes-agent should NOT be present
assert not (dst_home / "hermes-agent").exists()
# __pycache__ should NOT be present
assert not (dst_home / "plugins" / "__pycache__").exists()
# PID files should NOT be present
assert not (dst_home / "gateway.pid").exists()
# ---------------------------------------------------------------------------
# Validate / detect-prefix unit tests
# ---------------------------------------------------------------------------
class TestFormatSize:
def test_bytes(self):
from hermes_cli.backup import _format_size
assert _format_size(512) == "512 B"
def test_kilobytes(self):
from hermes_cli.backup import _format_size
assert "KB" in _format_size(2048)
def test_megabytes(self):
from hermes_cli.backup import _format_size
assert "MB" in _format_size(5 * 1024 * 1024)
def test_gigabytes(self):
from hermes_cli.backup import _format_size
assert "GB" in _format_size(3 * 1024 ** 3)
def test_terabytes(self):
from hermes_cli.backup import _format_size
assert "TB" in _format_size(2 * 1024 ** 4)
class TestValidation:
def test_validate_with_config(self):
"""Zip with config.yaml passes validation."""
import io
from hermes_cli.backup import _validate_backup_zip
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("config.yaml", "test")
buf.seek(0)
with zipfile.ZipFile(buf, "r") as zf:
ok, reason = _validate_backup_zip(zf)
assert ok
def test_validate_with_env(self):
"""Zip with .env passes validation."""
import io
from hermes_cli.backup import _validate_backup_zip
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr(".env", "KEY=val")
buf.seek(0)
with zipfile.ZipFile(buf, "r") as zf:
ok, reason = _validate_backup_zip(zf)
assert ok
def test_validate_rejects_random(self):
"""Zip without hermes markers fails validation."""
import io
from hermes_cli.backup import _validate_backup_zip
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("random/file.txt", "hello")
buf.seek(0)
with zipfile.ZipFile(buf, "r") as zf:
ok, reason = _validate_backup_zip(zf)
assert not ok
def test_detect_prefix_hermes(self):
"""Detects .hermes/ prefix wrapping all entries."""
import io
from hermes_cli.backup import _detect_prefix
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr(".hermes/config.yaml", "test")
zf.writestr(".hermes/skills/a/SKILL.md", "skill")
buf.seek(0)
with zipfile.ZipFile(buf, "r") as zf:
assert _detect_prefix(zf) == ".hermes/"
def test_detect_prefix_none(self):
"""No prefix when entries are at root."""
import io
from hermes_cli.backup import _detect_prefix
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("config.yaml", "test")
zf.writestr("skills/a/SKILL.md", "skill")
buf.seek(0)
with zipfile.ZipFile(buf, "r") as zf:
assert _detect_prefix(zf) == ""
def test_detect_prefix_only_dirs(self):
"""Prefix detection returns empty for zip with only directory entries."""
import io
from hermes_cli.backup import _detect_prefix
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
# Only directory entries (trailing slash)
zf.writestr(".hermes/", "")
zf.writestr(".hermes/skills/", "")
buf.seek(0)
with zipfile.ZipFile(buf, "r") as zf:
assert _detect_prefix(zf) == ""
# ---------------------------------------------------------------------------
# Edge case tests for uncovered paths
# ---------------------------------------------------------------------------
class TestBackupEdgeCases:
def test_nonexistent_hermes_home(self, tmp_path, monkeypatch):
"""Backup exits when hermes home doesn't exist."""
fake_home = tmp_path / "nonexistent" / ".hermes"
monkeypatch.setenv("HERMES_HOME", str(fake_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path / "nonexistent")
args = Namespace(output=str(tmp_path / "out.zip"))
from hermes_cli.backup import run_backup
with pytest.raises(SystemExit):
run_backup(args)
def test_output_is_directory(self, tmp_path, monkeypatch):
"""When output path is a directory, zip is created inside it."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("model: test\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_dir = tmp_path / "backups"
out_dir.mkdir()
args = Namespace(output=str(out_dir))
from hermes_cli.backup import run_backup
run_backup(args)
zips = list(out_dir.glob("hermes-backup-*.zip"))
assert len(zips) == 1
def test_output_without_zip_suffix(self, tmp_path, monkeypatch):
"""Output path without .zip gets suffix appended."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("model: test\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_path = tmp_path / "mybackup.tar"
args = Namespace(output=str(out_path))
from hermes_cli.backup import run_backup
run_backup(args)
# Should have .tar.zip suffix
assert (tmp_path / "mybackup.tar.zip").exists()
def test_empty_hermes_home(self, tmp_path, monkeypatch):
"""Backup handles empty hermes home (no files to back up)."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
# Only excluded dirs, no actual files
(hermes_home / "__pycache__").mkdir()
(hermes_home / "__pycache__" / "foo.pyc").write_bytes(b"\x00")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
args = Namespace(output=str(tmp_path / "out.zip"))
from hermes_cli.backup import run_backup
run_backup(args)
# No zip should be created
assert not (tmp_path / "out.zip").exists()
def test_permission_error_during_backup(self, tmp_path, monkeypatch):
"""Backup handles permission errors gracefully."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("model: test\n")
# Create an unreadable file
bad_file = hermes_home / "secret.db"
bad_file.write_text("data")
bad_file.chmod(0o000)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
out_zip = tmp_path / "out.zip"
args = Namespace(output=str(out_zip))
from hermes_cli.backup import run_backup
try:
run_backup(args)
finally:
# Restore permissions for cleanup
bad_file.chmod(0o644)
# Zip should still be created with the readable files
assert out_zip.exists()
def test_skips_output_zip_inside_hermes(self, tmp_path, monkeypatch):
"""Backup skips its own output zip if it's inside hermes root."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("model: test\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
# Output inside hermes home
out_zip = hermes_home / "backup.zip"
args = Namespace(output=str(out_zip))
from hermes_cli.backup import run_backup
run_backup(args)
# The zip should exist but not contain itself
assert out_zip.exists()
with zipfile.ZipFile(out_zip, "r") as zf:
assert "backup.zip" not in zf.namelist()
class TestImportEdgeCases:
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
with zipfile.ZipFile(zip_path, "w") as zf:
for name, content in files.items():
zf.writestr(name, content)
def test_not_a_zip(self, tmp_path, monkeypatch):
"""Import rejects a non-zip file."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
not_zip = tmp_path / "fake.zip"
not_zip.write_text("this is not a zip")
args = Namespace(zipfile=str(not_zip), force=True)
from hermes_cli.backup import run_import
with pytest.raises(SystemExit):
run_import(args)
def test_eof_during_confirmation(self, tmp_path, monkeypatch):
"""Import handles EOFError during confirmation prompt."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("existing\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {"config.yaml": "new\n"})
args = Namespace(zipfile=str(zip_path), force=False)
from hermes_cli.backup import run_import
with patch("builtins.input", side_effect=EOFError):
with pytest.raises(SystemExit):
run_import(args)
def test_keyboard_interrupt_during_confirmation(self, tmp_path, monkeypatch):
"""Import handles KeyboardInterrupt during confirmation prompt."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / ".env").write_text("KEY=val\n")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {"config.yaml": "new\n"})
args = Namespace(zipfile=str(zip_path), force=False)
from hermes_cli.backup import run_import
with patch("builtins.input", side_effect=KeyboardInterrupt):
with pytest.raises(SystemExit):
run_import(args)
def test_permission_error_during_import(self, tmp_path, monkeypatch):
"""Import handles permission errors during extraction."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
# Create a read-only directory so extraction fails
locked_dir = hermes_home / "locked"
locked_dir.mkdir()
locked_dir.chmod(0o555)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model: test\n",
"locked/secret.txt": "data",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
try:
run_import(args)
finally:
locked_dir.chmod(0o755)
# config.yaml should still be restored despite the error
assert (hermes_home / "config.yaml").exists()
def test_progress_with_many_files(self, tmp_path, monkeypatch):
"""Import shows progress with 500+ files."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "big.zip"
files = {"config.yaml": "model: test\n"}
for i in range(600):
files[f"sessions/s{i:04d}.json"] = "{}"
self._make_backup_zip(zip_path, files)
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
assert (hermes_home / "config.yaml").exists()
assert (hermes_home / "sessions" / "s0599.json").exists()
# ---------------------------------------------------------------------------
# Profile restoration tests
# ---------------------------------------------------------------------------
class TestProfileRestoration:
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
with zipfile.ZipFile(zip_path, "w") as zf:
for name, content in files.items():
zf.writestr(name, content)
def test_import_creates_profile_wrappers(self, tmp_path, monkeypatch):
"""Import auto-creates wrapper scripts for restored profiles."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
# Mock the wrapper dir to be inside tmp_path
wrapper_dir = tmp_path / ".local" / "bin"
wrapper_dir.mkdir(parents=True)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model:\n provider: openrouter\n",
"profiles/coder/config.yaml": "model:\n provider: anthropic\n",
"profiles/coder/.env": "ANTHROPIC_API_KEY=sk-test\n",
"profiles/researcher/config.yaml": "model:\n provider: deepseek\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
# Profile directories should exist
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
assert (hermes_home / "profiles" / "researcher" / "config.yaml").exists()
# Wrapper scripts should be created
assert (wrapper_dir / "coder").exists()
assert (wrapper_dir / "researcher").exists()
# Wrappers should contain the right content
coder_wrapper = (wrapper_dir / "coder").read_text()
assert "hermes -p coder" in coder_wrapper
def test_import_skips_profile_dirs_without_config(self, tmp_path, monkeypatch):
"""Import doesn't create wrappers for profile dirs without config."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
wrapper_dir = tmp_path / ".local" / "bin"
wrapper_dir.mkdir(parents=True)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model: test\n",
"profiles/valid/config.yaml": "model: test\n",
"profiles/empty/readme.txt": "nothing here\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
from hermes_cli.backup import run_import
run_import(args)
# Only valid profile should get a wrapper
assert (wrapper_dir / "valid").exists()
assert not (wrapper_dir / "empty").exists()
def test_import_without_profiles_module(self, tmp_path, monkeypatch):
"""Import gracefully handles missing profiles module (fresh install)."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
zip_path = tmp_path / "backup.zip"
self._make_backup_zip(zip_path, {
"config.yaml": "model: test\n",
"profiles/coder/config.yaml": "model: test\n",
})
args = Namespace(zipfile=str(zip_path), force=True)
# Simulate profiles module not being available
import hermes_cli.backup as backup_mod
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
def fake_import(name, *a, **kw):
if name == "hermes_cli.profiles":
raise ImportError("no profiles module")
return original_import(name, *a, **kw)
from hermes_cli.backup import run_import
with patch("builtins.__import__", side_effect=fake_import):
run_import(args)
# Files should still be restored even if wrappers can't be created
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()

View File

@@ -58,13 +58,13 @@ class TestFindOpenclawDirs:
def test_finds_legacy_dirs(self, tmp_path):
clawdbot = tmp_path / ".clawdbot"
clawdbot.mkdir()
moldbot = tmp_path / ".moldbot"
moldbot.mkdir()
moltbot = tmp_path / ".moltbot"
moltbot.mkdir()
with patch("pathlib.Path.home", return_value=tmp_path):
found = claw_mod._find_openclaw_dirs()
assert len(found) == 2
assert clawdbot in found
assert moldbot in found
assert moltbot in found
def test_returns_empty_when_none_exist(self, tmp_path):
with patch("pathlib.Path.home", return_value=tmp_path):
@@ -297,7 +297,6 @@ class TestCmdMigrate:
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
patch.object(claw_mod, "get_config_path", return_value=config_path),
patch.object(claw_mod, "prompt_yes_no", return_value=True),
patch.object(claw_mod, "_offer_source_archival"),
patch("sys.stdin", mock_stdin),
):
claw_mod._cmd_migrate(args)
@@ -306,43 +305,8 @@ class TestCmdMigrate:
assert "Migration Results" in captured.out
assert "Migration complete!" in captured.out
def test_execute_offers_archival_on_success(self, tmp_path, capsys):
"""After successful migration, _offer_source_archival should be called."""
openclaw_dir = tmp_path / ".openclaw"
openclaw_dir.mkdir()
fake_mod = ModuleType("openclaw_to_hermes")
fake_mod.resolve_selected_options = MagicMock(return_value={"soul"})
fake_migrator = MagicMock()
fake_migrator.migrate.return_value = {
"summary": {"migrated": 3, "skipped": 0, "conflict": 0, "error": 0},
"items": [
{"kind": "soul", "status": "migrated", "destination": str(tmp_path / "SOUL.md")},
],
}
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
args = Namespace(
source=str(openclaw_dir),
dry_run=False, preset="full", overwrite=False,
migrate_secrets=False, workspace_target=None,
skill_conflict="skip", yes=True,
)
with (
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
patch.object(claw_mod, "save_config"),
patch.object(claw_mod, "load_config", return_value={}),
patch.object(claw_mod, "_offer_source_archival") as mock_archival,
):
claw_mod._cmd_migrate(args)
mock_archival.assert_called_once_with(openclaw_dir, True)
def test_dry_run_skips_archival(self, tmp_path, capsys):
"""Dry run should not offer archival."""
def test_dry_run_does_not_touch_source(self, tmp_path, capsys):
"""Dry run should not modify the source directory."""
openclaw_dir = tmp_path / ".openclaw"
openclaw_dir.mkdir()
@@ -369,11 +333,10 @@ class TestCmdMigrate:
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
patch.object(claw_mod, "save_config"),
patch.object(claw_mod, "load_config", return_value={}),
patch.object(claw_mod, "_offer_source_archival") as mock_archival,
):
claw_mod._cmd_migrate(args)
mock_archival.assert_not_called()
assert openclaw_dir.is_dir() # Source untouched
def test_execute_cancelled_by_user(self, tmp_path, capsys):
openclaw_dir = tmp_path / ".openclaw"
@@ -506,73 +469,6 @@ class TestCmdMigrate:
assert call_kwargs["migrate_secrets"] is True
# ---------------------------------------------------------------------------
# _offer_source_archival
# ---------------------------------------------------------------------------
class TestOfferSourceArchival:
"""Test the post-migration archival offer."""
def test_archives_with_auto_yes(self, tmp_path, capsys):
source = tmp_path / ".openclaw"
source.mkdir()
(source / "workspace").mkdir()
(source / "workspace" / "todo.json").write_text("{}")
claw_mod._offer_source_archival(source, auto_yes=True)
captured = capsys.readouterr()
assert "Archived" in captured.out
assert not source.exists()
assert (tmp_path / ".openclaw.pre-migration").is_dir()
def test_skips_when_user_declines(self, tmp_path, capsys):
source = tmp_path / ".openclaw"
source.mkdir()
mock_stdin = MagicMock()
mock_stdin.isatty.return_value = True
with (
patch.object(claw_mod, "prompt_yes_no", return_value=False),
patch("sys.stdin", mock_stdin),
):
claw_mod._offer_source_archival(source, auto_yes=False)
captured = capsys.readouterr()
assert "Skipped" in captured.out
assert source.is_dir() # Still exists
def test_noop_when_source_missing(self, tmp_path, capsys):
claw_mod._offer_source_archival(tmp_path / "nonexistent", auto_yes=True)
captured = capsys.readouterr()
assert captured.out == "" # No output
def test_shows_state_files(self, tmp_path, capsys):
source = tmp_path / ".openclaw"
source.mkdir()
ws = source / "workspace"
ws.mkdir()
(ws / "todo.json").write_text("{}")
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
claw_mod._offer_source_archival(source, auto_yes=False)
captured = capsys.readouterr()
assert "todo.json" in captured.out
def test_handles_archive_error(self, tmp_path, capsys):
source = tmp_path / ".openclaw"
source.mkdir()
with patch.object(claw_mod, "_archive_directory", side_effect=OSError("permission denied")):
claw_mod._offer_source_archival(source, auto_yes=True)
captured = capsys.readouterr()
assert "Could not archive" in captured.out
# ---------------------------------------------------------------------------
# _cmd_cleanup
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,254 @@
"""Tests for the interactive CLI /model picker (provider → model drill-down)."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
class _FakeBuffer:
def __init__(self, text="draft text"):
self.text = text
self.cursor_position = len(text)
self.reset_calls = []
def reset(self, append_to_history=False):
self.reset_calls.append(append_to_history)
self.text = ""
self.cursor_position = 0
def _make_providers():
return [
{
"slug": "openrouter",
"name": "OpenRouter",
"is_current": True,
"is_user_defined": False,
"models": ["anthropic/claude-opus-4.6", "openai/gpt-5.4"],
"total_models": 2,
"source": "built-in",
},
{
"slug": "anthropic",
"name": "Anthropic",
"is_current": False,
"is_user_defined": False,
"models": ["claude-opus-4.6", "claude-sonnet-4.6"],
"total_models": 2,
"source": "built-in",
},
{
"slug": "custom:my-ollama",
"name": "My Ollama",
"is_current": False,
"is_user_defined": True,
"models": ["llama3", "mistral"],
"total_models": 2,
"source": "user-config",
"api_url": "http://localhost:11434/v1",
},
]
def _make_picker_cli(picker_return_value):
cli = MagicMock()
cli._run_curses_picker = MagicMock(return_value=picker_return_value)
cli._app = MagicMock()
cli._status_bar_visible = True
return cli
def _make_modal_cli():
from cli import HermesCLI
cli = HermesCLI.__new__(HermesCLI)
cli.model = "gpt-5.4"
cli.provider = "openrouter"
cli.requested_provider = "openrouter"
cli.base_url = ""
cli.api_key = ""
cli.api_mode = ""
cli._explicit_api_key = ""
cli._explicit_base_url = ""
cli._pending_model_switch_note = None
cli._model_picker_state = None
cli._modal_input_snapshot = None
cli._status_bar_visible = True
cli._invalidate = MagicMock()
cli.agent = None
cli.config = {}
cli.console = MagicMock()
cli._app = SimpleNamespace(
current_buffer=_FakeBuffer(),
invalidate=MagicMock(),
)
return cli
def test_provider_selection_returns_slug_on_choice():
providers = _make_providers()
cli = _make_picker_cli(1)
from cli import HermesCLI
result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
assert result == "anthropic"
cli._run_curses_picker.assert_called_once()
def test_provider_selection_returns_none_on_cancel():
providers = _make_providers()
cli = _make_picker_cli(None)
from cli import HermesCLI
result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
assert result is None
def test_provider_selection_default_is_current():
providers = _make_providers()
cli = _make_picker_cli(0)
from cli import HermesCLI
HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
assert cli._run_curses_picker.call_args.kwargs["default_index"] == 0
def test_model_selection_returns_model_on_choice():
provider_data = _make_providers()[0]
cli = _make_picker_cli(0)
from cli import HermesCLI
result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data)
assert result == "anthropic/claude-opus-4.6"
def test_model_selection_custom_entry_prompts_for_input():
provider_data = _make_providers()[0]
cli = _make_picker_cli(2)
from cli import HermesCLI
cli._prompt_text_input = MagicMock(return_value="my-custom-model")
result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data)
assert result == "my-custom-model"
cli._prompt_text_input.assert_called_once_with(" Enter model name: ")
def test_model_selection_empty_prompts_for_manual_input():
provider_data = {
"slug": "custom:empty",
"name": "Empty Provider",
"models": [],
"total_models": 0,
}
cli = _make_picker_cli(None)
from cli import HermesCLI
cli._prompt_text_input = MagicMock(return_value="my-model")
result = HermesCLI._interactive_model_selection(cli, [], provider_data)
assert result == "my-model"
cli._prompt_text_input.assert_called_once_with(" Enter model name manually (or Enter to cancel): ")
def test_prompt_text_input_uses_run_in_terminal_when_app_active():
from cli import HermesCLI
cli = _make_modal_cli()
with (
patch("prompt_toolkit.application.run_in_terminal", side_effect=lambda fn: fn()) as run_mock,
patch("builtins.input", return_value="manual-value"),
):
result = HermesCLI._prompt_text_input(cli, "Enter value: ")
assert result == "manual-value"
run_mock.assert_called_once()
assert cli._status_bar_visible is True
def test_should_handle_model_command_inline_uses_command_name_resolution():
from cli import HermesCLI
cli = _make_modal_cli()
with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="model")):
assert HermesCLI._should_handle_model_command_inline(cli, "/model") is True
with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="help")):
assert HermesCLI._should_handle_model_command_inline(cli, "/model") is False
assert HermesCLI._should_handle_model_command_inline(cli, "/model", has_images=True) is False
def test_process_command_model_without_args_opens_modal_picker_and_captures_draft():
from cli import HermesCLI
cli = _make_modal_cli()
providers = _make_providers()
with (
patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers),
patch("cli._cprint"),
):
result = cli.process_command("/model")
assert result is True
assert cli._model_picker_state is not None
assert cli._model_picker_state["stage"] == "provider"
assert cli._model_picker_state["selected"] == 0
assert cli._modal_input_snapshot == {"text": "draft text", "cursor_position": len("draft text")}
assert cli._app.current_buffer.text == ""
def test_model_picker_provider_then_model_selection_applies_switch_result_and_restores_draft():
from cli import HermesCLI
cli = _make_modal_cli()
providers = _make_providers()
with (
patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers),
patch("cli._cprint"),
):
assert cli.process_command("/model") is True
cli._model_picker_state["selected"] = 1
with patch("hermes_cli.models.provider_model_ids", return_value=["claude-opus-4.6", "claude-sonnet-4.6"]):
HermesCLI._handle_model_picker_selection(cli)
assert cli._model_picker_state["stage"] == "model"
assert cli._model_picker_state["provider_data"]["slug"] == "anthropic"
assert cli._model_picker_state["model_list"] == ["claude-opus-4.6", "claude-sonnet-4.6"]
cli._model_picker_state["selected"] = 0
switch_result = SimpleNamespace(
success=True,
error_message=None,
new_model="claude-opus-4.6",
target_provider="anthropic",
api_key="",
base_url="",
api_mode="anthropic_messages",
provider_label="Anthropic",
model_info=None,
warning_message=None,
provider_changed=True,
)
with (
patch("hermes_cli.model_switch.switch_model", return_value=switch_result) as switch_mock,
patch("cli._cprint"),
):
HermesCLI._handle_model_picker_selection(cli)
assert cli._model_picker_state is None
assert cli.model == "claude-opus-4.6"
assert cli.provider == "anthropic"
assert cli.requested_provider == "anthropic"
assert cli._app.current_buffer.text == "draft text"
switch_mock.assert_called_once()
assert switch_mock.call_args.kwargs["explicit_provider"] == "anthropic"

View File

@@ -0,0 +1,241 @@
"""Regression test: openai-codex must appear in /model picker when
credentials are only in the Codex CLI shared file (~/.codex/auth.json)
and haven't been migrated to the Hermes auth store yet.
Root cause: list_authenticated_providers() checked the raw Hermes auth
store but didn't know about the Codex CLI fallback import path.
Fix: _seed_from_singletons() now imports from the Codex CLI when the
Hermes auth store has no openai-codex tokens, and
list_authenticated_providers() falls back to load_pool() for OAuth
providers.
"""
import base64
import json
import os
import sys
import time
from pathlib import Path
from unittest.mock import patch
import pytest
def _make_fake_jwt(expiry_offset: int = 3600) -> str:
"""Build a fake JWT with a future expiry."""
header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
exp = int(time.time()) + expiry_offset
payload_bytes = json.dumps({"exp": exp, "sub": "test"}).encode()
payload = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode()
return f"{header}.{payload}.fakesig"
@pytest.fixture()
def codex_cli_only_env(tmp_path, monkeypatch):
"""Set up an environment where Codex tokens exist only in ~/.codex/auth.json,
NOT in the Hermes auth store."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
codex_home = tmp_path / ".codex"
codex_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setenv("CODEX_HOME", str(codex_home))
# Empty Hermes auth store
(hermes_home / "auth.json").write_text(
json.dumps({"version": 2, "providers": {}})
)
# Valid Codex CLI tokens
fake_jwt = _make_fake_jwt()
(codex_home / "auth.json").write_text(
json.dumps({
"tokens": {
"access_token": fake_jwt,
"refresh_token": "fake-refresh-token",
}
})
)
# Clear provider env vars so only OAuth is a detection path
for var in [
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"NOUS_API_KEY", "DEEPSEEK_API_KEY", "COPILOT_GITHUB_TOKEN",
"GH_TOKEN", "GEMINI_API_KEY",
]:
monkeypatch.delenv(var, raising=False)
return hermes_home
def test_codex_cli_tokens_detected_by_model_picker(codex_cli_only_env):
"""openai-codex should appear when tokens only exist in ~/.codex/auth.json."""
from hermes_cli.model_switch import list_authenticated_providers
providers = list_authenticated_providers(
current_provider="openai-codex",
max_models=10,
)
slugs = [p["slug"] for p in providers]
assert "openai-codex" in slugs, (
f"openai-codex not found in /model picker providers: {slugs}"
)
codex = next(p for p in providers if p["slug"] == "openai-codex")
assert codex["is_current"] is True
assert codex["total_models"] > 0
def test_codex_cli_tokens_migrated_after_detection(codex_cli_only_env):
"""After the /model picker detects Codex CLI tokens, they should be
migrated into the Hermes auth store for subsequent fast lookups."""
from hermes_cli.model_switch import list_authenticated_providers
# First call triggers migration
list_authenticated_providers(current_provider="openai-codex")
# Verify tokens are now in Hermes auth store
auth_path = codex_cli_only_env / "auth.json"
store = json.loads(auth_path.read_text())
providers = store.get("providers", {})
assert "openai-codex" in providers, (
f"openai-codex not migrated to Hermes auth store: {list(providers.keys())}"
)
tokens = providers["openai-codex"].get("tokens", {})
assert tokens.get("access_token"), "access_token missing after migration"
assert tokens.get("refresh_token"), "refresh_token missing after migration"
@pytest.fixture()
def hermes_auth_only_env(tmp_path, monkeypatch):
"""Tokens already in Hermes auth store (no Codex CLI needed)."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
# Point CODEX_HOME to nonexistent dir to prove it's not needed
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "no_codex"))
(hermes_home / "auth.json").write_text(json.dumps({
"version": 2,
"providers": {
"openai-codex": {
"tokens": {
"access_token": _make_fake_jwt(),
"refresh_token": "fake-refresh",
},
"last_refresh": "2026-04-12T00:00:00Z",
}
},
}))
for var in [
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"NOUS_API_KEY", "DEEPSEEK_API_KEY",
]:
monkeypatch.delenv(var, raising=False)
return hermes_home
def test_normal_path_still_works(hermes_auth_only_env):
"""openai-codex appears when tokens are already in Hermes auth store."""
from hermes_cli.model_switch import list_authenticated_providers
providers = list_authenticated_providers(
current_provider="openai-codex",
max_models=10,
)
slugs = [p["slug"] for p in providers]
assert "openai-codex" in slugs
@pytest.fixture()
def claude_code_only_env(tmp_path, monkeypatch):
"""Set up an environment where Anthropic credentials only exist in
~/.claude/.credentials.json (Claude Code) — not in env vars or Hermes
auth store."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
# No Codex CLI
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "no_codex"))
(hermes_home / "auth.json").write_text(
json.dumps({"version": 2, "providers": {}})
)
# Claude Code credentials in the correct format
claude_dir = tmp_path / ".claude"
claude_dir.mkdir()
(claude_dir / ".credentials.json").write_text(json.dumps({
"claudeAiOauth": {
"accessToken": _make_fake_jwt(),
"refreshToken": "fake-refresh",
"expiresAt": int(time.time() * 1000) + 3_600_000,
}
}))
# Patch Path.home() so the adapter finds the file
monkeypatch.setattr(Path, "home", classmethod(lambda cls: tmp_path))
for var in [
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN",
"NOUS_API_KEY", "DEEPSEEK_API_KEY",
]:
monkeypatch.delenv(var, raising=False)
return hermes_home
def test_claude_code_file_detected_by_model_picker(claude_code_only_env):
"""anthropic should appear when credentials only exist in ~/.claude/.credentials.json."""
from hermes_cli.model_switch import list_authenticated_providers
providers = list_authenticated_providers(
current_provider="anthropic",
max_models=10,
)
slugs = [p["slug"] for p in providers]
assert "anthropic" in slugs, (
f"anthropic not found in /model picker providers: {slugs}"
)
anthropic = next(p for p in providers if p["slug"] == "anthropic")
assert anthropic["is_current"] is True
assert anthropic["total_models"] > 0
def test_no_codex_when_no_credentials(tmp_path, monkeypatch):
"""openai-codex should NOT appear when no credentials exist anywhere."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "no_codex"))
(hermes_home / "auth.json").write_text(
json.dumps({"version": 2, "providers": {}})
)
for var in [
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"NOUS_API_KEY", "DEEPSEEK_API_KEY", "COPILOT_GITHUB_TOKEN",
"GH_TOKEN", "GEMINI_API_KEY",
]:
monkeypatch.delenv(var, raising=False)
from hermes_cli.model_switch import list_authenticated_providers
providers = list_authenticated_providers(
current_provider="openrouter",
max_models=10,
)
slugs = [p["slug"] for p in providers]
assert "openai-codex" not in slugs, (
"openai-codex should not appear without any credentials"
)

View File

@@ -68,6 +68,7 @@ class TestLoadConfigDefaults:
assert "max_turns" not in config
assert "terminal" in config
assert config["terminal"]["backend"] == "local"
assert config["display"]["interim_assistant_messages"] is True
def test_legacy_root_level_max_turns_migrates_to_agent_config(self, tmp_path):
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
@@ -421,3 +422,25 @@ class TestAnthropicTokenMigration:
}):
migrate_config(interactive=False, quiet=True)
assert load_env().get("ANTHROPIC_TOKEN") == "current-token"
class TestInterimAssistantMessageConfig:
"""Test the explicit gateway interim-message config gate."""
def test_default_config_enables_interim_assistant_messages(self):
assert DEFAULT_CONFIG["display"]["interim_assistant_messages"] is True
def test_migrate_to_v15_adds_interim_assistant_message_gate(self, tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump({"_config_version": 14, "display": {"tool_progress": "off"}}),
encoding="utf-8",
)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 16
assert raw["display"]["tool_progress"] == "off"
assert raw["display"]["interim_assistant_messages"] is True

View File

@@ -0,0 +1,342 @@
"""Tests for container-aware CLI routing (NixOS container mode).
When container.enable = true in the NixOS module, the activation script
writes a .container-mode metadata file. The host CLI detects this and
execs into the container instead of running locally.
"""
import os
import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from hermes_cli.config import (
_is_inside_container,
get_container_exec_info,
)
# =============================================================================
# _is_inside_container
# =============================================================================
def test_is_inside_container_dockerenv():
"""Detects /.dockerenv marker file."""
with patch("os.path.exists") as mock_exists:
mock_exists.side_effect = lambda p: p == "/.dockerenv"
assert _is_inside_container() is True
def test_is_inside_container_containerenv():
"""Detects Podman's /run/.containerenv marker."""
with patch("os.path.exists") as mock_exists:
mock_exists.side_effect = lambda p: p == "/run/.containerenv"
assert _is_inside_container() is True
def test_is_inside_container_cgroup_docker():
"""Detects 'docker' in /proc/1/cgroup."""
with patch("os.path.exists", return_value=False), \
patch("builtins.open", create=True) as mock_open:
mock_open.return_value.__enter__ = lambda s: s
mock_open.return_value.__exit__ = MagicMock(return_value=False)
mock_open.return_value.read = MagicMock(
return_value="12:memory:/docker/abc123\n"
)
assert _is_inside_container() is True
def test_is_inside_container_false_on_host():
"""Returns False when none of the container indicators are present."""
with patch("os.path.exists", return_value=False), \
patch("builtins.open", side_effect=OSError("no such file")):
assert _is_inside_container() is False
# =============================================================================
# get_container_exec_info
# =============================================================================
@pytest.fixture
def container_env(tmp_path, monkeypatch):
"""Set up a fake HERMES_HOME with .container-mode file."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("HERMES_DEV", raising=False)
container_mode = hermes_home / ".container-mode"
container_mode.write_text(
"# Written by NixOS activation script. Do not edit manually.\n"
"backend=podman\n"
"container_name=hermes-agent\n"
"exec_user=hermes\n"
"hermes_bin=/data/current-package/bin/hermes\n"
)
return hermes_home
def test_get_container_exec_info_returns_metadata(container_env):
"""Reads .container-mode and returns all fields including exec_user."""
with patch("hermes_cli.config._is_inside_container", return_value=False):
info = get_container_exec_info()
assert info is not None
assert info["backend"] == "podman"
assert info["container_name"] == "hermes-agent"
assert info["exec_user"] == "hermes"
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
def test_get_container_exec_info_none_inside_container(container_env):
"""Returns None when we're already inside a container."""
with patch("hermes_cli.config._is_inside_container", return_value=True):
info = get_container_exec_info()
assert info is None
def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch):
"""Returns None when .container-mode doesn't exist (native mode)."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("HERMES_DEV", raising=False)
with patch("hermes_cli.config._is_inside_container", return_value=False):
info = get_container_exec_info()
assert info is None
def test_get_container_exec_info_skipped_when_hermes_dev(container_env, monkeypatch):
"""Returns None when HERMES_DEV=1 is set (dev mode bypass)."""
monkeypatch.setenv("HERMES_DEV", "1")
with patch("hermes_cli.config._is_inside_container", return_value=False):
info = get_container_exec_info()
assert info is None
def test_get_container_exec_info_not_skipped_when_hermes_dev_zero(container_env, monkeypatch):
"""HERMES_DEV=0 does NOT trigger bypass — only '1' does."""
monkeypatch.setenv("HERMES_DEV", "0")
with patch("hermes_cli.config._is_inside_container", return_value=False):
info = get_container_exec_info()
assert info is not None
def test_get_container_exec_info_defaults():
"""Falls back to defaults for missing keys."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
hermes_home = Path(tmpdir) / ".hermes"
hermes_home.mkdir()
(hermes_home / ".container-mode").write_text(
"# minimal file with no keys\n"
)
with patch("hermes_cli.config._is_inside_container", return_value=False), \
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \
patch.dict(os.environ, {}, clear=False):
os.environ.pop("HERMES_DEV", None)
info = get_container_exec_info()
assert info is not None
assert info["backend"] == "docker"
assert info["container_name"] == "hermes-agent"
assert info["exec_user"] == "hermes"
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
def test_get_container_exec_info_docker_backend(container_env):
"""Correctly reads docker backend with custom exec_user."""
(container_env / ".container-mode").write_text(
"backend=docker\n"
"container_name=hermes-custom\n"
"exec_user=myuser\n"
"hermes_bin=/opt/hermes/bin/hermes\n"
)
with patch("hermes_cli.config._is_inside_container", return_value=False):
info = get_container_exec_info()
assert info["backend"] == "docker"
assert info["container_name"] == "hermes-custom"
assert info["exec_user"] == "myuser"
assert info["hermes_bin"] == "/opt/hermes/bin/hermes"
def test_get_container_exec_info_crashes_on_permission_error(container_env):
"""PermissionError propagates instead of being silently swallowed."""
with patch("hermes_cli.config._is_inside_container", return_value=False), \
patch("builtins.open", side_effect=PermissionError("permission denied")):
with pytest.raises(PermissionError):
get_container_exec_info()
# =============================================================================
# _exec_in_container
# =============================================================================
@pytest.fixture
def docker_container_info():
return {
"backend": "docker",
"container_name": "hermes-agent",
"exec_user": "hermes",
"hermes_bin": "/data/current-package/bin/hermes",
}
@pytest.fixture
def podman_container_info():
return {
"backend": "podman",
"container_name": "hermes-agent",
"exec_user": "hermes",
"hermes_bin": "/data/current-package/bin/hermes",
}
def test_exec_in_container_calls_execvp(docker_container_info):
"""Verifies os.execvp is called with correct args: runtime, tty flags,
user, env vars, container name, binary, and CLI args."""
from hermes_cli.main import _exec_in_container
with patch("shutil.which", return_value="/usr/bin/docker"), \
patch("subprocess.run") as mock_run, \
patch("sys.stdin") as mock_stdin, \
patch("os.execvp") as mock_execvp, \
patch.dict(os.environ, {"TERM": "xterm-256color", "LANG": "en_US.UTF-8"},
clear=False):
mock_stdin.isatty.return_value = True
mock_run.return_value = MagicMock(returncode=0)
_exec_in_container(docker_container_info, ["chat", "-m", "opus"])
mock_execvp.assert_called_once()
cmd = mock_execvp.call_args[0][1]
assert cmd[0] == "/usr/bin/docker"
assert cmd[1] == "exec"
assert "-it" in cmd
idx_u = cmd.index("-u")
assert cmd[idx_u + 1] == "hermes"
e_indices = [i for i, v in enumerate(cmd) if v == "-e"]
e_values = [cmd[i + 1] for i in e_indices]
assert "TERM=xterm-256color" in e_values
assert "LANG=en_US.UTF-8" in e_values
assert "hermes-agent" in cmd
assert "/data/current-package/bin/hermes" in cmd
assert "chat" in cmd
def test_exec_in_container_non_tty_uses_i_only(docker_container_info):
"""Non-TTY mode uses -i instead of -it."""
from hermes_cli.main import _exec_in_container
with patch("shutil.which", return_value="/usr/bin/docker"), \
patch("subprocess.run") as mock_run, \
patch("sys.stdin") as mock_stdin, \
patch("os.execvp") as mock_execvp:
mock_stdin.isatty.return_value = False
mock_run.return_value = MagicMock(returncode=0)
_exec_in_container(docker_container_info, ["sessions", "list"])
cmd = mock_execvp.call_args[0][1]
assert "-i" in cmd
assert "-it" not in cmd
def test_exec_in_container_no_runtime_hard_fails(podman_container_info):
"""Hard fails when runtime not found (no fallback)."""
from hermes_cli.main import _exec_in_container
with patch("shutil.which", return_value=None), \
patch("subprocess.run") as mock_run, \
patch("os.execvp") as mock_execvp, \
pytest.raises(SystemExit) as exc_info:
_exec_in_container(podman_container_info, ["chat"])
mock_run.assert_not_called()
mock_execvp.assert_not_called()
assert exc_info.value.code != 0
def test_exec_in_container_sudo_probe_sets_prefix(podman_container_info):
"""When first probe fails and sudo probe succeeds, execvp is called
with sudo -n prefix."""
from hermes_cli.main import _exec_in_container
def which_side_effect(name):
if name == "podman":
return "/usr/bin/podman"
if name == "sudo":
return "/usr/bin/sudo"
return None
with patch("shutil.which", side_effect=which_side_effect), \
patch("subprocess.run") as mock_run, \
patch("sys.stdin") as mock_stdin, \
patch("os.execvp") as mock_execvp:
mock_stdin.isatty.return_value = True
mock_run.side_effect = [
MagicMock(returncode=1), # direct probe fails
MagicMock(returncode=0), # sudo probe succeeds
]
_exec_in_container(podman_container_info, ["chat"])
mock_execvp.assert_called_once()
cmd = mock_execvp.call_args[0][1]
assert cmd[0] == "/usr/bin/sudo"
assert cmd[1] == "-n"
assert cmd[2] == "/usr/bin/podman"
assert cmd[3] == "exec"
def test_exec_in_container_probe_timeout_prints_message(docker_container_info):
"""TimeoutExpired from probe produces a human-readable error, not a
raw traceback."""
from hermes_cli.main import _exec_in_container
with patch("shutil.which", return_value="/usr/bin/docker"), \
patch("subprocess.run", side_effect=subprocess.TimeoutExpired(
cmd=["docker", "inspect"], timeout=15)), \
patch("os.execvp") as mock_execvp, \
pytest.raises(SystemExit) as exc_info:
_exec_in_container(docker_container_info, ["chat"])
mock_execvp.assert_not_called()
assert exc_info.value.code == 1
def test_exec_in_container_container_not_running_no_sudo(docker_container_info):
"""When runtime exists but container not found and no sudo available,
prints helpful error about root containers."""
from hermes_cli.main import _exec_in_container
def which_side_effect(name):
if name == "docker":
return "/usr/bin/docker"
return None
with patch("shutil.which", side_effect=which_side_effect), \
patch("subprocess.run") as mock_run, \
patch("os.execvp") as mock_execvp, \
pytest.raises(SystemExit) as exc_info:
mock_run.return_value = MagicMock(returncode=1)
_exec_in_container(docker_container_info, ["chat"])
mock_execvp.assert_not_called()
assert exc_info.value.code == 1

View File

@@ -122,3 +122,54 @@ class TestCustomProviderModelSwitch:
model = config.get("model")
assert isinstance(model, dict)
assert model["default"] == "model-X"
def test_api_mode_set_from_provider_info(self, config_home):
"""When custom_providers entry has api_mode, it should be applied."""
import yaml
from hermes_cli.main import _model_flow_named_custom
provider_info = {
"name": "Anthropic Proxy",
"base_url": "https://proxy.example.com/anthropic",
"api_key": "***",
"model": "claude-3",
"api_mode": "anthropic_messages",
}
with patch("hermes_cli.models.fetch_api_models", return_value=["claude-3"]), \
patch.dict("sys.modules", {"simple_term_menu": None}), \
patch("builtins.input", return_value="1"), \
patch("builtins.print"):
_model_flow_named_custom({}, provider_info)
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
model = config.get("model")
assert isinstance(model, dict)
assert model.get("api_mode") == "anthropic_messages"
def test_api_mode_cleared_when_not_specified(self, config_home):
"""When custom_providers entry has no api_mode, stale api_mode is removed."""
import yaml
from hermes_cli.main import _model_flow_named_custom
# Pre-seed a stale api_mode in config
config_path = config_home / "config.yaml"
config_path.write_text(yaml.dump({"model": {"api_mode": "anthropic_messages"}}))
provider_info = {
"name": "My vLLM",
"base_url": "https://vllm.example.com/v1",
"api_key": "***",
"model": "llama-3",
}
with patch("hermes_cli.models.fetch_api_models", return_value=["llama-3"]), \
patch.dict("sys.modules", {"simple_term_menu": None}), \
patch("builtins.input", return_value="1"), \
patch("builtins.print"):
_model_flow_named_custom({}, provider_info)
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
model = config.get("model")
assert isinstance(model, dict)
assert "api_mode" not in model, "Stale api_mode should be removed"

View File

@@ -1,288 +1,255 @@
"""Tests for hermes_cli/logs.py — log viewing and filtering."""
"""Tests for hermes_cli.logs — log viewing and filtering."""
import os
import textwrap
from datetime import datetime, timedelta
from io import StringIO
from pathlib import Path
from unittest.mock import patch
import pytest
from hermes_cli.logs import (
LOG_FILES,
_extract_level,
_extract_logger_name,
_line_matches_component,
_matches_filters,
_parse_line_timestamp,
_parse_since,
_read_last_n_lines,
list_logs,
tail_log,
_read_tail,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def log_dir(tmp_path, monkeypatch):
"""Create a fake HERMES_HOME with a logs/ directory."""
home = Path(os.environ["HERMES_HOME"])
logs = home / "logs"
logs.mkdir(parents=True, exist_ok=True)
return logs
@pytest.fixture
def sample_agent_log(log_dir):
"""Write a realistic agent.log with mixed levels and sessions."""
lines = textwrap.dedent("""\
2026-04-05 10:00:00,000 INFO run_agent: conversation turn: session=sess_aaa model=claude provider=openrouter platform=cli history=0 msg='hello'
2026-04-05 10:00:01,000 INFO run_agent: tool terminal completed (0.50s, 200 chars)
2026-04-05 10:00:02,000 INFO run_agent: API call #1: model=claude provider=openrouter in=1000 out=200 total=1200 latency=1.5s
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
2026-04-05 10:00:04,000 INFO run_agent: conversation turn: session=sess_bbb model=gpt-5 provider=openai platform=telegram history=5 msg='fix bug'
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
2026-04-05 10:00:06,000 INFO run_agent: tool read_file completed (0.01s, 500 chars)
2026-04-05 10:00:07,000 DEBUG run_agent: verbose internal detail
2026-04-05 10:00:08,000 INFO credential_pool: credential pool: marking key-1 exhausted (status=429), rotating
2026-04-05 10:00:09,000 INFO credential_pool: credential pool: rotated to key-2
""")
path = log_dir / "agent.log"
path.write_text(lines)
return path
@pytest.fixture
def sample_errors_log(log_dir):
"""Write a small errors.log."""
lines = textwrap.dedent("""\
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
""")
path = log_dir / "errors.log"
path.write_text(lines)
return path
# ---------------------------------------------------------------------------
# _parse_since
# Timestamp parsing
# ---------------------------------------------------------------------------
class TestParseSince:
def test_hours(self):
cutoff = _parse_since("2h")
assert cutoff is not None
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(7200, abs=5)
assert abs((datetime.now() - cutoff).total_seconds() - 7200) < 2
def test_minutes(self):
cutoff = _parse_since("30m")
assert cutoff is not None
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(1800, abs=5)
assert abs((datetime.now() - cutoff).total_seconds() - 1800) < 2
def test_days(self):
cutoff = _parse_since("1d")
assert cutoff is not None
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(86400, abs=5)
assert abs((datetime.now() - cutoff).total_seconds() - 86400) < 2
def test_seconds(self):
cutoff = _parse_since("60s")
cutoff = _parse_since("120s")
assert cutoff is not None
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(60, abs=5)
assert abs((datetime.now() - cutoff).total_seconds() - 120) < 2
def test_invalid_returns_none(self):
assert _parse_since("abc") is None
assert _parse_since("") is None
assert _parse_since("10x") is None
def test_whitespace_handling(self):
cutoff = _parse_since(" 1h ")
def test_whitespace_tolerance(self):
cutoff = _parse_since(" 5m ")
assert cutoff is not None
# ---------------------------------------------------------------------------
# _parse_line_timestamp
# ---------------------------------------------------------------------------
class TestParseLineTimestamp:
def test_standard_format(self):
ts = _parse_line_timestamp("2026-04-05 10:00:00,123 INFO something")
assert ts is not None
assert ts.year == 2026
assert ts.hour == 10
ts = _parse_line_timestamp("2026-04-11 10:23:45 INFO gateway.run: msg")
assert ts == datetime(2026, 4, 11, 10, 23, 45)
def test_no_timestamp(self):
assert _parse_line_timestamp("just some text") is None
assert _parse_line_timestamp("no timestamp here") is None
def test_continuation_line(self):
assert _parse_line_timestamp(" at module.function (line 42)") is None
# ---------------------------------------------------------------------------
# _extract_level
# ---------------------------------------------------------------------------
class TestExtractLevel:
def test_info(self):
assert _extract_level("2026-04-05 10:00:00 INFO run_agent: something") == "INFO"
assert _extract_level("2026-01-01 00:00:00 INFO gateway.run: msg") == "INFO"
def test_warning(self):
assert _extract_level("2026-04-05 10:00:00 WARNING run_agent: bad") == "WARNING"
assert _extract_level("2026-01-01 00:00:00 WARNING tools.file: msg") == "WARNING"
def test_error(self):
assert _extract_level("2026-04-05 10:00:00 ERROR run_agent: crash") == "ERROR"
assert _extract_level("2026-01-01 00:00:00 ERROR run_agent: msg") == "ERROR"
def test_debug(self):
assert _extract_level("2026-04-05 10:00:00 DEBUG run_agent: detail") == "DEBUG"
assert _extract_level("2026-01-01 00:00:00 DEBUG agent.aux: msg") == "DEBUG"
def test_no_level(self):
assert _extract_level("just a plain line") is None
assert _extract_level("random text") is None
# ---------------------------------------------------------------------------
# _matches_filters
# Logger name extraction (new for component filtering)
# ---------------------------------------------------------------------------
class TestExtractLoggerName:
def test_standard_line(self):
line = "2026-04-11 10:23:45 INFO gateway.run: Starting gateway"
assert _extract_logger_name(line) == "gateway.run"
def test_nested_logger(self):
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: connected"
assert _extract_logger_name(line) == "gateway.platforms.telegram"
def test_warning_level(self):
line = "2026-04-11 10:23:45 WARNING tools.terminal_tool: timeout"
assert _extract_logger_name(line) == "tools.terminal_tool"
def test_with_session_tag(self):
line = "2026-04-11 10:23:45 INFO [abc123] tools.file_tools: reading file"
assert _extract_logger_name(line) == "tools.file_tools"
def test_with_session_tag_and_error(self):
line = "2026-04-11 10:23:45 ERROR [sess_xyz] agent.context_compressor: failed"
assert _extract_logger_name(line) == "agent.context_compressor"
def test_top_level_module(self):
line = "2026-04-11 10:23:45 INFO run_agent: starting conversation"
assert _extract_logger_name(line) == "run_agent"
def test_no_match(self):
assert _extract_logger_name("random text") is None
class TestLineMatchesComponent:
def test_gateway_component(self):
line = "2026-04-11 10:23:45 INFO gateway.run: msg"
assert _line_matches_component(line, ("gateway",))
def test_gateway_nested(self):
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: msg"
assert _line_matches_component(line, ("gateway",))
def test_tools_component(self):
line = "2026-04-11 10:23:45 INFO tools.terminal_tool: msg"
assert _line_matches_component(line, ("tools",))
def test_agent_with_multiple_prefixes(self):
prefixes = ("agent", "run_agent", "model_tools")
assert _line_matches_component(
"2026-04-11 10:23:45 INFO agent.context_compressor: msg", prefixes)
assert _line_matches_component(
"2026-04-11 10:23:45 INFO run_agent: msg", prefixes)
assert _line_matches_component(
"2026-04-11 10:23:45 INFO model_tools: msg", prefixes)
def test_no_match(self):
line = "2026-04-11 10:23:45 INFO tools.browser: msg"
assert not _line_matches_component(line, ("gateway",))
def test_with_session_tag(self):
line = "2026-04-11 10:23:45 INFO [abc] gateway.run: msg"
assert _line_matches_component(line, ("gateway",))
def test_unparseable_line(self):
assert not _line_matches_component("random text", ("gateway",))
# ---------------------------------------------------------------------------
# Combined filter
# ---------------------------------------------------------------------------
class TestMatchesFilters:
def test_no_filters_always_matches(self):
assert _matches_filters("any line") is True
def test_no_filters_passes_everything(self):
assert _matches_filters("any line")
def test_level_filter_passes(self):
def test_level_filter(self):
assert _matches_filters(
"2026-04-05 10:00:00 WARNING something",
min_level="WARNING",
) is True
"2026-01-01 00:00:00 WARNING x: msg", min_level="WARNING")
assert not _matches_filters(
"2026-01-01 00:00:00 INFO x: msg", min_level="WARNING")
def test_level_filter_rejects(self):
def test_session_filter(self):
assert _matches_filters(
"2026-04-05 10:00:00 INFO something",
min_level="WARNING",
) is False
"2026-01-01 00:00:00 INFO [abc123] x: msg", session_filter="abc123")
assert not _matches_filters(
"2026-01-01 00:00:00 INFO [xyz789] x: msg", session_filter="abc123")
def test_session_filter_passes(self):
def test_component_filter(self):
assert _matches_filters(
"session=sess_aaa model=claude",
session_filter="sess_aaa",
) is True
def test_session_filter_rejects(self):
assert _matches_filters(
"session=sess_aaa model=claude",
session_filter="sess_bbb",
) is False
def test_since_filter_passes(self):
# Line from the future should always pass
assert _matches_filters(
"2099-01-01 00:00:00 INFO future",
since=datetime.now(),
) is True
def test_since_filter_rejects(self):
assert _matches_filters(
"2020-01-01 00:00:00 INFO past",
since=datetime.now(),
) is False
"2026-01-01 00:00:00 INFO gateway.run: msg",
component_prefixes=("gateway",))
assert not _matches_filters(
"2026-01-01 00:00:00 INFO tools.file: msg",
component_prefixes=("gateway",))
def test_combined_filters(self):
line = "2099-01-01 00:00:00 WARNING run_agent: session=abc error"
"""All filters must pass for a line to match."""
line = "2026-04-11 10:00:00 WARNING [sess_1] gateway.run: connection lost"
assert _matches_filters(
line, min_level="WARNING", session_filter="abc",
since=datetime.now(),
) is True
# Fails session filter
line,
min_level="WARNING",
session_filter="sess_1",
component_prefixes=("gateway",),
)
# Fails component filter
assert not _matches_filters(
line,
min_level="WARNING",
session_filter="sess_1",
component_prefixes=("tools",),
)
def test_since_filter(self):
# Line with a very old timestamp should be filtered out
assert not _matches_filters(
"2020-01-01 00:00:00 INFO x: old msg",
since=datetime.now() - timedelta(hours=1))
# Line with a recent timestamp should pass
recent = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
assert _matches_filters(
line, min_level="WARNING", session_filter="xyz",
) is False
f"{recent} INFO x: recent msg",
since=datetime.now() - timedelta(hours=1))
# ---------------------------------------------------------------------------
# _read_last_n_lines
# File reading
# ---------------------------------------------------------------------------
class TestReadLastNLines:
def test_reads_correct_count(self, sample_agent_log):
lines = _read_last_n_lines(sample_agent_log, 3)
assert len(lines) == 3
class TestReadTail:
def test_read_small_file(self, tmp_path):
log_file = tmp_path / "test.log"
lines = [f"2026-01-01 00:00:0{i} INFO x: line {i}\n" for i in range(10)]
log_file.write_text("".join(lines))
def test_reads_all_when_fewer(self, sample_agent_log):
lines = _read_last_n_lines(sample_agent_log, 100)
assert len(lines) == 10 # sample has 10 lines
result = _read_last_n_lines(log_file, 5)
assert len(result) == 5
assert "line 9" in result[-1]
def test_empty_file(self, log_dir):
empty = log_dir / "empty.log"
empty.write_text("")
lines = _read_last_n_lines(empty, 10)
assert lines == []
def test_read_with_component_filter(self, tmp_path):
log_file = tmp_path / "test.log"
lines = [
"2026-01-01 00:00:00 INFO gateway.run: gw msg\n",
"2026-01-01 00:00:01 INFO tools.file: tool msg\n",
"2026-01-01 00:00:02 INFO gateway.session: session msg\n",
"2026-01-01 00:00:03 INFO agent.compressor: agent msg\n",
]
log_file.write_text("".join(lines))
def test_last_line_content(self, sample_agent_log):
lines = _read_last_n_lines(sample_agent_log, 1)
assert "rotated to key-2" in lines[0]
result = _read_tail(
log_file, 50,
has_filters=True,
component_prefixes=("gateway",),
)
assert len(result) == 2
assert "gw msg" in result[0]
assert "session msg" in result[1]
def test_empty_file(self, tmp_path):
log_file = tmp_path / "empty.log"
log_file.write_text("")
result = _read_last_n_lines(log_file, 10)
assert result == []
# ---------------------------------------------------------------------------
# tail_log
# LOG_FILES registry
# ---------------------------------------------------------------------------
class TestTailLog:
def test_basic_tail(self, sample_agent_log, capsys):
tail_log("agent", num_lines=3)
captured = capsys.readouterr()
assert "agent.log" in captured.out
# Should have the header + 3 lines
lines = captured.out.strip().split("\n")
assert len(lines) == 4 # 1 header + 3 content
def test_level_filter(self, sample_agent_log, capsys):
tail_log("agent", num_lines=50, level="ERROR")
captured = capsys.readouterr()
assert "level>=ERROR" in captured.out
# Only the ERROR line should appear
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
assert len(content_lines) == 1
assert "API call failed" in content_lines[0]
def test_session_filter(self, sample_agent_log, capsys):
tail_log("agent", num_lines=50, session="sess_bbb")
captured = capsys.readouterr()
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
assert len(content_lines) == 1
assert "sess_bbb" in content_lines[0]
def test_errors_log(self, sample_errors_log, capsys):
tail_log("errors", num_lines=10)
captured = capsys.readouterr()
assert "errors.log" in captured.out
assert "WARNING" in captured.out or "ERROR" in captured.out
def test_unknown_log_exits(self):
with pytest.raises(SystemExit):
tail_log("nonexistent")
def test_missing_file_exits(self, log_dir):
with pytest.raises(SystemExit):
tail_log("agent") # agent.log doesn't exist in clean log_dir
# ---------------------------------------------------------------------------
# list_logs
# ---------------------------------------------------------------------------
class TestListLogs:
def test_lists_files(self, sample_agent_log, sample_errors_log, capsys):
list_logs()
captured = capsys.readouterr()
assert "agent.log" in captured.out
assert "errors.log" in captured.out
def test_empty_dir(self, log_dir, capsys):
list_logs()
captured = capsys.readouterr()
assert "no log files yet" in captured.out
def test_shows_sizes(self, sample_agent_log, capsys):
list_logs()
captured = capsys.readouterr()
# File is small, should show as bytes or KB
assert "B" in captured.out or "KB" in captured.out
class TestLogFiles:
def test_known_log_files(self):
assert "agent" in LOG_FILES
assert "errors" in LOG_FILES
assert "gateway" in LOG_FILES

View File

@@ -46,6 +46,8 @@ def _make_args(**kwargs):
"command": None,
"args": None,
"auth": None,
"preset": None,
"env": None,
"mcp_action": None,
}
defaults.update(kwargs)
@@ -269,6 +271,145 @@ class TestMcpAdd:
config = load_config()
assert config["mcp_servers"]["broken"]["enabled"] is False
def test_add_stdio_server_with_env(self, tmp_path, capsys, monkeypatch):
"""Stdio servers can persist explicit environment variables."""
fake_tools = [FakeTool("search", "Search repos")]
def mock_probe(name, config, **kw):
assert config["env"] == {
"MY_API_KEY": "secret123",
"DEBUG": "true",
}
return [(t.name, t.description) for t in fake_tools]
monkeypatch.setattr(
"hermes_cli.mcp_config._probe_single_server", mock_probe
)
monkeypatch.setattr("builtins.input", lambda _: "")
from hermes_cli.mcp_config import cmd_mcp_add
cmd_mcp_add(_make_args(
name="github",
command="npx",
args=["@mcp/github"],
env=["MY_API_KEY=secret123", "DEBUG=true"],
))
out = capsys.readouterr().out
assert "Saved" in out
from hermes_cli.config import load_config
config = load_config()
srv = config["mcp_servers"]["github"]
assert srv["env"] == {
"MY_API_KEY": "secret123",
"DEBUG": "true",
}
def test_add_stdio_server_rejects_invalid_env_name(self, capsys):
"""Invalid environment variable names are rejected up front."""
from hermes_cli.mcp_config import cmd_mcp_add
cmd_mcp_add(_make_args(
name="github",
command="npx",
args=["@mcp/github"],
env=["BAD-NAME=value"],
))
out = capsys.readouterr().out
assert "Invalid --env variable name" in out
def test_add_http_server_rejects_env_flag(self, capsys):
"""The --env flag is only valid for stdio transports."""
from hermes_cli.mcp_config import cmd_mcp_add
cmd_mcp_add(_make_args(
name="ink",
url="https://mcp.ml.ink/mcp",
env=["DEBUG=true"],
))
out = capsys.readouterr().out
assert "only supported for stdio MCP servers" in out
def test_add_preset_fills_transport(self, tmp_path, capsys, monkeypatch):
"""A preset fills in command/args when no explicit transport given."""
monkeypatch.setattr(
"hermes_cli.mcp_config._MCP_PRESETS",
{"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}},
)
fake_tools = [FakeTool("do_thing", "Does a thing")]
def mock_probe(name, config, **kw):
assert name == "myserver"
assert config["command"] == "npx"
assert config["args"] == ["-y", "test-mcp-server"]
assert "env" not in config
return [(t.name, t.description) for t in fake_tools]
monkeypatch.setattr(
"hermes_cli.mcp_config._probe_single_server", mock_probe
)
monkeypatch.setattr("builtins.input", lambda _: "")
from hermes_cli.mcp_config import cmd_mcp_add
from hermes_cli.config import read_raw_config
cmd_mcp_add(_make_args(name="myserver", preset="testmcp"))
out = capsys.readouterr().out
assert "Saved" in out
config = read_raw_config()
srv = config["mcp_servers"]["myserver"]
assert srv["command"] == "npx"
assert srv["args"] == ["-y", "test-mcp-server"]
assert "env" not in srv
def test_preset_does_not_override_explicit_command(self, tmp_path, capsys, monkeypatch):
"""Explicit transports win over presets."""
monkeypatch.setattr(
"hermes_cli.mcp_config._MCP_PRESETS",
{"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}},
)
fake_tools = [FakeTool("search", "Search repos")]
def mock_probe(name, config, **kw):
assert config["command"] == "uvx"
assert config["args"] == ["custom-server"]
assert "env" not in config
return [(t.name, t.description) for t in fake_tools]
monkeypatch.setattr(
"hermes_cli.mcp_config._probe_single_server", mock_probe
)
monkeypatch.setattr("builtins.input", lambda _: "")
from hermes_cli.mcp_config import cmd_mcp_add
from hermes_cli.config import read_raw_config
cmd_mcp_add(_make_args(
name="custom",
preset="testmcp",
command="uvx",
args=["custom-server"],
))
out = capsys.readouterr().out
assert "Saved" in out
config = read_raw_config()
srv = config["mcp_servers"]["custom"]
assert srv["command"] == "uvx"
assert srv["args"] == ["custom-server"]
assert "env" not in srv
def test_unknown_preset_rejected(self, capsys):
"""An unknown preset name is rejected with a clear error."""
from hermes_cli.mcp_config import cmd_mcp_add
cmd_mcp_add(_make_args(name="foo", preset="nonexistent"))
out = capsys.readouterr().out
assert "Unknown MCP preset" in out
# ---------------------------------------------------------------------------
# Tests: cmd_mcp_test

View File

@@ -257,3 +257,76 @@ class TestProviderPersistsAfterModelSave:
assert model.get("provider") == "opencode-go"
assert model.get("default") == "minimax-m2.5"
assert model.get("api_mode") == "anthropic_messages"
class TestBaseUrlValidation:
"""Reject non-URL values in the base URL prompt (e.g. shell commands)."""
def test_invalid_base_url_rejected(self, config_home, monkeypatch, capsys):
"""Typing a non-URL string should not be saved as the base URL."""
from hermes_cli.auth import PROVIDER_REGISTRY
pconfig = PROVIDER_REGISTRY.get("zai")
if not pconfig:
pytest.skip("zai not in PROVIDER_REGISTRY")
monkeypatch.setenv("GLM_API_KEY", "test-key")
from hermes_cli.main import _model_flow_api_key_provider
from hermes_cli.config import load_config, get_env_value
# User types a shell command instead of a URL at the base URL prompt
with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \
patch("hermes_cli.auth.deactivate_provider"), \
patch("builtins.input", return_value="nano ~/.hermes/.env"):
_model_flow_api_key_provider(load_config(), "zai", "old-model")
# The garbage value should NOT have been saved
saved = get_env_value("GLM_BASE_URL") or ""
assert not saved or saved.startswith(("http://", "https://")), \
f"Non-URL value was saved as GLM_BASE_URL: {saved}"
captured = capsys.readouterr()
assert "Invalid URL" in captured.out
def test_valid_base_url_accepted(self, config_home, monkeypatch):
"""A proper URL should be saved normally."""
from hermes_cli.auth import PROVIDER_REGISTRY
pconfig = PROVIDER_REGISTRY.get("zai")
if not pconfig:
pytest.skip("zai not in PROVIDER_REGISTRY")
monkeypatch.setenv("GLM_API_KEY", "test-key")
from hermes_cli.main import _model_flow_api_key_provider
from hermes_cli.config import load_config, get_env_value
with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \
patch("hermes_cli.auth.deactivate_provider"), \
patch("builtins.input", return_value="https://custom.z.ai/api/paas/v4"):
_model_flow_api_key_provider(load_config(), "zai", "old-model")
saved = get_env_value("GLM_BASE_URL") or ""
assert saved == "https://custom.z.ai/api/paas/v4"
def test_empty_base_url_keeps_default(self, config_home, monkeypatch):
"""Pressing Enter (empty) should not change the base URL."""
from hermes_cli.auth import PROVIDER_REGISTRY
pconfig = PROVIDER_REGISTRY.get("zai")
if not pconfig:
pytest.skip("zai not in PROVIDER_REGISTRY")
monkeypatch.setenv("GLM_API_KEY", "test-key")
monkeypatch.delenv("GLM_BASE_URL", raising=False)
from hermes_cli.main import _model_flow_api_key_provider
from hermes_cli.config import load_config, get_env_value
with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \
patch("hermes_cli.auth.deactivate_provider"), \
patch("builtins.input", return_value=""):
_model_flow_api_key_provider(load_config(), "zai", "old-model")
saved = get_env_value("GLM_BASE_URL") or ""
assert saved == "", "Empty input should not save a base URL"

View File

@@ -1,8 +1,10 @@
from io import StringIO
from unittest.mock import patch
import pytest
from rich.console import Console
from cli import ChatConsole
from hermes_cli.skills_hub import do_check, do_install, do_list, do_update, handle_skills_slash
@@ -179,6 +181,21 @@ def test_do_update_reinstalls_outdated_skills(monkeypatch):
assert "Updated 1 skill" in output
def test_handle_skills_slash_search_accepts_chatconsole_without_status_errors():
results = [type("R", (), {
"name": "kubernetes",
"description": "Cluster orchestration",
"source": "skills.sh",
"trust_level": "community",
"identifier": "skills-sh/example/kubernetes",
})()]
with patch("tools.skills_hub.unified_search", return_value=results), \
patch("tools.skills_hub.create_source_router", return_value={}), \
patch("tools.skills_hub.GitHubAuth"):
handle_skills_slash("/skills search kubernetes", console=ChatConsole())
def test_do_install_scans_with_resolved_identifier(monkeypatch, tmp_path, hub_env):
import tools.skills_guard as guard
import tools.skills_hub as hub

View File

@@ -0,0 +1,77 @@
"""Tests for hermes_cli/tips.py — random tip display at session start."""
import pytest
from hermes_cli.tips import TIPS, get_random_tip, get_tip_count
class TestTipsCorpus:
"""Validate the tip corpus itself."""
def test_has_at_least_200_tips(self):
assert len(TIPS) >= 200, f"Expected 200+ tips, got {len(TIPS)}"
def test_no_duplicates(self):
assert len(TIPS) == len(set(TIPS)), "Duplicate tips found"
def test_all_tips_are_strings(self):
for i, tip in enumerate(TIPS):
assert isinstance(tip, str), f"Tip {i} is not a string: {type(tip)}"
def test_no_empty_tips(self):
for i, tip in enumerate(TIPS):
assert tip.strip(), f"Tip {i} is empty or whitespace-only"
def test_max_length_reasonable(self):
"""Tips should fit on a single terminal line (~120 chars max)."""
for i, tip in enumerate(TIPS):
assert len(tip) <= 150, (
f"Tip {i} too long ({len(tip)} chars): {tip[:60]}..."
)
def test_no_leading_trailing_whitespace(self):
for i, tip in enumerate(TIPS):
assert tip == tip.strip(), f"Tip {i} has leading/trailing whitespace"
class TestGetRandomTip:
"""Validate the get_random_tip() function."""
def test_returns_string(self):
tip = get_random_tip()
assert isinstance(tip, str)
assert len(tip) > 0
def test_returns_tip_from_corpus(self):
tip = get_random_tip()
assert tip in TIPS
def test_randomness(self):
"""Multiple calls should eventually return different tips."""
seen = set()
for _ in range(50):
seen.add(get_random_tip())
# With 200+ tips and 50 draws, we should see at least 10 unique
assert len(seen) >= 10, f"Only got {len(seen)} unique tips in 50 draws"
class TestGetTipCount:
def test_matches_corpus_length(self):
assert get_tip_count() == len(TIPS)
class TestTipIntegrationInCLI:
"""Test that the tip display code in cli.py works correctly."""
def test_tip_import_works(self):
"""The import used in cli.py must succeed."""
from hermes_cli.tips import get_random_tip
assert callable(get_random_tip)
def test_tip_display_format(self):
"""Verify the Rich markup format doesn't break."""
tip = get_random_tip()
color = "#B8860B"
markup = f"[dim {color}]✦ Tip: {tip}[/]"
# Should not contain nested/broken Rich tags
assert markup.count("[/]") == 1
assert "[dim #B8860B]" in markup

View File

@@ -798,3 +798,120 @@ class TestFindGatewayPidsExclude:
pids = gateway_cli.find_gateway_pids()
assert pids == [100]
# ---------------------------------------------------------------------------
# Gateway mode writes exit code before restart (#8300)
# ---------------------------------------------------------------------------
class TestGatewayModeWritesExitCodeEarly:
"""When running as ``hermes update --gateway``, the exit code marker must be
written *before* the gateway restart attempt. Without this, systemd's
``KillMode=mixed`` kills the update process (and its wrapping shell) during
the cgroup teardown, so the shell epilogue that normally writes the exit
code never executes. The new gateway's update watcher then polls for 30
minutes and sends a spurious timeout message.
"""
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_exit_code_written_in_gateway_mode(
self, mock_run, _mock_which, capsys, tmp_path, monkeypatch,
):
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
# Point HERMES_HOME at a temp dir so the marker file lands there
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
import hermes_cli.config as _cfg
monkeypatch.setattr(_cfg, "get_hermes_home", lambda: hermes_home)
# Also patch the module-level ref used by cmd_update
import hermes_cli.main as _main_mod
monkeypatch.setattr(_main_mod, "get_hermes_home", lambda: hermes_home)
mock_run.side_effect = _make_run_side_effect(commit_count="1")
args = SimpleNamespace(gateway=True)
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(args)
exit_code_path = hermes_home / ".update_exit_code"
assert exit_code_path.exists(), ".update_exit_code not written in gateway mode"
assert exit_code_path.read_text() == "0"
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_exit_code_not_written_in_normal_mode(
self, mock_run, _mock_which, capsys, tmp_path, monkeypatch,
):
"""Non-gateway mode should NOT write the exit code (the shell does it)."""
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
import hermes_cli.config as _cfg
monkeypatch.setattr(_cfg, "get_hermes_home", lambda: hermes_home)
import hermes_cli.main as _main_mod
monkeypatch.setattr(_main_mod, "get_hermes_home", lambda: hermes_home)
mock_run.side_effect = _make_run_side_effect(commit_count="1")
args = SimpleNamespace(gateway=False)
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(args)
exit_code_path = hermes_home / ".update_exit_code"
assert not exit_code_path.exists(), ".update_exit_code should not be written outside gateway mode"
@patch("shutil.which", return_value=None)
@patch("subprocess.run")
def test_exit_code_written_before_restart_call(
self, mock_run, _mock_which, capsys, tmp_path, monkeypatch,
):
"""Exit code must exist BEFORE systemctl restart is called."""
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
import hermes_cli.config as _cfg
monkeypatch.setattr(_cfg, "get_hermes_home", lambda: hermes_home)
import hermes_cli.main as _main_mod
monkeypatch.setattr(_main_mod, "get_hermes_home", lambda: hermes_home)
exit_code_path = hermes_home / ".update_exit_code"
# Track whether exit code exists when systemctl restart is called
exit_code_existed_at_restart = []
original_side_effect = _make_run_side_effect(
commit_count="1", systemd_active=True,
)
def tracking_side_effect(cmd, **kwargs):
joined = " ".join(str(c) for c in cmd)
if "systemctl" in joined and "restart" in joined:
exit_code_existed_at_restart.append(exit_code_path.exists())
return original_side_effect(cmd, **kwargs)
mock_run.side_effect = tracking_side_effect
args = SimpleNamespace(gateway=True)
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
cmd_update(args)
assert exit_code_existed_at_restart, "systemctl restart was never called"
assert exit_code_existed_at_restart[0] is True, \
".update_exit_code must exist BEFORE systemctl restart (cgroup kill race)"

View File

@@ -26,6 +26,7 @@ def _make_agent(
agent.provider = "openrouter"
agent.base_url = "https://openrouter.ai/api/v1"
agent.api_key = "sk-test"
agent.api_mode = "chat_completions"
agent.quiet_mode = True
agent.log_prefix = ""
agent.compression_enabled = compression_enabled
@@ -99,6 +100,36 @@ def test_no_warning_when_aux_context_sufficient(mock_get_client, mock_ctx_len):
assert agent._compression_warning is None
def test_feasibility_check_passes_live_main_runtime():
"""Compression feasibility should probe using the live session runtime."""
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
agent.model = "gpt-5.4"
agent.provider = "openai-codex"
agent.base_url = "https://chatgpt.com/backend-api/codex"
agent.api_key = "codex-token"
agent.api_mode = "codex_responses"
mock_client = MagicMock()
mock_client.base_url = "https://chatgpt.com/backend-api/codex"
mock_client.api_key = "codex-token"
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_client, "gpt-5.4")) as mock_get_client, \
patch("agent.model_metadata.get_model_context_length", return_value=200_000):
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_get_client.assert_called_once_with(
"compression",
main_runtime={
"model": "gpt-5.4",
"provider": "openai-codex",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_key": "codex-token",
"api_mode": "codex_responses",
},
)
@patch("agent.auxiliary_client.get_text_auxiliary_client")
def test_warns_when_no_auxiliary_provider(mock_get_client):
"""Warning emitted when no auxiliary provider is configured."""

View File

@@ -2742,74 +2742,12 @@ class TestSystemPromptStability:
assert "Hermes Agent" in agent._cached_system_prompt
class TestBudgetPressure:
"""Budget pressure warning system (issue #414)."""
"""Budget exhaustion grace call system."""
def test_no_warning_below_caution(self, agent):
agent.max_iterations = 60
assert agent._get_budget_warning(30) is None
def test_caution_at_70_percent(self, agent):
agent.max_iterations = 60
msg = agent._get_budget_warning(42)
assert msg is not None
assert "[BUDGET:" in msg
assert "18 iterations left" in msg
def test_warning_at_90_percent(self, agent):
agent.max_iterations = 60
msg = agent._get_budget_warning(54)
assert "[BUDGET WARNING:" in msg
assert "Provide your final response NOW" in msg
def test_last_iteration(self, agent):
agent.max_iterations = 60
msg = agent._get_budget_warning(59)
assert "1 iteration(s) left" in msg
def test_disabled(self, agent):
agent.max_iterations = 60
agent._budget_pressure_enabled = False
assert agent._get_budget_warning(55) is None
def test_zero_max_iterations(self, agent):
agent.max_iterations = 0
assert agent._get_budget_warning(0) is None
def test_injects_into_json_tool_result(self, agent):
"""Warning should be injected as _budget_warning field in JSON tool results."""
import json
agent.max_iterations = 10
messages = [
{"role": "tool", "content": json.dumps({"output": "done", "exit_code": 0}), "tool_call_id": "tc1"}
]
warning = agent._get_budget_warning(9)
assert warning is not None
# Simulate the injection logic
last_content = messages[-1]["content"]
parsed = json.loads(last_content)
parsed["_budget_warning"] = warning
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
result = json.loads(messages[-1]["content"])
assert "_budget_warning" in result
assert "BUDGET WARNING" in result["_budget_warning"]
assert result["output"] == "done" # original content preserved
def test_appends_to_non_json_tool_result(self, agent):
"""Warning should be appended as text for non-JSON tool results."""
agent.max_iterations = 10
messages = [
{"role": "tool", "content": "plain text result", "tool_call_id": "tc1"}
]
warning = agent._get_budget_warning(9)
# Simulate injection logic for non-JSON
last_content = messages[-1]["content"]
try:
import json
json.loads(last_content)
except (json.JSONDecodeError, TypeError):
messages[-1]["content"] = last_content + f"\n\n{warning}"
assert "plain text result" in messages[-1]["content"]
assert "BUDGET WARNING" in messages[-1]["content"]
def test_grace_call_flags_initialized(self, agent):
"""Agent should have budget grace call flags."""
assert agent._budget_exhausted_injected is False
assert agent._budget_grace_call is False
class TestSafeWriter:

View File

@@ -744,6 +744,44 @@ def test_normalize_codex_response_marks_commentary_only_message_as_incomplete(mo
assert "inspect the repository" in (assistant_message.content or "")
def test_interim_commentary_is_not_marked_already_streamed_without_callbacks(monkeypatch):
agent = _build_agent(monkeypatch)
observed = {}
agent._fire_stream_delta("short version: yes")
agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update(
{"text": text, "already_streamed": already_streamed}
)
agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"})
assert observed == {
"text": "short version: yes",
"already_streamed": False,
}
def test_interim_commentary_is_not_marked_already_streamed_when_stream_callback_fails(monkeypatch):
agent = _build_agent(monkeypatch)
observed = {}
def failing_callback(_text):
raise RuntimeError("display failed")
agent.stream_delta_callback = failing_callback
agent._fire_stream_delta("short version: yes")
agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update(
{"text": text, "already_streamed": already_streamed}
)
agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"})
assert observed == {
"text": "short version: yes",
"already_streamed": False,
}
def test_run_conversation_codex_continues_after_commentary_phase_message(monkeypatch):
agent = _build_agent(monkeypatch)
responses = [

View File

@@ -185,6 +185,38 @@ def test_migrator_optionally_imports_supported_secrets_and_messaging_settings(tm
assert "TELEGRAM_BOT_TOKEN=123:abc" in env_text
def test_messaging_cwd_skipped_when_inside_source(tmp_path: Path):
"""MESSAGING_CWD pointing inside the OpenClaw source dir should be skipped."""
mod = load_module()
source = tmp_path / ".openclaw"
target = tmp_path / ".hermes"
target.mkdir()
# Workspace path is inside the source directory
ws_path = str(source / "workspace")
(source / "credentials").mkdir(parents=True)
(source / "openclaw.json").write_text(
json.dumps({"agents": {"defaults": {"workspace": ws_path}}}),
encoding="utf-8",
)
migrator = mod.Migrator(
source_root=source,
target_root=target,
execute=True,
workspace_target=None,
overwrite=False,
migrate_secrets=True,
output_dir=target / "migration-report",
selected_options={"messaging-settings"},
)
migrator.migrate()
env_path = target / ".env"
if env_path.exists():
assert "MESSAGING_CWD" not in env_path.read_text(encoding="utf-8")
def test_migrator_can_execute_only_selected_categories(tmp_path: Path):
mod = load_module()
source = tmp_path / ".openclaw"
@@ -722,3 +754,98 @@ def test_skill_installs_cleanly_under_skills_guard():
KNOWN_FALSE_POSITIVES = {"agent_config_mod", "python_os_environ", "hermes_config_mod"}
for f in result.findings:
assert f.pattern_id in KNOWN_FALSE_POSITIVES, f"Unexpected finding: {f}"
# ── rebrand_text tests ────────────────────────────────────────
def test_rebrand_text_replaces_openclaw_variants():
mod = load_module()
assert mod.rebrand_text("OpenClaw prefers Python 3.11") == "Hermes prefers Python 3.11"
assert mod.rebrand_text("I told Open Claw to use dark mode") == "I told Hermes to use dark mode"
assert mod.rebrand_text("Open-Claw config is great") == "Hermes config is great"
assert mod.rebrand_text("openclaw should always respond concisely") == "Hermes should always respond concisely"
assert mod.rebrand_text("OPENCLAW uses tools well") == "Hermes uses tools well"
def test_rebrand_text_replaces_legacy_bot_names():
mod = load_module()
assert mod.rebrand_text("ClawdBot remembers my timezone") == "Hermes remembers my timezone"
assert mod.rebrand_text("clawdbot prefers tabs") == "Hermes prefers tabs"
assert mod.rebrand_text("MoltBot was configured for Spanish") == "Hermes was configured for Spanish"
assert mod.rebrand_text("moltbot uses Python") == "Hermes uses Python"
def test_rebrand_text_preserves_unrelated_content():
mod = load_module()
text = "User prefers dark mode and lives in Las Vegas"
assert mod.rebrand_text(text) == text
def test_rebrand_text_handles_multiple_replacements():
mod = load_module()
text = "OpenClaw said to ask ClawdBot about MoltBot settings"
assert mod.rebrand_text(text) == "Hermes said to ask Hermes about Hermes settings"
def test_migrate_memory_rebrands_entries(tmp_path):
mod = load_module()
source_root = tmp_path / "openclaw"
source_root.mkdir()
workspace = source_root / "workspace"
workspace.mkdir()
memory_md = workspace / "MEMORY.md"
memory_md.write_text(
"# Memory\n\n- OpenClaw should use Python 3.11\n- ClawdBot prefers dark mode\n",
encoding="utf-8",
)
target_root = tmp_path / "hermes"
target_root.mkdir()
(target_root / "memories").mkdir()
migrator = mod.Migrator(
source_root=source_root,
target_root=target_root,
execute=True,
workspace_target=None,
overwrite=False,
migrate_secrets=False,
output_dir=tmp_path / "report",
selected_options={"memory"},
)
migrator.migrate()
result = (target_root / "memories" / "MEMORY.md").read_text(encoding="utf-8")
assert "OpenClaw" not in result
assert "ClawdBot" not in result
assert "Hermes" in result
def test_migrate_soul_rebrands_content(tmp_path):
mod = load_module()
source_root = tmp_path / "openclaw"
source_root.mkdir()
workspace = source_root / "workspace"
workspace.mkdir()
soul_md = workspace / "SOUL.md"
soul_md.write_text("You are OpenClaw, an AI assistant made by SparkLab.", encoding="utf-8")
target_root = tmp_path / "hermes"
target_root.mkdir()
migrator = mod.Migrator(
source_root=source_root,
target_root=target_root,
execute=True,
workspace_target=None,
overwrite=False,
migrate_secrets=False,
output_dir=tmp_path / "report",
selected_options={"soul"},
)
migrator.migrate()
result = (target_root / "SOUL.md").read_text(encoding="utf-8")
assert "OpenClaw" not in result
assert "You are Hermes" in result

View File

@@ -0,0 +1,120 @@
"""Tests for empty model fallback — when provider is configured but model is missing."""
from unittest.mock import MagicMock, patch
import pytest
class TestGetDefaultModelForProvider:
"""Unit tests for hermes_cli.models.get_default_model_for_provider."""
def test_known_provider_returns_first_model(self):
from hermes_cli.models import get_default_model_for_provider
result = get_default_model_for_provider("openai-codex")
# Should return first model from _PROVIDER_MODELS["openai-codex"]
assert result
assert isinstance(result, str)
def test_openrouter_returns_empty(self):
"""OpenRouter uses dynamic model fetch, no static catalog entry."""
from hermes_cli.models import get_default_model_for_provider
# OpenRouter is not in _PROVIDER_MODELS — it uses live fetching
result = get_default_model_for_provider("openrouter")
assert result == ""
def test_unknown_provider_returns_empty(self):
from hermes_cli.models import get_default_model_for_provider
assert get_default_model_for_provider("nonexistent-provider") == ""
def test_custom_provider_returns_empty(self):
"""Custom provider has no model catalog — should return empty."""
from hermes_cli.models import get_default_model_for_provider
# Custom providers don't have entries in _PROVIDER_MODELS
assert get_default_model_for_provider("some-random-custom") == ""
class TestGatewayEmptyModelFallback:
"""Test that _resolve_session_agent_runtime fills in empty model from provider catalog."""
def test_empty_model_filled_from_provider(self):
"""When config has no model but provider is openai-codex, use first codex model."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._session_model_overrides = {}
# Mock _resolve_gateway_model to return empty string
# Mock _resolve_runtime_agent_kwargs to return openai-codex provider
with patch("gateway.run._resolve_gateway_model", return_value=""), \
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
"provider": "openai-codex",
"api_key": "test-key",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_mode": "codex_responses",
}):
model, kwargs = runner._resolve_session_agent_runtime()
# Model should have been filled in from provider catalog
assert model, "Model should not be empty when provider is known"
assert isinstance(model, str)
assert kwargs["provider"] == "openai-codex"
def test_nonempty_model_not_overridden(self):
"""When config has a model set, don't override it."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._session_model_overrides = {}
with patch("gateway.run._resolve_gateway_model", return_value="gpt-5.4"), \
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
"provider": "openai-codex",
"api_key": "test-key",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_mode": "codex_responses",
}):
model, kwargs = runner._resolve_session_agent_runtime()
assert model == "gpt-5.4", "Explicit model should not be overridden"
def test_empty_model_no_provider_stays_empty(self):
"""When both model and provider are empty, model stays empty."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._session_model_overrides = {}
with patch("gateway.run._resolve_gateway_model", return_value=""), \
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
"provider": "",
"api_key": "test-key",
"base_url": "https://example.com",
"api_mode": "chat_completions",
}):
model, kwargs = runner._resolve_session_agent_runtime()
# Can't fill in a default without knowing the provider
assert model == ""
class TestResolveGatewayModel:
"""Test _resolve_gateway_model reads model from config correctly."""
def test_returns_default_key(self):
from gateway.run import _resolve_gateway_model
assert _resolve_gateway_model({"model": {"default": "gpt-5.4"}}) == "gpt-5.4"
def test_returns_model_key_fallback(self):
from gateway.run import _resolve_gateway_model
assert _resolve_gateway_model({"model": {"model": "gpt-5.4"}}) == "gpt-5.4"
def test_returns_empty_when_missing(self):
from gateway.run import _resolve_gateway_model
assert _resolve_gateway_model({"model": {}}) == ""
def test_returns_empty_when_no_model_section(self):
from gateway.run import _resolve_gateway_model
assert _resolve_gateway_model({}) == ""
def test_string_model_config(self):
from gateway.run import _resolve_gateway_model
assert _resolve_gateway_model({"model": "my-model"}) == "my-model"

View File

@@ -3,6 +3,7 @@
import logging
import os
import stat
import threading
from logging.handlers import RotatingFileHandler
from pathlib import Path
from unittest.mock import patch
@@ -34,6 +35,8 @@ def _reset_logging_state():
h.close()
else:
pre_existing.append(h)
# Ensure the record factory is installed (it's idempotent).
hermes_logging._install_session_record_factory()
yield
# Restore — remove any handlers added during the test.
for h in list(root.handlers):
@@ -41,6 +44,7 @@ def _reset_logging_state():
root.removeHandler(h)
h.close()
hermes_logging._logging_initialized = False
hermes_logging.clear_session_context()
@pytest.fixture
@@ -220,6 +224,294 @@ class TestSetupLogging:
]
assert agent_handlers[0].level == logging.WARNING
def test_record_factory_installed(self, hermes_home):
"""The custom record factory injects session_tag on all records."""
hermes_logging.setup_logging(hermes_home=hermes_home)
factory = logging.getLogRecordFactory()
assert getattr(factory, "_hermes_session_injector", False), (
"Record factory should have _hermes_session_injector marker"
)
# Verify session_tag exists on a fresh record
record = factory("test", logging.INFO, "", 0, "msg", (), None)
assert hasattr(record, "session_tag")
class TestGatewayMode:
"""setup_logging(mode='gateway') creates a filtered gateway.log."""
def test_gateway_log_created(self, hermes_home):
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
root = logging.getLogger()
gw_handlers = [
h for h in root.handlers
if isinstance(h, RotatingFileHandler)
and "gateway.log" in getattr(h, "baseFilename", "")
]
assert len(gw_handlers) == 1
def test_gateway_log_not_created_in_cli_mode(self, hermes_home):
hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli")
root = logging.getLogger()
gw_handlers = [
h for h in root.handlers
if isinstance(h, RotatingFileHandler)
and "gateway.log" in getattr(h, "baseFilename", "")
]
assert len(gw_handlers) == 0
def test_gateway_log_receives_gateway_records(self, hermes_home):
"""gateway.log captures records from gateway.* loggers."""
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
gw_logger = logging.getLogger("gateway.platforms.telegram")
gw_logger.info("telegram connected")
for h in logging.getLogger().handlers:
h.flush()
gw_log = hermes_home / "logs" / "gateway.log"
assert gw_log.exists()
assert "telegram connected" in gw_log.read_text()
def test_gateway_log_rejects_non_gateway_records(self, hermes_home):
"""gateway.log does NOT capture records from tools.*, agent.*, etc."""
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
tool_logger = logging.getLogger("tools.terminal_tool")
tool_logger.info("running command")
agent_logger = logging.getLogger("agent.context_compressor")
agent_logger.info("compressing context")
for h in logging.getLogger().handlers:
h.flush()
gw_log = hermes_home / "logs" / "gateway.log"
if gw_log.exists():
content = gw_log.read_text()
assert "running command" not in content
assert "compressing context" not in content
def test_agent_log_still_receives_all(self, hermes_home):
"""agent.log (catch-all) still receives gateway AND tool records."""
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
logging.getLogger("gateway.run").info("gateway msg")
logging.getLogger("tools.file_tools").info("file msg")
for h in logging.getLogger().handlers:
h.flush()
agent_log = hermes_home / "logs" / "agent.log"
content = agent_log.read_text()
assert "gateway msg" in content
assert "file msg" in content
class TestSessionContext:
"""set_session_context / clear_session_context + _SessionFilter."""
def test_session_tag_in_log_output(self, hermes_home):
"""When session context is set, log lines include [session_id]."""
hermes_logging.setup_logging(hermes_home=hermes_home)
hermes_logging.set_session_context("abc123")
test_logger = logging.getLogger("test.session_tag")
test_logger.info("tagged message")
for h in logging.getLogger().handlers:
h.flush()
agent_log = hermes_home / "logs" / "agent.log"
content = agent_log.read_text()
assert "[abc123]" in content
assert "tagged message" in content
def test_no_session_tag_without_context(self, hermes_home):
"""Without session context, log lines have no session tag."""
hermes_logging.setup_logging(hermes_home=hermes_home)
hermes_logging.clear_session_context()
test_logger = logging.getLogger("test.no_session")
test_logger.info("untagged message")
for h in logging.getLogger().handlers:
h.flush()
agent_log = hermes_home / "logs" / "agent.log"
content = agent_log.read_text()
assert "untagged message" in content
# Should not have any [xxx] session tag
import re
for line in content.splitlines():
if "untagged message" in line:
assert not re.search(r"\[.+?\]", line.split("INFO")[1].split("test.no_session")[0])
def test_clear_session_context(self, hermes_home):
"""After clearing, session tag disappears."""
hermes_logging.setup_logging(hermes_home=hermes_home)
hermes_logging.set_session_context("xyz789")
hermes_logging.clear_session_context()
test_logger = logging.getLogger("test.cleared")
test_logger.info("after clear")
for h in logging.getLogger().handlers:
h.flush()
agent_log = hermes_home / "logs" / "agent.log"
content = agent_log.read_text()
assert "[xyz789]" not in content
def test_session_context_thread_isolated(self, hermes_home):
"""Session context is per-thread — one thread's context doesn't leak."""
hermes_logging.setup_logging(hermes_home=hermes_home)
results = {}
def thread_a():
hermes_logging.set_session_context("thread_a_session")
logging.getLogger("test.thread_a").info("from thread A")
for h in logging.getLogger().handlers:
h.flush()
def thread_b():
hermes_logging.set_session_context("thread_b_session")
logging.getLogger("test.thread_b").info("from thread B")
for h in logging.getLogger().handlers:
h.flush()
ta = threading.Thread(target=thread_a)
tb = threading.Thread(target=thread_b)
ta.start()
ta.join()
tb.start()
tb.join()
agent_log = hermes_home / "logs" / "agent.log"
content = agent_log.read_text()
# Each thread's message should have its own session tag
for line in content.splitlines():
if "from thread A" in line:
assert "[thread_a_session]" in line
assert "[thread_b_session]" not in line
if "from thread B" in line:
assert "[thread_b_session]" in line
assert "[thread_a_session]" not in line
class TestRecordFactory:
"""Unit tests for the custom LogRecord factory."""
def test_record_has_session_tag(self):
"""Every record gets a session_tag attribute."""
factory = logging.getLogRecordFactory()
record = factory("test", logging.INFO, "", 0, "msg", (), None)
assert hasattr(record, "session_tag")
def test_empty_tag_without_context(self):
hermes_logging.clear_session_context()
factory = logging.getLogRecordFactory()
record = factory("test", logging.INFO, "", 0, "msg", (), None)
assert record.session_tag == ""
def test_tag_with_context(self):
hermes_logging.set_session_context("sess_42")
factory = logging.getLogRecordFactory()
record = factory("test", logging.INFO, "", 0, "msg", (), None)
assert record.session_tag == " [sess_42]"
def test_idempotent_install(self):
"""Calling _install_session_record_factory() twice doesn't double-wrap."""
hermes_logging._install_session_record_factory()
factory_a = logging.getLogRecordFactory()
hermes_logging._install_session_record_factory()
factory_b = logging.getLogRecordFactory()
assert factory_a is factory_b
def test_works_with_any_handler(self):
"""A handler using %(session_tag)s works even without _SessionFilter."""
hermes_logging.set_session_context("any_handler_test")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(session_tag)s %(message)s"))
logger = logging.getLogger("_test_any_handler")
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
try:
# Should not raise KeyError
logger.info("hello")
finally:
logger.removeHandler(handler)
class TestComponentFilter:
"""Unit tests for _ComponentFilter."""
def test_passes_matching_prefix(self):
f = hermes_logging._ComponentFilter(("gateway",))
record = logging.LogRecord(
"gateway.run", logging.INFO, "", 0, "msg", (), None
)
assert f.filter(record) is True
def test_passes_nested_matching_prefix(self):
f = hermes_logging._ComponentFilter(("gateway",))
record = logging.LogRecord(
"gateway.platforms.telegram", logging.INFO, "", 0, "msg", (), None
)
assert f.filter(record) is True
def test_blocks_non_matching(self):
f = hermes_logging._ComponentFilter(("gateway",))
record = logging.LogRecord(
"tools.terminal_tool", logging.INFO, "", 0, "msg", (), None
)
assert f.filter(record) is False
def test_multiple_prefixes(self):
f = hermes_logging._ComponentFilter(("agent", "run_agent", "model_tools"))
assert f.filter(logging.LogRecord(
"agent.compressor", logging.INFO, "", 0, "", (), None
))
assert f.filter(logging.LogRecord(
"run_agent", logging.INFO, "", 0, "", (), None
))
assert f.filter(logging.LogRecord(
"model_tools", logging.INFO, "", 0, "", (), None
))
assert not f.filter(logging.LogRecord(
"tools.browser", logging.INFO, "", 0, "", (), None
))
class TestComponentPrefixes:
"""COMPONENT_PREFIXES covers the expected components."""
def test_gateway_prefix(self):
assert "gateway" in hermes_logging.COMPONENT_PREFIXES
assert ("gateway",) == hermes_logging.COMPONENT_PREFIXES["gateway"]
def test_agent_prefix(self):
prefixes = hermes_logging.COMPONENT_PREFIXES["agent"]
assert "agent" in prefixes
assert "run_agent" in prefixes
assert "model_tools" in prefixes
def test_tools_prefix(self):
assert ("tools",) == hermes_logging.COMPONENT_PREFIXES["tools"]
def test_cli_prefix(self):
prefixes = hermes_logging.COMPONENT_PREFIXES["cli"]
assert "hermes_cli" in prefixes
assert "cli" in prefixes
def test_cron_prefix(self):
assert ("cron",) == hermes_logging.COMPONENT_PREFIXES["cron"]
class TestSetupVerboseLogging:
"""setup_verbose_logging() adds a DEBUG-level console handler."""
@@ -301,6 +593,59 @@ class TestAddRotatingHandler:
logger.removeHandler(h)
h.close()
def test_log_filter_attached(self, tmp_path):
"""Optional log_filter is attached to the handler."""
log_path = tmp_path / "filtered.log"
logger = logging.getLogger("_test_rotating_filter")
formatter = logging.Formatter("%(message)s")
component_filter = hermes_logging._ComponentFilter(("test",))
hermes_logging._add_rotating_handler(
logger, log_path,
level=logging.INFO, max_bytes=1024, backup_count=1,
formatter=formatter,
log_filter=component_filter,
)
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
assert len(handlers) == 1
assert component_filter in handlers[0].filters
# Clean up
for h in list(logger.handlers):
if isinstance(h, RotatingFileHandler):
logger.removeHandler(h)
h.close()
def test_no_session_filter_on_handler(self, tmp_path):
"""Handlers rely on record factory, not per-handler _SessionFilter."""
log_path = tmp_path / "no_session_filter.log"
logger = logging.getLogger("_test_no_session_filter")
formatter = logging.Formatter("%(session_tag)s%(message)s")
hermes_logging._add_rotating_handler(
logger, log_path,
level=logging.INFO, max_bytes=1024, backup_count=1,
formatter=formatter,
)
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
assert len(handlers) == 1
# No _SessionFilter on the handler — record factory handles it
assert len(handlers[0].filters) == 0
# But session_tag still works (via record factory)
hermes_logging.set_session_context("factory_test")
logger.info("test msg")
handlers[0].flush()
content = log_path.read_text()
assert "[factory_test]" in content
# Clean up
for h in list(logger.handlers):
if isinstance(h, RotatingFileHandler):
logger.removeHandler(h)
h.close()
def test_managed_mode_initial_open_sets_group_writable(self, tmp_path):
log_path = tmp_path / "managed-open.log"
logger = logging.getLogger("_test_rotating_managed_open")

View File

@@ -0,0 +1,114 @@
"""Tests for network.force_ipv4 — the socket.getaddrinfo monkey-patch."""
import importlib
import socket
from unittest.mock import patch, MagicMock
import pytest
def _reload_constants():
"""Reload hermes_constants to get a fresh apply_ipv4_preference."""
import hermes_constants
importlib.reload(hermes_constants)
return hermes_constants
class TestApplyIPv4Preference:
"""Tests for apply_ipv4_preference()."""
def setup_method(self):
"""Save the original getaddrinfo before each test."""
self._original = socket.getaddrinfo
def teardown_method(self):
"""Restore the original getaddrinfo after each test."""
socket.getaddrinfo = self._original
def test_noop_when_force_false(self):
"""No patch when force=False."""
from hermes_constants import apply_ipv4_preference
original = socket.getaddrinfo
apply_ipv4_preference(force=False)
assert socket.getaddrinfo is original
def test_patches_getaddrinfo_when_forced(self):
"""Patches socket.getaddrinfo when force=True."""
from hermes_constants import apply_ipv4_preference
original = socket.getaddrinfo
apply_ipv4_preference(force=True)
assert socket.getaddrinfo is not original
assert getattr(socket.getaddrinfo, "_hermes_ipv4_patched", False) is True
def test_double_patch_is_safe(self):
"""Calling apply twice doesn't double-wrap."""
from hermes_constants import apply_ipv4_preference
apply_ipv4_preference(force=True)
first_patch = socket.getaddrinfo
apply_ipv4_preference(force=True)
assert socket.getaddrinfo is first_patch
def test_af_unspec_becomes_af_inet(self):
"""AF_UNSPEC (default) calls get rewritten to AF_INET."""
from hermes_constants import apply_ipv4_preference
calls = []
original = socket.getaddrinfo
def mock_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
calls.append(family)
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 80))]
socket.getaddrinfo = mock_getaddrinfo
apply_ipv4_preference(force=True)
# Call with default family (AF_UNSPEC = 0)
socket.getaddrinfo("example.com", 80)
assert calls[-1] == socket.AF_INET, "AF_UNSPEC should be rewritten to AF_INET"
def test_explicit_family_preserved(self):
"""Explicit AF_INET6 requests are not intercepted."""
from hermes_constants import apply_ipv4_preference
calls = []
original = socket.getaddrinfo
def mock_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
calls.append(family)
return [(family, socket.SOCK_STREAM, 6, "", ("::1", 80))]
socket.getaddrinfo = mock_getaddrinfo
apply_ipv4_preference(force=True)
socket.getaddrinfo("example.com", 80, family=socket.AF_INET6)
assert calls[-1] == socket.AF_INET6, "Explicit AF_INET6 should pass through"
def test_fallback_on_gaierror(self):
"""Falls back to AF_UNSPEC if AF_INET resolution fails."""
from hermes_constants import apply_ipv4_preference
call_families = []
def mock_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
call_families.append(family)
if family == socket.AF_INET:
raise socket.gaierror("No A record")
# AF_UNSPEC fallback returns IPv6
return [(socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("::1", 80))]
socket.getaddrinfo = mock_getaddrinfo
apply_ipv4_preference(force=True)
result = socket.getaddrinfo("ipv6only.example.com", 80)
# Should have tried AF_INET first, then fallen back to AF_UNSPEC
assert call_families == [socket.AF_INET, 0]
assert result[0][0] == socket.AF_INET6
class TestConfigDefault:
"""Verify network section exists in DEFAULT_CONFIG."""
def test_network_section_in_default_config(self):
from hermes_cli.config import DEFAULT_CONFIG
assert "network" in DEFAULT_CONFIG
assert DEFAULT_CONFIG["network"]["force_ipv4"] is False

View File

@@ -59,8 +59,9 @@ class TestCamofoxConfigDefaults:
browser_cfg = DEFAULT_CONFIG["browser"]
assert browser_cfg["camofox"]["managed_persistence"] is False
def test_config_version_unchanged(self):
def test_config_version_matches_current_schema(self):
from hermes_cli.config import DEFAULT_CONFIG
# managed_persistence is auto-merged by _deep_merge, no version bump needed
assert DEFAULT_CONFIG["_config_version"] == 13
# The current schema version is tracked globally; unrelated default
# options may bump it after browser defaults are added.
assert DEFAULT_CONFIG["_config_version"] == 15

View File

@@ -380,7 +380,7 @@ class TestStubSchemaDrift(unittest.TestCase):
# Parameters that are internal (injected by the handler, not user-facing)
_INTERNAL_PARAMS = {"task_id", "user_task"}
# Parameters intentionally blocked in the sandbox
_BLOCKED_TERMINAL_PARAMS = {"background", "check_interval", "pty", "notify_on_complete"}
_BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete"}
def test_stubs_cover_all_schema_params(self):
"""Every user-facing parameter in the real schema must appear in the

View File

@@ -0,0 +1,295 @@
"""Tests for Modal bulk upload via tar/base64 archive."""
import asyncio
import base64
import io
import tarfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from tools.environments import modal as modal_env
def _make_mock_modal_env(monkeypatch, tmp_path):
"""Create a minimal mock ModalEnvironment for testing upload methods.
Returns a ModalEnvironment-like object with _sandbox and _worker mocked.
We don't call __init__ because it requires the Modal SDK.
"""
env = object.__new__(modal_env.ModalEnvironment)
env._sandbox = MagicMock()
env._worker = MagicMock()
env._persistent = False
env._task_id = "test"
env._sync_manager = None
return env
def _make_mock_stdin():
"""Create a mock stdin that captures written data."""
stdin = MagicMock()
written_chunks = []
def mock_write(data):
written_chunks.append(data)
stdin.write = mock_write
stdin.write_eof = MagicMock()
stdin.drain = MagicMock()
stdin.drain.aio = AsyncMock()
stdin._written_chunks = written_chunks
return stdin
def _wire_async_exec(env, exec_calls=None):
"""Wire mock sandbox.exec.aio and a real run_coroutine on the env.
Optionally captures exec call args into *exec_calls* list.
Returns (exec_calls, run_kwargs, stdin_mock).
"""
if exec_calls is None:
exec_calls = []
run_kwargs: dict = {}
stdin_mock = _make_mock_stdin()
async def mock_exec_fn(*args, **kwargs):
exec_calls.append(args)
proc = MagicMock()
proc.wait = MagicMock()
proc.wait.aio = AsyncMock(return_value=0)
proc.stdin = stdin_mock
proc.stderr = MagicMock()
proc.stderr.read = MagicMock()
proc.stderr.read.aio = AsyncMock(return_value="")
return proc
env._sandbox.exec = MagicMock()
env._sandbox.exec.aio = mock_exec_fn
def real_run_coroutine(coro, **kwargs):
run_kwargs.update(kwargs)
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
env._worker.run_coroutine = real_run_coroutine
return exec_calls, run_kwargs, stdin_mock
class TestModalBulkUpload:
"""Test _modal_bulk_upload method."""
def test_empty_files_is_noop(self, monkeypatch, tmp_path):
"""Empty file list should not call worker.run_coroutine."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
env._modal_bulk_upload([])
env._worker.run_coroutine.assert_not_called()
def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path):
"""The tar archive sent via stdin should contain all files."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
src_a = tmp_path / "a.json"
src_b = tmp_path / "b.py"
src_a.write_text("cred_content")
src_b.write_text("skill_content")
files = [
(str(src_a), "/root/.hermes/credentials/a.json"),
(str(src_b), "/root/.hermes/skills/b.py"),
]
exec_calls, _, stdin_mock = _wire_async_exec(env)
env._modal_bulk_upload(files)
# Verify the command reads from stdin (no echo with embedded payload)
assert len(exec_calls) == 1
args = exec_calls[0]
assert args[0] == "bash"
assert args[1] == "-c"
cmd = args[2]
assert "mkdir -p" in cmd
assert "base64 -d" in cmd
assert "tar xzf" in cmd
assert "-C /" in cmd
# Reassemble the base64 payload from stdin chunks and verify tar contents
payload = "".join(stdin_mock._written_chunks)
tar_data = base64.b64decode(payload)
buf = io.BytesIO(tar_data)
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
names = sorted(tar.getnames())
assert "root/.hermes/credentials/a.json" in names
assert "root/.hermes/skills/b.py" in names
# Verify content
a_content = tar.extractfile("root/.hermes/credentials/a.json").read()
assert a_content == b"cred_content"
b_content = tar.extractfile("root/.hermes/skills/b.py").read()
assert b_content == b"skill_content"
# Verify stdin was closed
stdin_mock.write_eof.assert_called_once()
def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path):
"""Remote parent directories should be pre-created in the command."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
src = tmp_path / "f.txt"
src.write_text("data")
files = [
(str(src), "/root/.hermes/credentials/f.txt"),
(str(src), "/root/.hermes/skills/deep/nested/f.txt"),
]
exec_calls, _, _ = _wire_async_exec(env)
env._modal_bulk_upload(files)
cmd = exec_calls[0][2]
assert "/root/.hermes/credentials" in cmd
assert "/root/.hermes/skills/deep/nested" in cmd
def test_single_exec_call(self, monkeypatch, tmp_path):
"""Bulk upload should use exactly one exec call regardless of file count."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
files = []
for i in range(20):
src = tmp_path / f"file_{i}.txt"
src.write_text(f"content_{i}")
files.append((str(src), f"/root/.hermes/cache/file_{i}.txt"))
exec_calls, _, _ = _wire_async_exec(env)
env._modal_bulk_upload(files)
# Should be exactly 1 exec call, not 20
assert len(exec_calls) == 1
def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch):
"""Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager."""
captured_kwargs = {}
def capture_fsm(**kwargs):
captured_kwargs.update(kwargs)
return type("M", (), {"sync": lambda self, **k: None})()
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
# Create a minimal env without full __init__
env = object.__new__(modal_env.ModalEnvironment)
env._sandbox = MagicMock()
env._worker = MagicMock()
env._persistent = False
env._task_id = "test"
# Manually call the part of __init__ that wires FileSyncManager
from tools.environments.file_sync import iter_sync_files
env._sync_manager = modal_env.FileSyncManager(
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
upload_fn=env._modal_upload,
delete_fn=env._modal_delete,
bulk_upload_fn=env._modal_bulk_upload,
)
assert "bulk_upload_fn" in captured_kwargs
assert captured_kwargs["bulk_upload_fn"] is not None
assert callable(captured_kwargs["bulk_upload_fn"])
def test_timeout_set_to_120(self, monkeypatch, tmp_path):
"""Bulk upload uses a 120s timeout (not the per-file 15s)."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
src = tmp_path / "f.txt"
src.write_text("data")
files = [(str(src), "/root/.hermes/f.txt")]
_, run_kwargs, _ = _wire_async_exec(env)
env._modal_bulk_upload(files)
assert run_kwargs.get("timeout") == 120
def test_nonzero_exit_raises(self, monkeypatch, tmp_path):
"""Non-zero exit code from remote exec should raise RuntimeError."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
src = tmp_path / "f.txt"
src.write_text("data")
files = [(str(src), "/root/.hermes/f.txt")]
stdin_mock = _make_mock_stdin()
async def mock_exec_fn(*args, **kwargs):
proc = MagicMock()
proc.wait = MagicMock()
proc.wait.aio = AsyncMock(return_value=1) # non-zero exit
proc.stdin = stdin_mock
proc.stderr = MagicMock()
proc.stderr.read = MagicMock()
proc.stderr.read.aio = AsyncMock(return_value="tar: error")
return proc
env._sandbox.exec = MagicMock()
env._sandbox.exec.aio = mock_exec_fn
def real_run_coroutine(coro, **kwargs):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
env._worker.run_coroutine = real_run_coroutine
with pytest.raises(RuntimeError, match="Modal bulk upload failed"):
env._modal_bulk_upload(files)
def test_payload_not_in_command_string(self, monkeypatch, tmp_path):
"""The base64 payload must NOT appear in the bash -c argument.
This is the core ARG_MAX fix: the payload goes through stdin,
not embedded in the command string.
"""
env = _make_mock_modal_env(monkeypatch, tmp_path)
src = tmp_path / "f.txt"
src.write_text("some data to upload")
files = [(str(src), "/root/.hermes/f.txt")]
exec_calls, _, stdin_mock = _wire_async_exec(env)
env._modal_bulk_upload(files)
# The command should NOT contain an echo with the payload
cmd = exec_calls[0][2]
assert "echo" not in cmd
# The payload should go through stdin
assert len(stdin_mock._written_chunks) > 0
def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path):
"""Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes."""
env = _make_mock_modal_env(monkeypatch, tmp_path)
# Use random bytes so gzip cannot compress them -- ensures the
# base64 payload exceeds one 1 MB chunk.
import os as _os
src = tmp_path / "large.bin"
src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024))
files = [(str(src), "/root/.hermes/large.bin")]
exec_calls, _, stdin_mock = _wire_async_exec(env)
env._modal_bulk_upload(files)
# Should have multiple stdin write chunks
assert len(stdin_mock._written_chunks) >= 2
# Reassembled payload should still decode to valid tar
payload = "".join(stdin_mock._written_chunks)
tar_data = base64.b64decode(payload)
buf = io.BytesIO(tar_data)
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
names = tar.getnames()
assert "root/.hermes/large.bin" in names

View File

@@ -289,3 +289,62 @@ class TestCodeExecutionBlocked:
def test_notify_on_complete_blocked_in_sandbox(self):
from tools.code_execution_tool import _TERMINAL_BLOCKED_PARAMS
assert "notify_on_complete" in _TERMINAL_BLOCKED_PARAMS
# =========================================================================
# Completion consumed suppression
# =========================================================================
class TestCompletionConsumed:
"""Test that wait/poll/log suppress redundant completion notifications."""
def test_wait_marks_completion_consumed(self, registry):
"""wait() returning exited status marks session as consumed."""
s = _make_session(sid="proc_wait", notify_on_complete=True, output="done")
s.exited = True
s.exit_code = 0
registry._running[s.id] = s
with patch.object(registry, "_write_checkpoint"):
registry._move_to_finished(s)
# Notification is in the queue
assert not registry.completion_queue.empty()
assert not registry.is_completion_consumed("proc_wait")
# Agent calls wait() — gets the result directly
result = registry.wait("proc_wait", timeout=1)
assert result["status"] == "exited"
# Now the completion is marked as consumed
assert registry.is_completion_consumed("proc_wait")
def test_poll_marks_completion_consumed(self, registry):
"""poll() returning exited status marks session as consumed."""
s = _make_session(sid="proc_poll", notify_on_complete=True, output="done")
s.exited = True
s.exit_code = 0
registry._finished[s.id] = s
result = registry.poll("proc_poll")
assert result["status"] == "exited"
assert registry.is_completion_consumed("proc_poll")
def test_log_marks_completion_consumed(self, registry):
"""read_log() on exited session marks as consumed."""
s = _make_session(sid="proc_log", notify_on_complete=True, output="line1\nline2")
s.exited = True
s.exit_code = 0
registry._finished[s.id] = s
result = registry.read_log("proc_log")
assert result["status"] == "exited"
assert registry.is_completion_consumed("proc_log")
def test_running_process_not_consumed(self, registry):
"""poll() on a still-running process does not mark as consumed."""
s = _make_session(sid="proc_running", notify_on_complete=True, output="partial")
registry._running[s.id] = s
result = registry.poll("proc_running")
assert result["status"] == "running"
assert not registry.is_completion_consumed("proc_running")

View File

@@ -0,0 +1,517 @@
"""Tests for SSH bulk upload via tar pipe."""
import os
import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from tools.environments import ssh as ssh_env
from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
from tools.environments.ssh import SSHEnvironment
def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
stderr_read=b""):
"""Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
m = MagicMock()
m.stdout = MagicMock()
m.returncode = returncode
m.poll.return_value = poll_return
m.communicate.return_value = communicate_return
m.stderr = MagicMock()
m.stderr.read.return_value = stderr_read
return m
@pytest.fixture
def mock_env(monkeypatch):
"""Create an SSHEnvironment with mocked connection/sync."""
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
monkeypatch.setattr(
ssh_env, "FileSyncManager",
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
)
return SSHEnvironment(host="example.com", user="testuser")
class TestSSHBulkUpload:
"""Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
def test_empty_files_is_noop(self, mock_env):
"""Empty file list should not spawn any subprocesses."""
with patch.object(subprocess, "run") as mock_run, \
patch.object(subprocess, "Popen") as mock_popen:
mock_env._ssh_bulk_upload([])
mock_run.assert_not_called()
mock_popen.assert_not_called()
def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
"""All parent directories should be created in one SSH call."""
# Create test files
f1 = tmp_path / "a.txt"
f1.write_text("aaa")
f2 = tmp_path / "b.txt"
f2.write_text("bbb")
files = [
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
(str(f2), "/home/testuser/.hermes/credentials/b.txt"),
]
# Mock subprocess.run for mkdir and Popen for tar pipe
mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
def make_proc(cmd, **kwargs):
m = MagicMock()
m.stdout = MagicMock()
m.returncode = 0
m.poll.return_value = 0
m.communicate.return_value = (b"", b"")
m.stderr = MagicMock()
m.stderr.read.return_value = b""
return m
with patch.object(subprocess, "run", mock_run), \
patch.object(subprocess, "Popen", side_effect=make_proc):
mock_env._ssh_bulk_upload(files)
# Exactly one subprocess.run call for mkdir
assert mock_run.call_count == 1
mkdir_cmd = mock_run.call_args[0][0]
# Should contain mkdir -p with both parent dirs
mkdir_str = " ".join(mkdir_cmd)
assert "mkdir -p" in mkdir_str
assert "/home/testuser/.hermes/skills" in mkdir_str
assert "/home/testuser/.hermes/credentials" in mkdir_str
def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
"""Symlinks in staging dir should mirror the remote path structure."""
f1 = tmp_path / "local_a.txt"
f1.write_text("content a")
files = [
(str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
]
staging_paths = []
def capture_tar_cmd(cmd, **kwargs):
if cmd[0] == "tar":
# Capture the staging dir from -C argument
c_idx = cmd.index("-C")
staging_dir = cmd[c_idx + 1]
# Check the symlink exists
expected = os.path.join(
staging_dir, "home/testuser/.hermes/skills/my_skill.md"
)
staging_paths.append(expected)
assert os.path.islink(expected), f"Expected symlink at {expected}"
assert os.readlink(expected) == os.path.abspath(str(f1))
mock = MagicMock()
mock.stdout = MagicMock()
mock.returncode = 0
mock.poll.return_value = 0
mock.communicate.return_value = (b"", b"")
mock.stderr = MagicMock()
mock.stderr.read.return_value = b""
return mock
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
mock_env._ssh_bulk_upload(files)
assert len(staging_paths) == 1, "tar command should have been called"
def test_tar_pipe_commands(self, mock_env, tmp_path):
"""Verify tar and SSH commands are wired correctly."""
f1 = tmp_path / "x.txt"
f1.write_text("x")
files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
popen_cmds = []
def capture_popen(cmd, **kwargs):
popen_cmds.append(cmd)
mock = MagicMock()
mock.stdout = MagicMock()
mock.returncode = 0
mock.poll.return_value = 0
mock.communicate.return_value = (b"", b"")
mock.stderr = MagicMock()
mock.stderr.read.return_value = b""
return mock
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=capture_popen):
mock_env._ssh_bulk_upload(files)
assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
tar_cmd = popen_cmds[0]
ssh_cmd = popen_cmds[1]
# tar: create, dereference symlinks, to stdout
assert tar_cmd[0] == "tar"
assert "-chf" in tar_cmd
assert "-" in tar_cmd # stdout
assert "-C" in tar_cmd
# ssh: extract from stdin at /
ssh_str = " ".join(ssh_cmd)
assert "ssh" in ssh_str
assert "tar xf - -C /" in ssh_str
assert "testuser@example.com" in ssh_str
def test_mkdir_failure_raises(self, mock_env, tmp_path):
"""mkdir failure should raise RuntimeError before tar pipe."""
f1 = tmp_path / "y.txt"
f1.write_text("y")
files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
with patch.object(subprocess, "run", return_value=failed_run):
with pytest.raises(RuntimeError, match="remote mkdir failed"):
mock_env._ssh_bulk_upload(files)
def test_tar_create_failure_raises(self, mock_env, tmp_path):
"""tar create failure should raise RuntimeError."""
f1 = tmp_path / "z.txt"
f1.write_text("z")
files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
mock_tar = MagicMock()
mock_tar.stdout = MagicMock()
mock_tar.returncode = 1
mock_tar.poll.return_value = 1
mock_tar.communicate.return_value = (b"tar: error", b"")
mock_tar.stderr = MagicMock()
mock_tar.stderr.read.return_value = b"tar: error"
mock_ssh = MagicMock()
mock_ssh.communicate.return_value = (b"", b"")
mock_ssh.returncode = 0
def popen_side_effect(cmd, **kwargs):
if cmd[0] == "tar":
return mock_tar
return mock_ssh
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
with pytest.raises(RuntimeError, match="tar create failed"):
mock_env._ssh_bulk_upload(files)
def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
"""SSH tar extract failure should raise RuntimeError."""
f1 = tmp_path / "w.txt"
f1.write_text("w")
files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
mock_tar = MagicMock()
mock_tar.stdout = MagicMock()
mock_tar.returncode = 0
mock_tar.poll.return_value = 0
mock_tar.communicate.return_value = (b"", b"")
mock_tar.stderr = MagicMock()
mock_tar.stderr.read.return_value = b""
mock_ssh = MagicMock()
mock_ssh.communicate.return_value = (b"", b"Permission denied")
mock_ssh.returncode = 1
def popen_side_effect(cmd, **kwargs):
if cmd[0] == "tar":
return mock_tar
return mock_ssh
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
mock_env._ssh_bulk_upload(files)
def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
"""SSH command for tar extract should reuse ControlMaster socket."""
f1 = tmp_path / "c.txt"
f1.write_text("c")
files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
popen_cmds = []
def capture_popen(cmd, **kwargs):
popen_cmds.append(cmd)
mock = MagicMock()
mock.stdout = MagicMock()
mock.returncode = 0
mock.poll.return_value = 0
mock.communicate.return_value = (b"", b"")
mock.stderr = MagicMock()
mock.stderr.read.return_value = b""
return mock
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=capture_popen):
mock_env._ssh_bulk_upload(files)
# The SSH command (second Popen call) should include ControlPath
ssh_cmd = popen_cmds[1]
assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
"""Bulk upload SSH command should include custom port and key."""
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
monkeypatch.setattr(
ssh_env, "FileSyncManager",
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
)
env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
f1 = tmp_path / "d.txt"
f1.write_text("d")
files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
run_cmds = []
popen_cmds = []
def capture_run(cmd, **kwargs):
run_cmds.append(cmd)
return subprocess.CompletedProcess([], 0)
def capture_popen(cmd, **kwargs):
popen_cmds.append(cmd)
mock = MagicMock()
mock.stdout = MagicMock()
mock.returncode = 0
mock.poll.return_value = 0
mock.communicate.return_value = (b"", b"")
mock.stderr = MagicMock()
mock.stderr.read.return_value = b""
return mock
with patch.object(subprocess, "run", side_effect=capture_run), \
patch.object(subprocess, "Popen", side_effect=capture_popen):
env._ssh_bulk_upload(files)
# Check mkdir SSH call includes port and key
assert len(run_cmds) == 1
mkdir_cmd = run_cmds[0]
assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
# Check tar extract SSH call includes port and key
ssh_cmd = popen_cmds[1]
assert "-p" in ssh_cmd and "2222" in ssh_cmd
assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
"""Multiple files in the same dir should produce one mkdir entry."""
f1 = tmp_path / "a.txt"
f1.write_text("a")
f2 = tmp_path / "b.txt"
f2.write_text("b")
f3 = tmp_path / "c.txt"
f3.write_text("c")
files = [
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
(str(f2), "/home/testuser/.hermes/skills/b.txt"),
(str(f3), "/home/testuser/.hermes/credentials/c.txt"),
]
run_cmds = []
def capture_run(cmd, **kwargs):
run_cmds.append(cmd)
return subprocess.CompletedProcess([], 0)
def make_mock_proc(cmd, **kwargs):
mock = MagicMock()
mock.stdout = MagicMock()
mock.returncode = 0
mock.poll.return_value = 0
mock.communicate.return_value = (b"", b"")
mock.stderr = MagicMock()
mock.stderr.read.return_value = b""
return mock
with patch.object(subprocess, "run", side_effect=capture_run), \
patch.object(subprocess, "Popen", side_effect=make_mock_proc):
mock_env._ssh_bulk_upload(files)
# Only one mkdir call
assert len(run_cmds) == 1
mkdir_str = " ".join(run_cmds[0])
# skills dir should appear exactly once despite two files
assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
assert "/home/testuser/.hermes/credentials" in mkdir_str
def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
"""tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
f1 = tmp_path / "s.txt"
f1.write_text("s")
files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
mock_tar_stdout = MagicMock()
def make_proc(cmd, **kwargs):
mock = MagicMock()
if cmd[0] == "tar":
mock.stdout = mock_tar_stdout
else:
mock.stdout = MagicMock()
mock.returncode = 0
mock.poll.return_value = 0
mock.communicate.return_value = (b"", b"")
mock.stderr = MagicMock()
mock.stderr.read.return_value = b""
return mock
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=make_proc):
mock_env._ssh_bulk_upload(files)
mock_tar_stdout.close.assert_called_once()
def test_timeout_kills_both_processes(self, mock_env, tmp_path):
"""TimeoutExpired during communicate should kill both processes."""
f1 = tmp_path / "t.txt"
f1.write_text("t")
files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
mock_tar = MagicMock()
mock_tar.stdout = MagicMock()
mock_tar.returncode = None
mock_tar.poll.return_value = None
mock_ssh = MagicMock()
mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
mock_ssh.returncode = None
def make_proc(cmd, **kwargs):
if cmd[0] == "tar":
return mock_tar
return mock_ssh
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=make_proc):
with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
mock_env._ssh_bulk_upload(files)
mock_tar.kill.assert_called_once()
mock_ssh.kill.assert_called_once()
class TestSSHBulkUploadWiring:
"""Verify bulk_upload_fn is wired into FileSyncManager."""
def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
"""SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
captured_kwargs = {}
class FakeSyncManager:
def __init__(self, **kwargs):
captured_kwargs.update(kwargs)
def sync(self, **kw):
pass
monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
env = SSHEnvironment(host="h", user="u")
assert "bulk_upload_fn" in captured_kwargs
assert captured_kwargs["bulk_upload_fn"] is not None
# Should be the bound method
assert callable(captured_kwargs["bulk_upload_fn"])
class TestSharedHelpers:
"""Direct unit tests for file_sync.py helpers."""
def test_quoted_mkdir_command_basic(self):
result = quoted_mkdir_command(["/a", "/b/c"])
assert result == "mkdir -p /a /b/c"
def test_quoted_mkdir_command_quotes_special_chars(self):
result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
assert "mkdir -p" in result
# shlex.quote wraps in single quotes
assert "'/path/with spaces'" in result
def test_quoted_mkdir_command_empty(self):
result = quoted_mkdir_command([])
assert result == "mkdir -p "
def test_unique_parent_dirs_deduplicates(self):
files = [
("/local/a.txt", "/remote/dir/a.txt"),
("/local/b.txt", "/remote/dir/b.txt"),
("/local/c.txt", "/remote/other/c.txt"),
]
result = unique_parent_dirs(files)
assert result == ["/remote/dir", "/remote/other"]
def test_unique_parent_dirs_sorted(self):
files = [
("/local/z.txt", "/z/file.txt"),
("/local/a.txt", "/a/file.txt"),
]
result = unique_parent_dirs(files)
assert result == ["/a", "/z"]
def test_unique_parent_dirs_empty(self):
assert unique_parent_dirs([]) == []
class TestSSHBulkUploadEdgeCases:
"""Edge cases for _ssh_bulk_upload."""
def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
"""If SSH Popen raises, tar process must be killed and cleaned up."""
f1 = tmp_path / "e.txt"
f1.write_text("e")
files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
mock_tar = _mock_proc()
call_count = 0
def failing_ssh_popen(cmd, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return mock_tar # tar Popen succeeds
raise OSError("SSH binary not found")
with patch.object(subprocess, "run",
return_value=subprocess.CompletedProcess([], 0)), \
patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
with pytest.raises(OSError, match="SSH binary not found"):
mock_env._ssh_bulk_upload(files)
mock_tar.kill.assert_called_once()
mock_tar.wait.assert_called_once()

View File

@@ -24,6 +24,18 @@ class TestWriteAndRead:
items[0]["content"] = "MUTATED"
assert store.read()[0]["content"] == "Task"
def test_write_deduplicates_duplicate_ids(self):
store = TodoStore()
result = store.write([
{"id": "1", "content": "First version", "status": "pending"},
{"id": "2", "content": "Other task", "status": "pending"},
{"id": "1", "content": "Latest version", "status": "in_progress"},
])
assert result == [
{"id": "2", "content": "Other task", "status": "pending"},
{"id": "1", "content": "Latest version", "status": "in_progress"},
]
class TestHasItems:
def test_empty_store(self):