Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
@@ -971,6 +971,74 @@ class TestTaskSpecificOverrides:
|
||||
client, model = get_text_auxiliary_client("compression")
|
||||
assert model == "google/gemini-3-flash-preview" # auto → OpenRouter
|
||||
|
||||
def test_resolve_auto_prefers_live_main_runtime_over_persisted_config(self, monkeypatch, tmp_path):
|
||||
"""Session-only live model switches should override persisted config for auto routing."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"""model:
|
||||
default: glm-5.1
|
||||
provider: opencode-go
|
||||
compression:
|
||||
summary_provider: auto
|
||||
"""
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
calls = []
|
||||
|
||||
def _fake_resolve(provider, model=None, *args, **kwargs):
|
||||
calls.append((provider, model, kwargs))
|
||||
return MagicMock(), model or "resolved-model"
|
||||
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", side_effect=_fake_resolve):
|
||||
client, model = _resolve_auto(
|
||||
main_runtime={
|
||||
"provider": "openai-codex",
|
||||
"model": "gpt-5.4",
|
||||
"api_mode": "codex_responses",
|
||||
}
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert model == "gpt-5.4"
|
||||
assert calls[0][0] == "openai-codex"
|
||||
assert calls[0][1] == "gpt-5.4"
|
||||
assert calls[0][2]["api_mode"] == "codex_responses"
|
||||
|
||||
def test_explicit_compression_pin_still_wins_over_live_main_runtime(self, monkeypatch, tmp_path):
|
||||
"""Task-level compression config should beat a live session override."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"""auxiliary:
|
||||
compression:
|
||||
provider: openrouter
|
||||
model: google/gemini-3-flash-preview
|
||||
model:
|
||||
default: glm-5.1
|
||||
provider: opencode-go
|
||||
"""
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(MagicMock(), "google/gemini-3-flash-preview")) as mock_resolve:
|
||||
client, model = get_text_auxiliary_client(
|
||||
"compression",
|
||||
main_runtime={
|
||||
"provider": "openai-codex",
|
||||
"model": "gpt-5.4",
|
||||
},
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.kwargs["main_runtime"] == {
|
||||
"provider": "openai-codex",
|
||||
"model": "gpt-5.4",
|
||||
}
|
||||
|
||||
def test_compression_summary_base_url_from_config(self, monkeypatch, tmp_path):
|
||||
"""compression.summary_base_url should produce a custom-endpoint client."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
@@ -1560,3 +1628,74 @@ class TestStaleBaseUrlWarning:
|
||||
|
||||
assert not any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
|
||||
"Warning should not fire a second time"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anthropic-compatible image block conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAnthropicCompatImageConversion:
|
||||
"""Tests for _is_anthropic_compat_endpoint and _convert_openai_images_to_anthropic."""
|
||||
|
||||
def test_known_providers_detected(self):
|
||||
from agent.auxiliary_client import _is_anthropic_compat_endpoint
|
||||
assert _is_anthropic_compat_endpoint("minimax", "")
|
||||
assert _is_anthropic_compat_endpoint("minimax-cn", "")
|
||||
|
||||
def test_openrouter_not_detected(self):
|
||||
from agent.auxiliary_client import _is_anthropic_compat_endpoint
|
||||
assert not _is_anthropic_compat_endpoint("openrouter", "")
|
||||
assert not _is_anthropic_compat_endpoint("anthropic", "")
|
||||
|
||||
def test_url_based_detection(self):
|
||||
from agent.auxiliary_client import _is_anthropic_compat_endpoint
|
||||
assert _is_anthropic_compat_endpoint("custom", "https://api.minimax.io/anthropic")
|
||||
assert _is_anthropic_compat_endpoint("custom", "https://example.com/anthropic/v1")
|
||||
assert not _is_anthropic_compat_endpoint("custom", "https://api.openai.com/v1")
|
||||
|
||||
def test_base64_image_converted(self):
|
||||
from agent.auxiliary_client import _convert_openai_images_to_anthropic
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR="}}
|
||||
]
|
||||
}]
|
||||
result = _convert_openai_images_to_anthropic(messages)
|
||||
img_block = result[0]["content"][1]
|
||||
assert img_block["type"] == "image"
|
||||
assert img_block["source"]["type"] == "base64"
|
||||
assert img_block["source"]["media_type"] == "image/png"
|
||||
assert img_block["source"]["data"] == "iVBOR="
|
||||
|
||||
def test_url_image_converted(self):
|
||||
from agent.auxiliary_client import _convert_openai_images_to_anthropic
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}}
|
||||
]
|
||||
}]
|
||||
result = _convert_openai_images_to_anthropic(messages)
|
||||
img_block = result[0]["content"][0]
|
||||
assert img_block["type"] == "image"
|
||||
assert img_block["source"]["type"] == "url"
|
||||
assert img_block["source"]["url"] == "https://example.com/img.jpg"
|
||||
|
||||
def test_text_only_messages_unchanged(self):
|
||||
from agent.auxiliary_client import _convert_openai_images_to_anthropic
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = _convert_openai_images_to_anthropic(messages)
|
||||
assert result[0] is messages[0] # same object, not copied
|
||||
|
||||
def test_jpeg_media_type_parsed(self):
|
||||
from agent.auxiliary_client import _convert_openai_images_to_anthropic
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/="}}
|
||||
]
|
||||
}]
|
||||
result = _convert_openai_images_to_anthropic(messages)
|
||||
assert result[0]["content"][0]["source"]["media_type"] == "image/jpeg"
|
||||
|
||||
139
tests/agent/test_compress_focus.py
Normal file
139
tests/agent/test_compress_focus.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Tests for focus_topic flowing through the compressor.
|
||||
|
||||
Verifies that _generate_summary and compress accept and use the focus_topic
|
||||
parameter correctly. Inspired by Claude Code's /compact <focus>.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_compressor():
|
||||
"""Create a ContextCompressor with minimal state for testing."""
|
||||
compressor = ContextCompressor.__new__(ContextCompressor)
|
||||
compressor.protect_first_n = 2
|
||||
compressor.protect_last_n = 5
|
||||
compressor.tail_token_budget = 20000
|
||||
compressor.context_length = 200000
|
||||
compressor.threshold_percent = 0.80
|
||||
compressor.threshold_tokens = 160000
|
||||
compressor.max_summary_tokens = 10000
|
||||
compressor.quiet_mode = True
|
||||
compressor.compression_count = 0
|
||||
compressor.last_prompt_tokens = 0
|
||||
compressor._previous_summary = None
|
||||
compressor._summary_failure_cooldown_until = 0.0
|
||||
compressor.summary_model = None
|
||||
return compressor
|
||||
|
||||
|
||||
def test_focus_topic_injected_into_summary_prompt():
|
||||
"""When focus_topic is provided, the LLM prompt includes focus guidance."""
|
||||
compressor = _make_compressor()
|
||||
turns = [
|
||||
{"role": "user", "content": "Tell me about the database schema"},
|
||||
{"role": "assistant", "content": "The schema has tables: users, orders, products."},
|
||||
]
|
||||
|
||||
captured_prompt = {}
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
captured_prompt["messages"] = kwargs["messages"]
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "## Goal\nUnderstand DB schema."
|
||||
return resp
|
||||
|
||||
with patch("agent.context_compressor.call_llm", mock_call_llm):
|
||||
result = compressor._generate_summary(turns, focus_topic="database schema")
|
||||
|
||||
assert result is not None
|
||||
prompt_text = captured_prompt["messages"][0]["content"]
|
||||
assert 'FOCUS TOPIC: "database schema"' in prompt_text
|
||||
assert "PRIORITISE" in prompt_text
|
||||
assert "60-70%" in prompt_text
|
||||
|
||||
|
||||
def test_no_focus_topic_no_injection():
|
||||
"""Without focus_topic, the prompt doesn't contain focus guidance."""
|
||||
compressor = _make_compressor()
|
||||
turns = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
|
||||
captured_prompt = {}
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
captured_prompt["messages"] = kwargs["messages"]
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "## Goal\nGreeting."
|
||||
return resp
|
||||
|
||||
with patch("agent.context_compressor.call_llm", mock_call_llm):
|
||||
result = compressor._generate_summary(turns)
|
||||
|
||||
prompt_text = captured_prompt["messages"][0]["content"]
|
||||
assert "FOCUS TOPIC" not in prompt_text
|
||||
|
||||
|
||||
def test_compress_passes_focus_to_generate_summary():
|
||||
"""compress() passes focus_topic through to _generate_summary."""
|
||||
compressor = _make_compressor()
|
||||
|
||||
# Track what _generate_summary receives
|
||||
received_kwargs = {}
|
||||
original_generate = compressor._generate_summary
|
||||
|
||||
def tracking_generate(turns, **kwargs):
|
||||
received_kwargs.update(kwargs)
|
||||
return "## Goal\nTest."
|
||||
|
||||
compressor._generate_summary = tracking_generate
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "first"},
|
||||
{"role": "assistant", "content": "reply1"},
|
||||
{"role": "user", "content": "second"},
|
||||
{"role": "assistant", "content": "reply2"},
|
||||
{"role": "user", "content": "third"},
|
||||
{"role": "assistant", "content": "reply3"},
|
||||
{"role": "user", "content": "fourth"},
|
||||
{"role": "assistant", "content": "reply4"},
|
||||
]
|
||||
|
||||
compressor.compress(messages, current_tokens=100000, focus_topic="authentication flow")
|
||||
|
||||
assert received_kwargs.get("focus_topic") == "authentication flow"
|
||||
|
||||
|
||||
def test_compress_none_focus_by_default():
|
||||
"""compress() passes None focus_topic by default."""
|
||||
compressor = _make_compressor()
|
||||
|
||||
received_kwargs = {}
|
||||
|
||||
def tracking_generate(turns, **kwargs):
|
||||
received_kwargs.update(kwargs)
|
||||
return "## Goal\nTest."
|
||||
|
||||
compressor._generate_summary = tracking_generate
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "first"},
|
||||
{"role": "assistant", "content": "reply1"},
|
||||
{"role": "user", "content": "second"},
|
||||
{"role": "assistant", "content": "reply2"},
|
||||
{"role": "user", "content": "third"},
|
||||
{"role": "assistant", "content": "reply3"},
|
||||
{"role": "user", "content": "fourth"},
|
||||
{"role": "assistant", "content": "reply4"},
|
||||
]
|
||||
|
||||
compressor.compress(messages, current_tokens=100000)
|
||||
|
||||
assert received_kwargs.get("focus_topic") is None
|
||||
@@ -191,6 +191,37 @@ class TestNonStringContent:
|
||||
kwargs = mock_call.call_args.kwargs
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
def test_summary_call_passes_live_main_runtime(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "ok"
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(
|
||||
model="gpt-5.4",
|
||||
provider="openai-codex",
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
api_key="codex-token",
|
||||
api_mode="codex_responses",
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_call:
|
||||
c._generate_summary(messages)
|
||||
|
||||
assert mock_call.call_args.kwargs["main_runtime"] == {
|
||||
"model": "gpt-5.4",
|
||||
"provider": "openai-codex",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "codex-token",
|
||||
"api_mode": "codex_responses",
|
||||
}
|
||||
|
||||
|
||||
class TestSummaryFailureCooldown:
|
||||
def test_summary_failure_enters_cooldown_and_skips_retry(self):
|
||||
@@ -576,11 +607,19 @@ class TestSummaryTargetRatio:
|
||||
assert c.summary_target_ratio == 0.80
|
||||
|
||||
def test_default_threshold_is_50_percent(self):
|
||||
"""Default compression threshold should be 50%."""
|
||||
"""Default compression threshold should be 50%, with a 64K floor."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
assert c.threshold_percent == 0.50
|
||||
assert c.threshold_tokens == 50_000
|
||||
# 50% of 100K = 50K, but the floor is 64K
|
||||
assert c.threshold_tokens == 64_000
|
||||
|
||||
def test_threshold_floor_does_not_apply_above_128k(self):
|
||||
"""On large-context models the 50% percentage is used directly."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=200_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
# 50% of 200K = 100K, which is above the 64K floor
|
||||
assert c.threshold_tokens == 100_000
|
||||
|
||||
def test_default_protect_last_n_is_20(self):
|
||||
"""Default protect_last_n should be 20."""
|
||||
|
||||
@@ -50,7 +50,8 @@ class TestEstimateTokensRough:
|
||||
assert estimate_tokens_rough("a" * 400) == 100
|
||||
|
||||
def test_short_text(self):
|
||||
assert estimate_tokens_rough("hello") == 1
|
||||
# "hello" = 5 chars → ceil(5/4) = 2
|
||||
assert estimate_tokens_rough("hello") == 2
|
||||
|
||||
def test_proportional(self):
|
||||
short = estimate_tokens_rough("hello world")
|
||||
@@ -68,10 +69,11 @@ class TestEstimateMessagesTokensRough:
|
||||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_single_message_concrete_value(self):
|
||||
"""Verify against known str(msg) length."""
|
||||
"""Verify against known str(msg) length (ceiling division)."""
|
||||
msg = {"role": "user", "content": "a" * 400}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
expected = len(str(msg)) // 4
|
||||
n = len(str(msg))
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_messages_additive(self):
|
||||
@@ -80,7 +82,8 @@ class TestEstimateMessagesTokensRough:
|
||||
{"role": "assistant", "content": "Hi there, how can I help?"},
|
||||
]
|
||||
result = estimate_messages_tokens_rough(msgs)
|
||||
expected = sum(len(str(m)) for m in msgs) // 4
|
||||
n = sum(len(str(m)) for m in msgs)
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_tool_call_message(self):
|
||||
@@ -89,7 +92,7 @@ class TestEstimateMessagesTokensRough:
|
||||
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result > 0
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
def test_message_with_list_content(self):
|
||||
"""Vision messages with multimodal content arrays."""
|
||||
@@ -98,7 +101,7 @@ class TestEstimateMessagesTokensRough:
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
|
||||
]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -87,7 +87,10 @@ class TestProviderMapping:
|
||||
|
||||
def test_unmapped_provider_not_in_dict(self):
|
||||
assert "nous" not in PROVIDER_TO_MODELS_DEV
|
||||
assert "openai-codex" not in PROVIDER_TO_MODELS_DEV
|
||||
|
||||
def test_openai_codex_mapped_to_openai(self):
|
||||
assert PROVIDER_TO_MODELS_DEV["openai"] == "openai"
|
||||
assert PROVIDER_TO_MODELS_DEV["openai-codex"] == "openai"
|
||||
|
||||
|
||||
class TestExtractContext:
|
||||
|
||||
@@ -18,6 +18,7 @@ from agent.prompt_builder import (
|
||||
build_skills_system_prompt,
|
||||
build_nous_subscription_prompt,
|
||||
build_context_files_prompt,
|
||||
build_environment_hints,
|
||||
CONTEXT_FILE_MAX_CHARS,
|
||||
DEFAULT_AGENT_IDENTITY,
|
||||
TOOL_USE_ENFORCEMENT_GUIDANCE,
|
||||
@@ -26,6 +27,7 @@ from agent.prompt_builder import (
|
||||
MEMORY_GUIDANCE,
|
||||
SESSION_SEARCH_GUIDANCE,
|
||||
PLATFORM_HINTS,
|
||||
WSL_ENVIRONMENT_HINT,
|
||||
)
|
||||
from hermes_cli.nous_subscription import NousFeatureState, NousSubscriptionFeatures
|
||||
|
||||
@@ -770,6 +772,29 @@ class TestPromptBuilderConstants:
|
||||
assert "cli" in PLATFORM_HINTS
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Environment hints
|
||||
# =========================================================================
|
||||
|
||||
class TestEnvironmentHints:
|
||||
def test_wsl_hint_constant_mentions_mnt(self):
|
||||
assert "/mnt/c/" in WSL_ENVIRONMENT_HINT
|
||||
assert "WSL" in WSL_ENVIRONMENT_HINT
|
||||
|
||||
def test_build_environment_hints_on_wsl(self, monkeypatch):
|
||||
import agent.prompt_builder as _pb
|
||||
monkeypatch.setattr(_pb, "is_wsl", lambda: True)
|
||||
result = _pb.build_environment_hints()
|
||||
assert "/mnt/" in result
|
||||
assert "WSL" in result
|
||||
|
||||
def test_build_environment_hints_not_wsl(self, monkeypatch):
|
||||
import agent.prompt_builder as _pb
|
||||
monkeypatch.setattr(_pb, "is_wsl", lambda: False)
|
||||
result = _pb.build_environment_hints()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Conditional skill activation
|
||||
# =========================================================================
|
||||
@@ -1009,65 +1034,4 @@ class TestOpenAIModelExecutionGuidance:
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestStripBudgetWarningsFromHistory:
|
||||
def test_strips_json_budget_warning_key(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "hello",
|
||||
"exit_code": 0,
|
||||
"_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
assert parsed["output"] == "hello"
|
||||
assert parsed["exit_code"] == 0
|
||||
|
||||
def test_strips_text_budget_warning(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1",
|
||||
"content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == "some result"
|
||||
|
||||
def test_leaves_non_tool_messages_unchanged(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
original_contents = [m["content"] for m in messages]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert [m["content"] for m in messages] == original_contents
|
||||
|
||||
def test_handles_empty_and_missing_content(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": ""},
|
||||
{"role": "tool", "tool_call_id": "c2"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == ""
|
||||
|
||||
def test_strips_caution_variant(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "ok",
|
||||
"_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
|
||||
118
tests/cli/test_compress_focus.py
Normal file
118
tests/cli/test_compress_focus.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Tests for /compress <focus> — guided compression with focus topic.
|
||||
|
||||
Inspired by Claude Code's /compact <focus> feature.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.cli.test_cli_init import _make_cli
|
||||
|
||||
|
||||
def _make_history() -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "user", "content": "one"},
|
||||
{"role": "assistant", "content": "two"},
|
||||
{"role": "user", "content": "three"},
|
||||
{"role": "assistant", "content": "four"},
|
||||
]
|
||||
|
||||
|
||||
def test_focus_topic_extracted_and_passed(capsys):
|
||||
"""Focus topic is extracted from the command and passed to _compress_context."""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
compressed = [history[0], history[-1]]
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent._compress_context.return_value = (compressed, "")
|
||||
|
||||
def _estimate(messages):
|
||||
if messages is history:
|
||||
return 100
|
||||
return 50
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", side_effect=_estimate):
|
||||
shell._manual_compress("/compress database schema")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert 'focus: "database schema"' in output
|
||||
|
||||
# Verify focus_topic was passed through
|
||||
shell.agent._compress_context.assert_called_once()
|
||||
call_kwargs = shell.agent._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") == "database schema"
|
||||
|
||||
|
||||
def test_no_focus_topic_when_bare_command(capsys):
|
||||
"""When no focus topic is provided, None is passed."""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent._compress_context.return_value = (list(history), "")
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
|
||||
shell._manual_compress("/compress")
|
||||
|
||||
shell.agent._compress_context.assert_called_once()
|
||||
call_kwargs = shell.agent._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") is None
|
||||
|
||||
|
||||
def test_empty_focus_after_command_treated_as_none(capsys):
|
||||
"""Trailing whitespace after /compress does not produce a focus topic."""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent._compress_context.return_value = (list(history), "")
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
|
||||
shell._manual_compress("/compress ")
|
||||
|
||||
shell.agent._compress_context.assert_called_once()
|
||||
call_kwargs = shell.agent._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") is None
|
||||
|
||||
|
||||
def test_focus_topic_printed_in_compression_banner(capsys):
|
||||
"""The focus topic shows in the compression progress banner."""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
compressed = [history[0], history[-1]]
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent._compress_context.return_value = (compressed, "")
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
|
||||
shell._manual_compress("/compress API endpoints")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert 'focus: "API endpoints"' in output
|
||||
|
||||
|
||||
def test_no_focus_prints_standard_banner(capsys):
|
||||
"""Without focus, the standard banner (no focus: line) is printed."""
|
||||
shell = _make_cli()
|
||||
history = _make_history()
|
||||
compressed = [history[0], history[-1]]
|
||||
shell.conversation_history = history
|
||||
shell.agent = MagicMock()
|
||||
shell.agent.compression_enabled = True
|
||||
shell.agent._cached_system_prompt = ""
|
||||
shell.agent._compress_context.return_value = (compressed, "")
|
||||
|
||||
with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100):
|
||||
shell._manual_compress("/compress")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "focus:" not in output
|
||||
assert "Compressing" in output
|
||||
189
tests/cli/test_tool_progress_scrollback.py
Normal file
189
tests/cli/test_tool_progress_scrollback.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for stacked tool progress scrollback lines in the CLI TUI.
|
||||
|
||||
When tool_progress_mode is "all" or "new", _on_tool_progress should print
|
||||
persistent lines to scrollback on tool.completed, restoring the stacked
|
||||
tool history that was lost when the TUI switched to a single-line spinner.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import importlib
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
# Module-level reference to the cli module (set by _make_cli on first call)
|
||||
_cli_mod = None
|
||||
|
||||
|
||||
def _make_cli(tool_progress="all"):
|
||||
"""Create a HermesCLI instance with minimal mocking."""
|
||||
global _cli_mod
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": tool_progress},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
|
||||
prompt_toolkit_stubs = {
|
||||
"prompt_toolkit": MagicMock(),
|
||||
"prompt_toolkit.history": MagicMock(),
|
||||
"prompt_toolkit.styles": MagicMock(),
|
||||
"prompt_toolkit.patch_stdout": MagicMock(),
|
||||
"prompt_toolkit.application": MagicMock(),
|
||||
"prompt_toolkit.layout": MagicMock(),
|
||||
"prompt_toolkit.layout.processors": MagicMock(),
|
||||
"prompt_toolkit.filters": MagicMock(),
|
||||
"prompt_toolkit.layout.dimension": MagicMock(),
|
||||
"prompt_toolkit.layout.menus": MagicMock(),
|
||||
"prompt_toolkit.widgets": MagicMock(),
|
||||
"prompt_toolkit.key_binding": MagicMock(),
|
||||
"prompt_toolkit.completion": MagicMock(),
|
||||
"prompt_toolkit.formatted_text": MagicMock(),
|
||||
"prompt_toolkit.auto_suggest": MagicMock(),
|
||||
}
|
||||
with patch.dict(sys.modules, prompt_toolkit_stubs), \
|
||||
patch.dict("os.environ", clean_env, clear=False):
|
||||
import cli as mod
|
||||
mod = importlib.reload(mod)
|
||||
_cli_mod = mod
|
||||
with patch.object(mod, "get_tool_definitions", return_value=[]), \
|
||||
patch.dict(mod.__dict__, {"CLI_CONFIG": _clean_config}):
|
||||
return mod.HermesCLI()
|
||||
|
||||
|
||||
class TestToolProgressScrollback:
|
||||
"""Stacked scrollback lines for 'all' and 'new' modes."""
|
||||
|
||||
def test_all_mode_prints_scrollback_on_completed(self):
|
||||
"""In 'all' mode, tool.completed prints a stacked line."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
# Simulate tool.started
|
||||
cli._on_tool_progress("tool.started", "terminal", "git log", {"command": "git log"})
|
||||
# Simulate tool.completed
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=1.5, is_error=False)
|
||||
|
||||
mock_print.assert_called_once()
|
||||
line = mock_print.call_args[0][0]
|
||||
# Should contain tool info (the cute message format has "git log" for terminal)
|
||||
assert "git log" in line or "$" in line
|
||||
|
||||
def test_all_mode_prints_every_call(self):
|
||||
"""In 'all' mode, consecutive calls to the same tool each get a line."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
# First call
|
||||
cli._on_tool_progress("tool.started", "read_file", "cli.py", {"path": "cli.py"})
|
||||
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.1, is_error=False)
|
||||
# Second call (same tool)
|
||||
cli._on_tool_progress("tool.started", "read_file", "run_agent.py", {"path": "run_agent.py"})
|
||||
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.2, is_error=False)
|
||||
|
||||
assert mock_print.call_count == 2
|
||||
|
||||
def test_new_mode_skips_consecutive_repeats(self):
|
||||
"""In 'new' mode, consecutive calls to the same tool only print once."""
|
||||
cli = _make_cli(tool_progress="new")
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.started", "read_file", "cli.py", {"path": "cli.py"})
|
||||
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.1, is_error=False)
|
||||
cli._on_tool_progress("tool.started", "read_file", "run_agent.py", {"path": "run_agent.py"})
|
||||
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.2, is_error=False)
|
||||
|
||||
assert mock_print.call_count == 1 # Only the first read_file
|
||||
|
||||
def test_new_mode_prints_when_tool_changes(self):
|
||||
"""In 'new' mode, a different tool name triggers a new line."""
|
||||
cli = _make_cli(tool_progress="new")
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.started", "read_file", "cli.py", {"path": "cli.py"})
|
||||
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.1, is_error=False)
|
||||
cli._on_tool_progress("tool.started", "search_files", "pattern", {"pattern": "test"})
|
||||
cli._on_tool_progress("tool.completed", "search_files", None, None, duration=0.3, is_error=False)
|
||||
cli._on_tool_progress("tool.started", "read_file", "run_agent.py", {"path": "run_agent.py"})
|
||||
cli._on_tool_progress("tool.completed", "read_file", None, None, duration=0.2, is_error=False)
|
||||
|
||||
# read_file, search_files, read_file (3rd prints because search_files broke the streak)
|
||||
assert mock_print.call_count == 3
|
||||
|
||||
def test_off_mode_no_scrollback(self):
|
||||
"""In 'off' mode, no stacked lines are printed."""
|
||||
cli = _make_cli(tool_progress="off")
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=False)
|
||||
|
||||
mock_print.assert_not_called()
|
||||
|
||||
def test_error_suffix_on_failed_tool(self):
|
||||
"""When is_error=True, the stacked line includes [error]."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
cli._on_tool_progress("tool.started", "terminal", "bad cmd", {"command": "bad cmd"})
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=True)
|
||||
|
||||
line = mock_print.call_args[0][0]
|
||||
assert "[error]" in line
|
||||
|
||||
def test_spinner_still_updates_on_started(self):
|
||||
"""tool.started still updates the spinner text for live display."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
cli._on_tool_progress("tool.started", "terminal", "git status", {"command": "git status"})
|
||||
assert "git status" in cli._spinner_text
|
||||
|
||||
def test_spinner_timer_clears_on_completed(self):
|
||||
"""tool.completed still clears the tool timer."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
cli._on_tool_progress("tool.started", "terminal", "git status", {"command": "git status"})
|
||||
assert cli._tool_start_time > 0
|
||||
with patch.object(_cli_mod, "_cprint"):
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=False)
|
||||
assert cli._tool_start_time == 0.0
|
||||
|
||||
def test_concurrent_tools_produce_stacked_lines(self):
|
||||
"""Multiple tool.started followed by multiple tool.completed all produce lines."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
# All start first (concurrent pattern)
|
||||
cli._on_tool_progress("tool.started", "web_search", "query 1", {"query": "test 1"})
|
||||
cli._on_tool_progress("tool.started", "web_search", "query 2", {"query": "test 2"})
|
||||
# All complete
|
||||
cli._on_tool_progress("tool.completed", "web_search", None, None, duration=1.0, is_error=False)
|
||||
cli._on_tool_progress("tool.completed", "web_search", None, None, duration=1.5, is_error=False)
|
||||
|
||||
assert mock_print.call_count == 2
|
||||
|
||||
def test_verbose_mode_no_duplicate_scrollback(self):
|
||||
"""In 'verbose' mode, scrollback lines are NOT printed (run_agent handles verbose output)."""
|
||||
cli = _make_cli(tool_progress="verbose")
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=False)
|
||||
|
||||
mock_print.assert_not_called()
|
||||
|
||||
def test_pending_info_stores_on_started(self):
|
||||
"""tool.started stores args for later use by tool.completed."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
|
||||
assert "terminal" in cli._pending_tool_info
|
||||
assert len(cli._pending_tool_info["terminal"]) == 1
|
||||
assert cli._pending_tool_info["terminal"][0] == {"command": "ls"}
|
||||
|
||||
def test_pending_info_consumed_on_completed(self):
|
||||
"""tool.completed consumes stored args (FIFO for concurrent)."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
cli._on_tool_progress("tool.started", "terminal", "ls", {"command": "ls"})
|
||||
cli._on_tool_progress("tool.started", "terminal", "pwd", {"command": "pwd"})
|
||||
assert len(cli._pending_tool_info["terminal"]) == 2
|
||||
with patch.object(_cli_mod, "_cprint"):
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.1, is_error=False)
|
||||
# First entry consumed, second remains
|
||||
assert len(cli._pending_tool_info.get("terminal", [])) == 1
|
||||
assert cli._pending_tool_info["terminal"][0] == {"command": "pwd"}
|
||||
226
tests/gateway/test_clean_shutdown_marker.py
Normal file
226
tests/gateway/test_clean_shutdown_marker.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Tests for the clean shutdown marker that prevents unwanted session auto-resets.
|
||||
|
||||
When the gateway shuts down gracefully (hermes update, gateway restart, /restart),
|
||||
it writes a .clean_shutdown marker. On the next startup, if the marker exists,
|
||||
suspend_recently_active() is skipped so users don't lose their sessions.
|
||||
|
||||
After a crash (no marker), suspension still fires as a safety net for stuck sessions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, SessionResetPolicy
|
||||
from gateway.session import SessionEntry, SessionSource, SessionStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
|
||||
return SessionSource(platform=platform, chat_id=chat_id, user_id=user_id)
|
||||
|
||||
|
||||
def _make_store(tmp_path, policy=None):
|
||||
config = GatewayConfig()
|
||||
if policy:
|
||||
config.default_reset_policy = policy
|
||||
return SessionStore(sessions_dir=tmp_path, config=config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionStore.suspend_recently_active
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSuspendRecentlyActive:
|
||||
"""Verify suspend_recently_active only marks recent sessions."""
|
||||
|
||||
def test_suspends_recently_active_sessions(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
assert not entry.suspended
|
||||
|
||||
count = store.suspend_recently_active()
|
||||
assert count == 1
|
||||
|
||||
# Re-fetch — should be suspended now
|
||||
refreshed = store.get_or_create_session(source)
|
||||
assert refreshed.was_auto_reset
|
||||
|
||||
def test_does_not_suspend_old_sessions(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
|
||||
# Backdate the session's updated_at beyond the cutoff
|
||||
with store._lock:
|
||||
entry.updated_at = datetime.now() - timedelta(seconds=300)
|
||||
store._save()
|
||||
|
||||
count = store.suspend_recently_active(max_age_seconds=120)
|
||||
assert count == 0
|
||||
|
||||
def test_already_suspended_not_double_counted(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
|
||||
# Suspend once
|
||||
count1 = store.suspend_recently_active()
|
||||
assert count1 == 1
|
||||
|
||||
# Create a new session (the old one got reset on next access)
|
||||
entry2 = store.get_or_create_session(source)
|
||||
|
||||
# Suspend again — the new session is recent but not yet suspended
|
||||
count2 = store.suspend_recently_active()
|
||||
assert count2 == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clean shutdown marker integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanShutdownMarker:
|
||||
"""Test that the marker file controls session suspension on startup."""
|
||||
|
||||
def test_marker_written_on_graceful_stop(self, tmp_path, monkeypatch):
|
||||
"""stop() should write .clean_shutdown marker."""
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
marker = tmp_path / ".clean_shutdown"
|
||||
assert not marker.exists()
|
||||
|
||||
# Create a minimal runner and call the shutdown logic directly
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._restart_requested = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_task_started = False
|
||||
runner._running = True
|
||||
runner._draining = False
|
||||
runner._stop_task = None
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._background_tasks = set()
|
||||
runner._shutdown_event = MagicMock()
|
||||
runner._restart_drain_timeout = 5
|
||||
runner._exit_code = None
|
||||
runner._exit_reason = None
|
||||
runner.adapters = {}
|
||||
runner.config = GatewayConfig()
|
||||
|
||||
# Mock heavy dependencies
|
||||
with patch("gateway.run.GatewayRunner._drain_active_agents", new_callable=AsyncMock, return_value=([], False)), \
|
||||
patch("gateway.run.GatewayRunner._finalize_shutdown_agents"), \
|
||||
patch("gateway.run.GatewayRunner._update_runtime_status"), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("tools.process_registry.process_registry") as mock_proc_reg, \
|
||||
patch("tools.terminal_tool.cleanup_all_environments"), \
|
||||
patch("tools.browser_tool.cleanup_all_browsers"):
|
||||
mock_proc_reg.kill_all = MagicMock()
|
||||
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(runner.stop())
|
||||
|
||||
assert marker.exists(), ".clean_shutdown marker should exist after graceful stop"
|
||||
|
||||
def test_marker_skips_suspension_on_startup(self, tmp_path, monkeypatch):
|
||||
"""If .clean_shutdown exists, suspend_recently_active should NOT be called."""
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
|
||||
# Create the marker
|
||||
marker = tmp_path / ".clean_shutdown"
|
||||
marker.touch()
|
||||
|
||||
# Create a store with a recently active session
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
assert not entry.suspended
|
||||
|
||||
# Simulate what start() does:
|
||||
if marker.exists():
|
||||
marker.unlink()
|
||||
# Should NOT call suspend_recently_active
|
||||
else:
|
||||
store.suspend_recently_active()
|
||||
|
||||
# Session should NOT be suspended
|
||||
with store._lock:
|
||||
store._ensure_loaded_locked()
|
||||
for e in store._entries.values():
|
||||
assert not e.suspended, "Session should NOT be suspended after clean shutdown"
|
||||
|
||||
assert not marker.exists(), "Marker should be cleaned up"
|
||||
|
||||
def test_no_marker_triggers_suspension(self, tmp_path, monkeypatch):
|
||||
"""Without .clean_shutdown marker (crash), suspension should fire."""
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
|
||||
marker = tmp_path / ".clean_shutdown"
|
||||
assert not marker.exists()
|
||||
|
||||
# Create a store with a recently active session
|
||||
store = _make_store(tmp_path)
|
||||
source = _make_source()
|
||||
entry = store.get_or_create_session(source)
|
||||
assert not entry.suspended
|
||||
|
||||
# Simulate what start() does:
|
||||
if marker.exists():
|
||||
marker.unlink()
|
||||
else:
|
||||
store.suspend_recently_active()
|
||||
|
||||
# Session SHOULD be suspended (crash recovery)
|
||||
with store._lock:
|
||||
store._ensure_loaded_locked()
|
||||
suspended_count = sum(1 for e in store._entries.values() if e.suspended)
|
||||
assert suspended_count == 1, "Session should be suspended after crash (no marker)"
|
||||
|
||||
def test_marker_written_on_restart_stop(self, tmp_path, monkeypatch):
|
||||
"""stop(restart=True) should also write the marker."""
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
marker = tmp_path / ".clean_shutdown"
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._restart_requested = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_task_started = False
|
||||
runner._running = True
|
||||
runner._draining = False
|
||||
runner._stop_task = None
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._background_tasks = set()
|
||||
runner._shutdown_event = MagicMock()
|
||||
runner._restart_drain_timeout = 5
|
||||
runner._exit_code = None
|
||||
runner._exit_reason = None
|
||||
runner.adapters = {}
|
||||
runner.config = GatewayConfig()
|
||||
|
||||
with patch("gateway.run.GatewayRunner._drain_active_agents", new_callable=AsyncMock, return_value=([], False)), \
|
||||
patch("gateway.run.GatewayRunner._finalize_shutdown_agents"), \
|
||||
patch("gateway.run.GatewayRunner._update_runtime_status"), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("tools.process_registry.process_registry") as mock_proc_reg, \
|
||||
patch("tools.terminal_tool.cleanup_all_environments"), \
|
||||
patch("tools.browser_tool.cleanup_all_browsers"):
|
||||
mock_proc_reg.kill_all = MagicMock()
|
||||
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(runner.stop(restart=True))
|
||||
|
||||
assert marker.exists(), ".clean_shutdown marker should exist after restart-stop too"
|
||||
118
tests/gateway/test_compress_focus.py
Normal file
118
tests/gateway/test_compress_focus.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Tests for gateway /compress <focus> — focus topic on the gateway side."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str = "/compress") -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_history() -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "user", "content": "one"},
|
||||
{"role": "assistant", "content": "two"},
|
||||
{"role": "user", "content": "three"},
|
||||
{"role": "assistant", "content": "four"},
|
||||
]
|
||||
|
||||
|
||||
def _make_runner(history: list[dict[str, str]]):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = history
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner.session_store._save = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_focus_topic_passed_to_agent():
|
||||
"""Focus topic from /compress <focus> is passed through to _compress_context."""
|
||||
history = _make_history()
|
||||
compressed = [history[0], history[-1]]
|
||||
runner = _make_runner(history)
|
||||
agent_instance = MagicMock()
|
||||
agent_instance.context_compressor.protect_first_n = 0
|
||||
agent_instance.context_compressor._align_boundary_forward.return_value = 0
|
||||
agent_instance.context_compressor._find_tail_cut_by_tokens.return_value = 2
|
||||
agent_instance.session_id = "sess-1"
|
||||
agent_instance._compress_context.return_value = (compressed, "")
|
||||
|
||||
def _estimate(messages):
|
||||
return 100
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "***"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=agent_instance),
|
||||
patch("agent.model_metadata.estimate_messages_tokens_rough", side_effect=_estimate),
|
||||
):
|
||||
result = await runner._handle_compress_command(_make_event("/compress database schema"))
|
||||
|
||||
# Verify focus_topic was passed
|
||||
agent_instance._compress_context.assert_called_once()
|
||||
call_kwargs = agent_instance._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") == "database schema"
|
||||
|
||||
# Verify focus is mentioned in response
|
||||
assert 'Focus: "database schema"' in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_no_focus_passes_none():
|
||||
"""Bare /compress passes focus_topic=None."""
|
||||
history = _make_history()
|
||||
runner = _make_runner(history)
|
||||
agent_instance = MagicMock()
|
||||
agent_instance.context_compressor.protect_first_n = 0
|
||||
agent_instance.context_compressor._align_boundary_forward.return_value = 0
|
||||
agent_instance.context_compressor._find_tail_cut_by_tokens.return_value = 2
|
||||
agent_instance.session_id = "sess-1"
|
||||
agent_instance._compress_context.return_value = (list(history), "")
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "***"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=agent_instance),
|
||||
patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=100),
|
||||
):
|
||||
result = await runner._handle_compress_command(_make_event("/compress"))
|
||||
|
||||
agent_instance._compress_context.assert_called_once()
|
||||
call_kwargs = agent_instance._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") is None
|
||||
|
||||
# No focus line in response
|
||||
assert "Focus:" not in result
|
||||
@@ -74,6 +74,26 @@ class FakeBot:
|
||||
return None
|
||||
|
||||
|
||||
class SlowSyncTree(FakeTree):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.started = asyncio.Event()
|
||||
self.allow_finish = asyncio.Event()
|
||||
|
||||
async def _slow_sync():
|
||||
self.started.set()
|
||||
await self.allow_finish.wait()
|
||||
return []
|
||||
|
||||
self.sync = AsyncMock(side_effect=_slow_sync)
|
||||
|
||||
|
||||
class SlowSyncBot(FakeBot):
|
||||
def __init__(self, *, intents, proxy=None):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
self.tree = SlowSyncTree()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("allowed_users", "expected_members_intent"),
|
||||
@@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
||||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_does_not_wait_for_slash_sync(monkeypatch):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None))
|
||||
monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None)
|
||||
|
||||
intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False)
|
||||
monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents)
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
bot = SlowSyncBot(intents=intents, proxy=proxy)
|
||||
created["bot"] = bot
|
||||
return bot
|
||||
|
||||
monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory)
|
||||
monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock())
|
||||
|
||||
ok = await asyncio.wait_for(adapter.connect(), timeout=1.0)
|
||||
|
||||
assert ok is True
|
||||
assert adapter._ready_event.is_set()
|
||||
|
||||
await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0)
|
||||
assert created["bot"].tree.sync.await_count == 1
|
||||
|
||||
created["bot"].tree.allow_finish.set()
|
||||
await asyncio.sleep(0)
|
||||
await adapter.disconnect()
|
||||
|
||||
355
tests/gateway/test_display_config.py
Normal file
355
tests/gateway/test_display_config.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Tests for gateway.display_config — per-platform display/verbosity resolver."""
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolver: resolution order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveDisplaySetting:
|
||||
"""resolve_display_setting() resolves with correct priority."""
|
||||
|
||||
def test_explicit_platform_override_wins(self):
|
||||
"""display.platforms.<plat>.<key> takes top priority."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_global_setting_when_no_platform_override(self):
|
||||
"""Falls back to display.<key> when no platform override exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"platforms": {},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_platform_default_when_no_user_config(self):
|
||||
"""Falls back to built-in platform default."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
# Empty config — should get built-in defaults
|
||||
config = {}
|
||||
# Telegram defaults to tier_high → "all"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
# Email defaults to tier_minimal → "off"
|
||||
assert resolve_display_setting(config, "email", "tool_progress") == "off"
|
||||
|
||||
def test_global_default_for_unknown_platform(self):
|
||||
"""Unknown platforms get the global defaults."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Unknown platform, no config → global default "all"
|
||||
assert resolve_display_setting(config, "unknown_platform", "tool_progress") == "all"
|
||||
|
||||
def test_fallback_parameter_used_last(self):
|
||||
"""Explicit fallback is used when nothing else matches."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# "nonexistent_key" isn't in any defaults
|
||||
result = resolve_display_setting(config, "telegram", "nonexistent_key", "my_fallback")
|
||||
assert result == "my_fallback"
|
||||
|
||||
def test_platform_override_only_affects_that_platform(self):
|
||||
"""Other platforms are unaffected by a specific platform override."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"slack": {"tool_progress": "off"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility: tool_progress_overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackwardCompat:
|
||||
"""Legacy tool_progress_overrides is still respected as a fallback."""
|
||||
|
||||
def test_legacy_overrides_read(self):
|
||||
"""tool_progress_overrides is read when no platforms entry exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "verbose",
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "signal", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_new_platforms_takes_precedence_over_legacy(self):
|
||||
"""display.platforms beats tool_progress_overrides."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
"platforms": {"telegram": {"tool_progress": "new"}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_legacy_overrides_only_for_tool_progress(self):
|
||||
"""Legacy overrides don't affect other settings."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
}
|
||||
}
|
||||
# show_reasoning should NOT read from tool_progress_overrides
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestYAMLNormalisation:
|
||||
"""YAML 1.1 quirks (bare off → False, on → True) are handled."""
|
||||
|
||||
def test_tool_progress_false_normalised_to_off(self):
|
||||
"""YAML's bare `off` parses as False — normalised to 'off' string."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": False}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "off"
|
||||
|
||||
def test_tool_progress_true_normalised_to_all(self):
|
||||
"""YAML's bare `on` parses as True — normalised to 'all'."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": True}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
def test_show_reasoning_string_true(self):
|
||||
"""String 'true' is normalised to bool True."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"telegram": {"show_reasoning": "true"}}}}
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is True
|
||||
|
||||
def test_tool_preview_length_string(self):
|
||||
"""String numbers are normalised to int."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_preview_length": "80"}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_preview_length") == 80
|
||||
|
||||
def test_platform_override_false_tool_progress(self):
|
||||
"""Per-platform bare off → normalised."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_progress": False}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in platform defaults (tier system)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPlatformDefaults:
|
||||
"""Built-in defaults reflect platform capability tiers."""
|
||||
|
||||
def test_high_tier_platforms(self):
|
||||
"""Telegram and Discord default to 'all' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("telegram", "discord"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "all", plat
|
||||
|
||||
def test_medium_tier_platforms(self):
|
||||
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
|
||||
|
||||
def test_low_tier_platforms(self):
|
||||
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_minimal_tier_platforms(self):
|
||||
"""Email, SMS, webhook default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("email", "sms", "webhook", "homeassistant"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_low_tier_streaming_defaults_to_false(self):
|
||||
"""Low-tier platforms default streaming to False."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "signal", "streaming") is False
|
||||
assert resolve_display_setting({}, "email", "streaming") is False
|
||||
|
||||
def test_high_tier_streaming_defaults_to_none(self):
|
||||
"""High-tier platforms default streaming to None (follow global)."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "telegram", "streaming") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_display / get_platform_defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHelpers:
|
||||
"""Helper functions return correct composite results."""
|
||||
|
||||
def test_get_effective_display_merges_correctly(self):
|
||||
from gateway.display_config import get_effective_display
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"show_reasoning": True,
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
eff = get_effective_display(config, "telegram")
|
||||
assert eff["tool_progress"] == "verbose" # platform override
|
||||
assert eff["show_reasoning"] is True # global
|
||||
assert "tool_preview_length" in eff # default filled in
|
||||
|
||||
def test_get_platform_defaults_returns_dict(self):
|
||||
from gateway.display_config import get_platform_defaults
|
||||
|
||||
defaults = get_platform_defaults("telegram")
|
||||
assert "tool_progress" in defaults
|
||||
assert "show_reasoning" in defaults
|
||||
# Returns a new dict (not the shared tier dict)
|
||||
defaults["tool_progress"] = "changed"
|
||||
assert get_platform_defaults("telegram")["tool_progress"] != "changed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config migration: tool_progress_overrides → display.platforms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigMigration:
|
||||
"""Version 16 migration moves tool_progress_overrides into display.platforms."""
|
||||
|
||||
def test_migration_creates_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Old overrides are migrated into display.platforms.<plat>.tool_progress."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "all",
|
||||
},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Re-import to pick up the new HERMES_HOME
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
result = cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
# Re-read config
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
platforms = updated.get("display", {}).get("platforms", {})
|
||||
assert platforms.get("signal", {}).get("tool_progress") == "off"
|
||||
assert platforms.get("telegram", {}).get("tool_progress") == "all"
|
||||
|
||||
def test_migration_preserves_existing_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Existing display.platforms entries are NOT overwritten by migration."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "off"},
|
||||
"platforms": {"telegram": {"tool_progress": "verbose"}},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
# Existing "verbose" should NOT be overwritten by legacy "off"
|
||||
assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming per-platform (None = follow global)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStreamingPerPlatform:
|
||||
"""Streaming per-platform override semantics."""
|
||||
|
||||
def test_none_means_follow_global(self):
|
||||
"""When streaming is None, the caller should use global config."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Telegram has no streaming override in defaults → None
|
||||
result = resolve_display_setting(config, "telegram", "streaming")
|
||||
assert result is None # caller should check global StreamingConfig
|
||||
|
||||
def test_explicit_false_disables(self):
|
||||
"""Explicit False disables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"telegram": {"streaming": False}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "streaming") is False
|
||||
|
||||
def test_explicit_true_enables(self):
|
||||
"""Explicit True enables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"email": {"streaming": True}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "email", "streaming") is True
|
||||
@@ -28,12 +28,16 @@ class _FakeRegistry:
|
||||
|
||||
def __init__(self, sessions):
|
||||
self._sessions = list(sessions)
|
||||
self._completion_consumed: set = set()
|
||||
|
||||
def get(self, session_id):
|
||||
if self._sessions:
|
||||
return self._sessions.pop(0)
|
||||
return None
|
||||
|
||||
def is_completion_consumed(self, session_id):
|
||||
return session_id in self._completion_consumed
|
||||
|
||||
|
||||
def _build_runner(monkeypatch, tmp_path) -> GatewayRunner:
|
||||
"""Create a GatewayRunner with notifications set to 'all'."""
|
||||
|
||||
@@ -157,12 +157,44 @@ def _make_fake_mautrix():
|
||||
mautrix_crypto_store = types.ModuleType("mautrix.crypto.store")
|
||||
|
||||
class MemoryCryptoStore:
|
||||
def __init__(self, account_id="", pickle_key=""):
|
||||
def __init__(self, account_id="", pickle_key=""): # noqa: S301
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
|
||||
mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore
|
||||
|
||||
# --- mautrix.crypto.store.asyncpg ---
|
||||
mautrix_crypto_store_asyncpg = types.ModuleType("mautrix.crypto.store.asyncpg")
|
||||
|
||||
class PgCryptoStore:
|
||||
upgrade_table = MagicMock()
|
||||
|
||||
def __init__(self, account_id="", pickle_key="", db=None): # noqa: S301
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
self.db = db
|
||||
|
||||
async def open(self):
|
||||
pass
|
||||
|
||||
mautrix_crypto_store_asyncpg.PgCryptoStore = PgCryptoStore
|
||||
|
||||
# --- mautrix.util ---
|
||||
mautrix_util = types.ModuleType("mautrix.util")
|
||||
|
||||
# --- mautrix.util.async_db ---
|
||||
mautrix_util_async_db = types.ModuleType("mautrix.util.async_db")
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
def create(cls, url, upgrade_table=None):
|
||||
db = MagicMock()
|
||||
db.start = AsyncMock()
|
||||
db.stop = AsyncMock()
|
||||
return db
|
||||
|
||||
mautrix_util_async_db.Database = Database
|
||||
|
||||
return {
|
||||
"mautrix": mautrix,
|
||||
"mautrix.api": mautrix_api,
|
||||
@@ -171,6 +203,9 @@ def _make_fake_mautrix():
|
||||
"mautrix.client.state_store": mautrix_client_state_store,
|
||||
"mautrix.crypto": mautrix_crypto,
|
||||
"mautrix.crypto.store": mautrix_crypto_store,
|
||||
"mautrix.crypto.store.asyncpg": mautrix_crypto_store_asyncpg,
|
||||
"mautrix.util": mautrix_util,
|
||||
"mautrix.util.async_db": mautrix_util_async_db,
|
||||
}
|
||||
|
||||
|
||||
@@ -740,6 +775,12 @@ class TestMatrixAccessTokenAuth:
|
||||
mock_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"DEV123": {
|
||||
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
@@ -751,6 +792,8 @@ class TestMatrixAccessTokenAuth:
|
||||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
# Patch Client constructor to return our mock
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
@@ -924,6 +967,12 @@ class TestMatrixDeviceId:
|
||||
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="WHOAMI_DEV"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"MY_STABLE_DEVICE": {
|
||||
"keys": {"ed25519:MY_STABLE_DEVICE": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
@@ -934,6 +983,8 @@ class TestMatrixDeviceId:
|
||||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
|
||||
@@ -1030,8 +1081,8 @@ class TestMatrixDeviceIdConfig:
|
||||
|
||||
class TestMatrixSyncLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_loop_shares_keys_when_encryption_enabled(self):
|
||||
"""_sync_loop should call crypto.share_keys() after each sync."""
|
||||
async def test_sync_loop_dispatches_events_and_stores_token(self):
|
||||
"""_sync_loop should call handle_sync() and persist next_batch."""
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._closing = False
|
||||
@@ -1046,7 +1097,6 @@ class TestMatrixSyncLoop:
|
||||
return {"rooms": {"join": {"!room:example.org": {}}}, "next_batch": "s1234"}
|
||||
|
||||
mock_crypto = MagicMock()
|
||||
mock_crypto.share_keys = AsyncMock()
|
||||
|
||||
mock_sync_store = MagicMock()
|
||||
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
|
||||
@@ -1062,7 +1112,6 @@ class TestMatrixSyncLoop:
|
||||
await adapter._sync_loop()
|
||||
|
||||
fake_client.sync.assert_awaited_once()
|
||||
mock_crypto.share_keys.assert_awaited_once()
|
||||
fake_client.handle_sync.assert_called_once()
|
||||
mock_sync_store.put_next_batch.assert_awaited_once_with("s1234")
|
||||
|
||||
@@ -1248,6 +1297,12 @@ class TestMatrixEncryptedEventHandler:
|
||||
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"DEV123": {
|
||||
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
@@ -1258,6 +1313,8 @@ class TestMatrixEncryptedEventHandler:
|
||||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
|
||||
|
||||
@@ -8,8 +8,8 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig, StreamingConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
@@ -104,6 +104,11 @@ def _make_runner(adapter):
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
group_sessions_per_user=False,
|
||||
stt_enabled=False,
|
||||
)
|
||||
return runner
|
||||
|
||||
|
||||
@@ -118,6 +123,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
@@ -144,7 +150,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": '⚙️ terminal: "pwd"',
|
||||
"content": '💻 terminal: "pwd"',
|
||||
"reply_to": None,
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
@@ -334,3 +340,238 @@ def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
||||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
|
||||
class CommentaryAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.stream_delta_callback = kwargs.get("stream_delta_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
time.sleep(0.1)
|
||||
if self.stream_delta_callback:
|
||||
self.stream_delta_callback("done")
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class PreviewedResponseAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("You're welcome.", already_streamed=False)
|
||||
return {
|
||||
"final_response": "You're welcome.",
|
||||
"response_previewed": True,
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class QueuedCommentaryAgent:
|
||||
calls = 0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
type(self).calls += 1
|
||||
if type(self).calls == 1 and self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
return {
|
||||
"final_response": f"final response {type(self).calls}",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
async def _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
agent_cls,
|
||||
*,
|
||||
session_id,
|
||||
pending_text=None,
|
||||
config_data=None,
|
||||
):
|
||||
if config_data:
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config_data), encoding="utf-8")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = agent_cls
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
if config_data and "streaming" in config_data:
|
||||
runner.config.streaming = StreamingConfig.from_dict(config_data["streaming"])
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
session_key = "agent:main:telegram:group:-1001:17585"
|
||||
if pending_text is not None:
|
||||
adapter._pending_messages[session_key] = MessageEvent(
|
||||
text=pending_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="queued-1",
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_real_interim_commentary(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_interim_commentary_by_default(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-default-on",
|
||||
)
|
||||
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_suppresses_interim_commentary_when_disabled(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-disabled",
|
||||
config_data={"display": {"interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_tool_progress_does_not_control_interim_commentary(monkeypatch, tmp_path):
|
||||
"""tool_progress=all with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-tool-progress",
|
||||
config_data={"display": {"tool_progress": "all", "interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_streaming_does_not_enable_completed_interim_commentary(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""Streaming alone with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-streaming",
|
||||
config_data={
|
||||
"display": {"tool_progress": "off", "interim_assistant_messages": False},
|
||||
"streaming": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_interim_commentary_works_with_tool_progress_off(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-explicit-on",
|
||||
config_data={
|
||||
"display": {
|
||||
"tool_progress": "off",
|
||||
"interim_assistant_messages": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_previewed_final_marks_already_sent(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
PreviewedResponseAgent,
|
||||
session_id="sess-previewed",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert [call["content"] for call in adapter.sent] == ["You're welcome."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
|
||||
QueuedCommentaryAgent.calls = 0
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
QueuedCommentaryAgent,
|
||||
session_id="sess-queued-commentary",
|
||||
pending_text="queued follow-up",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "final response 2"
|
||||
assert "I'll inspect the repo first." in sent_texts
|
||||
assert "final response 1" in sent_texts
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
@@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _SuccessfulAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
@@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch,
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["state"] == "retrying"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
@@ -89,6 +107,64 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey
|
||||
assert state["gateway_state"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
|
||||
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None)
|
||||
monkeypatch.setattr(runner.hooks, "emit", AsyncMock())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
assert state["platforms"]["discord"]["state"] == "connected"
|
||||
assert state["platforms"]["discord"]["error_code"] is None
|
||||
assert state["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_verbosity_imports_redacting_formatter(monkeypatch, tmp_path):
|
||||
"""Verbosity != None must not crash with NameError on RedactingFormatter (#8044)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
class _CleanExitRunner:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.should_exit_cleanly = True
|
||||
self.exit_reason = None
|
||||
self.adapters = {}
|
||||
|
||||
async def start(self):
|
||||
return True
|
||||
|
||||
async def stop(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: None)
|
||||
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
|
||||
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
|
||||
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("gateway.run.GatewayRunner", _CleanExitRunner)
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
# verbosity=1 triggers the code path that uses RedactingFormatter.
|
||||
# Before the fix this raised NameError.
|
||||
ok = await start_gateway(config=GatewayConfig(), replace=False, verbosity=1)
|
||||
|
||||
assert ok is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
@@ -130,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields():
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SESSION_KEY contextvars tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_key_set_via_contextvars(monkeypatch):
|
||||
"""set_session_vars should set HERMES_SESSION_KEY via contextvars."""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
tokens = set_session_vars(
|
||||
platform="telegram",
|
||||
chat_id="-1001",
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_falls_back_to_os_environ(monkeypatch):
|
||||
"""get_session_env for SESSION_KEY should fall back to os.environ."""
|
||||
monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123")
|
||||
|
||||
# No contextvar set — should read from os.environ
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
# Set contextvar — should prefer it
|
||||
tokens = set_session_vars(session_key="ctx-session-456")
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456"
|
||||
|
||||
# Restore — should fall back to os.environ
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
|
||||
def test_set_session_env_includes_session_key():
|
||||
"""_set_session_env should propagate session_key from SessionContext."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[],
|
||||
home_channels={},
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
runner._clear_session_env(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_no_race_condition_with_contextvars(monkeypatch):
|
||||
"""Prove contextvars isolates SESSION_KEY across concurrent async tasks.
|
||||
|
||||
Two tasks set different session keys. With contextvars each task
|
||||
reads back its own value. With os.environ the second task would
|
||||
overwrite the first (the old bug).
|
||||
"""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
results = {}
|
||||
|
||||
async def handler(key: str, delay: float):
|
||||
tokens = set_session_vars(session_key=key)
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
read_back = get_session_env("HERMES_SESSION_KEY")
|
||||
results[key] = read_back
|
||||
finally:
|
||||
clear_session_vars(tokens)
|
||||
|
||||
async def run():
|
||||
task_a = asyncio.create_task(handler("session-A", 0.15))
|
||||
await asyncio.sleep(0.05)
|
||||
task_b = asyncio.create_task(handler("session-B", 0.05))
|
||||
await asyncio.gather(task_a, task_b)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
# Both tasks must read back their own session key
|
||||
assert results["session-A"] == "session-A", (
|
||||
f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!"
|
||||
)
|
||||
assert results["session-B"] == "session-B", (
|
||||
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
|
||||
)
|
||||
|
||||
@@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus:
|
||||
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
|
||||
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
|
||||
|
||||
def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="startup_failed",
|
||||
exit_reason="stale error",
|
||||
platform="discord",
|
||||
platform_state="fatal",
|
||||
error_code="discord_timeout",
|
||||
error_message="stale platform error",
|
||||
)
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="running",
|
||||
exit_reason=None,
|
||||
platform="discord",
|
||||
platform_state="connected",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["gateway_state"] == "running"
|
||||
assert payload["exit_reason"] is None
|
||||
assert payload["platforms"]["discord"]["state"] == "connected"
|
||||
assert payload["platforms"]["discord"]["error_code"] is None
|
||||
assert payload["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
class TestTerminatePid:
|
||||
def test_force_uses_taskkill_on_windows(self, monkeypatch):
|
||||
|
||||
@@ -505,3 +505,81 @@ class TestSegmentBreakOnToolBoundary:
|
||||
assert len(sent_texts) == 3
|
||||
assert sent_texts[0].startswith(prefix)
|
||||
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
|
||||
|
||||
|
||||
class TestInterimCommentaryMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_commentary_message_stays_separate_from_final_stream(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_commentary("I'll inspect the repository first.")
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["I'll inspect the repository first.", "Done."]
|
||||
assert consumer.final_response_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_final_send_does_not_mark_final_response_sent(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(return_value=SimpleNamespace(success=False, message_id=None))
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.final_response_sent is False
|
||||
assert consumer.already_sent is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_without_message_id_marks_visible_and_sends_only_tail(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉"),
|
||||
)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "world"]
|
||||
assert consumer.already_sent is True
|
||||
assert consumer.final_response_sent is True
|
||||
|
||||
@@ -403,6 +403,56 @@ class TestWatchUpdateProgress:
|
||||
|
||||
# Should not crash; legacy notification handles this case
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_forwarded_only_once(self, tmp_path):
|
||||
"""Regression: prompt must not be re-sent on every poll cycle.
|
||||
|
||||
Before the fix, the watcher never deleted .update_prompt.json after
|
||||
forwarding, causing the same prompt to be sent every poll_interval.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222",
|
||||
"session_key": "agent:main:telegram:dm:111"}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text("")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
# Write the prompt file up front (before the watcher starts).
|
||||
# The watcher should forward it exactly once, then delete it.
|
||||
prompt = {"prompt": "Would you like to configure new options now? Y/n",
|
||||
"default": "n", "id": "dup-test"}
|
||||
(hermes_home / ".update_prompt.json").write_text(json.dumps(prompt))
|
||||
|
||||
async def finish_after_polls():
|
||||
# Wait long enough for multiple poll cycles to occur, then
|
||||
# simulate a response + completion.
|
||||
await asyncio.sleep(1.0)
|
||||
(hermes_home / ".update_response").write_text("n")
|
||||
await asyncio.sleep(0.3)
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
task = asyncio.create_task(finish_after_polls())
|
||||
await runner._watch_update_progress(
|
||||
poll_interval=0.1,
|
||||
stream_interval=0.2,
|
||||
timeout=10.0,
|
||||
)
|
||||
await task
|
||||
|
||||
# Count how many times the prompt text was sent
|
||||
all_sent = [str(c) for c in mock_adapter.send.call_args_list]
|
||||
prompt_sends = [s for s in all_sent if "configure new options" in s]
|
||||
assert len(prompt_sends) == 1, (
|
||||
f"Prompt was sent {len(prompt_sends)} times (expected 1). "
|
||||
f"All sends: {all_sent}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message interception for update prompts
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestVerboseCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enabled_cycles_mode(self, tmp_path, monkeypatch):
|
||||
"""When enabled, /verbose cycles tool_progress mode."""
|
||||
"""When enabled, /verbose cycles tool_progress mode per-platform."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
@@ -79,10 +79,11 @@ class TestVerboseCommand:
|
||||
|
||||
# all -> verbose
|
||||
assert "VERBOSE" in result
|
||||
assert "telegram" in result.lower() # per-platform feedback
|
||||
|
||||
# Verify config was saved
|
||||
# Verify config was saved to display.platforms.telegram
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == "verbose"
|
||||
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cycles_through_all_modes(self, tmp_path, monkeypatch):
|
||||
@@ -103,8 +104,9 @@ class TestVerboseCommand:
|
||||
for mode in expected:
|
||||
result = await runner._handle_verbose_command(_make_event())
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == mode, \
|
||||
f"Expected {mode}, got {saved['display']['tool_progress']}"
|
||||
actual = saved["display"]["platforms"]["telegram"]["tool_progress"]
|
||||
assert actual == mode, \
|
||||
f"Expected {mode}, got {actual}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defaults_to_all_when_no_tool_progress_set(self, tmp_path, monkeypatch):
|
||||
@@ -122,10 +124,45 @@ class TestVerboseCommand:
|
||||
runner = _make_runner()
|
||||
result = await runner._handle_verbose_command(_make_event())
|
||||
|
||||
# default "all" -> verbose
|
||||
# Telegram default is "all" (high tier) → cycles to verbose
|
||||
assert "VERBOSE" in result
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == "verbose"
|
||||
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_platform_isolation(self, tmp_path, monkeypatch):
|
||||
"""Cycling /verbose on Telegram doesn't change Slack's setting.
|
||||
|
||||
Without a global tool_progress, each platform uses its built-in
|
||||
default: Telegram = 'all' (high tier), Slack = 'new' (medium tier).
|
||||
"""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
# No global tool_progress → built-in platform defaults apply
|
||||
config_path.write_text(
|
||||
"display:\n tool_progress_command: true\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
runner = _make_runner()
|
||||
|
||||
# Cycle on Telegram
|
||||
await runner._handle_verbose_command(
|
||||
_make_event(platform=Platform.TELEGRAM)
|
||||
)
|
||||
# Cycle on Slack
|
||||
await runner._handle_verbose_command(
|
||||
_make_event(platform=Platform.SLACK)
|
||||
)
|
||||
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
platforms = saved["display"]["platforms"]
|
||||
# Telegram: all -> verbose (high tier default = all)
|
||||
assert platforms["telegram"]["tool_progress"] == "verbose"
|
||||
# Slack: new -> all (medium tier default = new, cycle to all)
|
||||
assert platforms["slack"]["tool_progress"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_config_file_returns_disabled(self, tmp_path, monkeypatch):
|
||||
|
||||
185
tests/gateway/test_wecom_callback.py
Normal file
185
tests/gateway/test_wecom_callback.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for the WeCom callback-mode adapter."""
|
||||
|
||||
import asyncio
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.wecom_callback import WecomCallbackAdapter
|
||||
from gateway.platforms.wecom_crypto import WXBizMsgCrypt
|
||||
|
||||
|
||||
def _app(name="test-app", corp_id="ww1234567890", agent_id="1000002"):
|
||||
return {
|
||||
"name": name,
|
||||
"corp_id": corp_id,
|
||||
"corp_secret": "test-secret",
|
||||
"agent_id": agent_id,
|
||||
"token": "test-callback-token",
|
||||
"encoding_aes_key": "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG",
|
||||
}
|
||||
|
||||
|
||||
def _config(apps=None):
|
||||
return PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"mode": "callback", "host": "127.0.0.1", "port": 0, "apps": apps or [_app()]},
|
||||
)
|
||||
|
||||
|
||||
class TestWecomCrypto:
|
||||
def test_roundtrip_encrypt_decrypt(self):
|
||||
app = _app()
|
||||
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
|
||||
encrypted_xml = crypt.encrypt(
|
||||
"<xml><Content>hello</Content></xml>", nonce="nonce123", timestamp="123456",
|
||||
)
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
decrypted = crypt.decrypt(
|
||||
root.findtext("MsgSignature", default=""),
|
||||
root.findtext("TimeStamp", default=""),
|
||||
root.findtext("Nonce", default=""),
|
||||
root.findtext("Encrypt", default=""),
|
||||
)
|
||||
assert b"<Content>hello</Content>" in decrypted
|
||||
|
||||
def test_signature_mismatch_raises(self):
|
||||
app = _app()
|
||||
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
|
||||
encrypted_xml = crypt.encrypt("<xml/>", nonce="n", timestamp="1")
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
from gateway.platforms.wecom_crypto import SignatureError
|
||||
with pytest.raises(SignatureError):
|
||||
crypt.decrypt("bad-sig", "1", "n", root.findtext("Encrypt", default=""))
|
||||
|
||||
|
||||
class TestWecomCallbackEventConstruction:
|
||||
def test_build_event_extracts_text_message(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
xml_text = """
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>zhangsan</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>text</MsgType>
|
||||
<Content>\u4f60\u597d</Content>
|
||||
<MsgId>123456789</MsgId>
|
||||
</xml>
|
||||
"""
|
||||
event = adapter._build_event(_app(), xml_text)
|
||||
assert event is not None
|
||||
assert event.source is not None
|
||||
assert event.source.user_id == "zhangsan"
|
||||
assert event.source.chat_id == "ww1234567890:zhangsan"
|
||||
assert event.message_id == "123456789"
|
||||
assert event.text == "\u4f60\u597d"
|
||||
|
||||
def test_build_event_returns_none_for_subscribe(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
xml_text = """
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>zhangsan</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>event</MsgType>
|
||||
<Event>subscribe</Event>
|
||||
</xml>
|
||||
"""
|
||||
event = adapter._build_event(_app(), xml_text)
|
||||
assert event is None
|
||||
|
||||
|
||||
class TestWecomCallbackRouting:
|
||||
def test_user_app_key_scopes_across_corps(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
assert adapter._user_app_key("corpA", "alice") == "corpA:alice"
|
||||
assert adapter._user_app_key("corpB", "alice") == "corpB:alice"
|
||||
assert adapter._user_app_key("corpA", "alice") != adapter._user_app_key("corpB", "alice")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_selects_correct_app_for_scoped_chat_id(self):
|
||||
apps = [
|
||||
_app(name="corp-a", corp_id="corpA", agent_id="1001"),
|
||||
_app(name="corp-b", corp_id="corpB", agent_id="2002"),
|
||||
]
|
||||
adapter = WecomCallbackAdapter(_config(apps=apps))
|
||||
adapter._user_app_map["corpB:alice"] = "corp-b"
|
||||
adapter._access_tokens["corp-b"] = {"token": "tok-b", "expires_at": 9999999999}
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def json(self):
|
||||
return {"errcode": 0, "msgid": "ok1"}
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json):
|
||||
calls["url"] = url
|
||||
calls["json"] = json
|
||||
return FakeResponse()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("corpB:alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert calls["json"]["touser"] == "alice"
|
||||
assert calls["json"]["agentid"] == 2002
|
||||
assert "tok-b" in calls["url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_falls_back_from_bare_user_id_when_unique(self):
|
||||
apps = [_app(name="corp-a", corp_id="corpA", agent_id="1001")]
|
||||
adapter = WecomCallbackAdapter(_config(apps=apps))
|
||||
adapter._user_app_map["corpA:alice"] = "corp-a"
|
||||
adapter._access_tokens["corp-a"] = {"token": "tok-a", "expires_at": 9999999999}
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def json(self):
|
||||
return {"errcode": 0, "msgid": "ok2"}
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json):
|
||||
calls["url"] = url
|
||||
calls["json"] = json
|
||||
return FakeResponse()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert calls["json"]["agentid"] == 1001
|
||||
|
||||
|
||||
class TestWecomCallbackPollLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_loop_dispatches_handle_message(self, monkeypatch):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
calls = []
|
||||
|
||||
async def fake_handle_message(event):
|
||||
calls.append(event.text)
|
||||
|
||||
monkeypatch.setattr(adapter, "handle_message", fake_handle_message)
|
||||
event = adapter._build_event(
|
||||
_app(),
|
||||
"""
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>lisi</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>text</MsgType>
|
||||
<Content>test</Content>
|
||||
<MsgId>m2</MsgId>
|
||||
</xml>
|
||||
""",
|
||||
)
|
||||
task = asyncio.create_task(adapter._poll_loop())
|
||||
await adapter._message_queue.put(event)
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
assert calls == ["test"]
|
||||
@@ -64,13 +64,44 @@ class TestWeixinFormatting:
|
||||
|
||||
|
||||
class TestWeixinChunking:
|
||||
def test_split_text_keeps_short_multiline_message_in_single_chunk(self):
|
||||
def test_split_text_splits_short_chatty_replies_into_separate_bubbles(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message("第一行\n第二行\n第三行")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["第一行\n第二行\n第三行"]
|
||||
assert chunks == ["第一行", "第二行", "第三行"]
|
||||
|
||||
def test_split_text_keeps_structured_table_block_together(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message(
|
||||
"- Setting: Timeout\n Value: 30s\n- Setting: Retries\n Value: 3"
|
||||
)
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["- Setting: Timeout\n Value: 30s\n- Setting: Retries\n Value: 3"]
|
||||
|
||||
def test_split_text_keeps_four_line_structured_blocks_together(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message(
|
||||
"今天结论:\n"
|
||||
"- 留存下降 3%\n"
|
||||
"- 转化上涨 8%\n"
|
||||
"- 主要问题在首日激活"
|
||||
)
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["今天结论:\n- 留存下降 3%\n- 转化上涨 8%\n- 主要问题在首日激活"]
|
||||
|
||||
def test_split_text_keeps_heading_with_body_together(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message("## 结论\n这是正文")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["**结论**\n这是正文"]
|
||||
|
||||
def test_split_text_keeps_short_reformatted_table_in_single_chunk(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
@@ -14,6 +14,7 @@ from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
_read_codex_tokens,
|
||||
_save_codex_tokens,
|
||||
_write_codex_cli_tokens,
|
||||
_import_codex_cli_tokens,
|
||||
get_codex_auth_status,
|
||||
get_provider_auth_state,
|
||||
@@ -161,7 +162,7 @@ def test_import_codex_cli_tokens_missing(tmp_path, monkeypatch):
|
||||
|
||||
|
||||
def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
|
||||
"""Verify Hermes never writes to ~/.codex/auth.json."""
|
||||
"""Verify _save_codex_tokens writes only to Hermes auth store, not ~/.codex/."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
@@ -173,7 +174,7 @@ def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
|
||||
|
||||
_save_codex_tokens({"access_token": "hermes-at", "refresh_token": "hermes-rt"})
|
||||
|
||||
# ~/.codex/auth.json should NOT exist
|
||||
# ~/.codex/auth.json should NOT exist — _save_codex_tokens only touches Hermes store
|
||||
assert not (codex_home / "auth.json").exists()
|
||||
|
||||
# Hermes auth store should have the tokens
|
||||
@@ -181,6 +182,98 @@ def test_codex_tokens_not_written_to_shared_file(tmp_path, monkeypatch):
|
||||
assert data["tokens"]["access_token"] == "hermes-at"
|
||||
|
||||
|
||||
def test_write_codex_cli_tokens_creates_file(tmp_path, monkeypatch):
|
||||
"""_write_codex_cli_tokens creates ~/.codex/auth.json with refreshed tokens."""
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
_write_codex_cli_tokens("new-access", "new-refresh", last_refresh="2026-04-12T00:00:00Z")
|
||||
|
||||
auth_path = codex_home / "auth.json"
|
||||
assert auth_path.exists()
|
||||
data = json.loads(auth_path.read_text())
|
||||
assert data["tokens"]["access_token"] == "new-access"
|
||||
assert data["tokens"]["refresh_token"] == "new-refresh"
|
||||
assert data["last_refresh"] == "2026-04-12T00:00:00Z"
|
||||
# Verify file permissions are restricted
|
||||
assert (auth_path.stat().st_mode & 0o777) == 0o600
|
||||
|
||||
|
||||
def test_write_codex_cli_tokens_preserves_existing(tmp_path, monkeypatch):
|
||||
"""_write_codex_cli_tokens preserves extra fields in existing auth.json."""
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
codex_home.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
existing = {
|
||||
"tokens": {
|
||||
"access_token": "old-access",
|
||||
"refresh_token": "old-refresh",
|
||||
"extra_field": "preserved",
|
||||
},
|
||||
"last_refresh": "2026-01-01T00:00:00Z",
|
||||
"custom_key": "keep_me",
|
||||
}
|
||||
(codex_home / "auth.json").write_text(json.dumps(existing))
|
||||
|
||||
_write_codex_cli_tokens("updated-access", "updated-refresh")
|
||||
|
||||
data = json.loads((codex_home / "auth.json").read_text())
|
||||
assert data["tokens"]["access_token"] == "updated-access"
|
||||
assert data["tokens"]["refresh_token"] == "updated-refresh"
|
||||
assert data["tokens"]["extra_field"] == "preserved"
|
||||
assert data["custom_key"] == "keep_me"
|
||||
# last_refresh not updated since we didn't pass it
|
||||
assert data["last_refresh"] == "2026-01-01T00:00:00Z"
|
||||
|
||||
|
||||
def test_write_codex_cli_tokens_handles_missing_dir(tmp_path, monkeypatch):
|
||||
"""_write_codex_cli_tokens creates parent directories if missing."""
|
||||
codex_home = tmp_path / "does" / "not" / "exist"
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
_write_codex_cli_tokens("at", "rt")
|
||||
|
||||
assert (codex_home / "auth.json").exists()
|
||||
data = json.loads((codex_home / "auth.json").read_text())
|
||||
assert data["tokens"]["access_token"] == "at"
|
||||
|
||||
|
||||
def test_refresh_codex_auth_tokens_writes_back_to_cli(tmp_path, monkeypatch):
|
||||
"""After refreshing, _refresh_codex_auth_tokens writes back to ~/.codex/auth.json."""
|
||||
from hermes_cli.auth import _refresh_codex_auth_tokens
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
codex_home = tmp_path / "codex-cli"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
codex_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({"version": 1, "providers": {}}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
# Write initial CLI tokens
|
||||
(codex_home / "auth.json").write_text(json.dumps({
|
||||
"tokens": {"access_token": "old-at", "refresh_token": "old-rt"},
|
||||
}))
|
||||
|
||||
# Mock the pure refresh to return new tokens
|
||||
monkeypatch.setattr("hermes_cli.auth.refresh_codex_oauth_pure", lambda *a, **kw: {
|
||||
"access_token": "refreshed-at",
|
||||
"refresh_token": "refreshed-rt",
|
||||
"last_refresh": "2026-04-12T01:00:00Z",
|
||||
})
|
||||
|
||||
_refresh_codex_auth_tokens(
|
||||
{"access_token": "old-at", "refresh_token": "old-rt"},
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
# Verify CLI file was updated
|
||||
cli_data = json.loads((codex_home / "auth.json").read_text())
|
||||
assert cli_data["tokens"]["access_token"] == "refreshed-at"
|
||||
assert cli_data["tokens"]["refresh_token"] == "refreshed-rt"
|
||||
|
||||
|
||||
def test_resolve_returns_hermes_auth_store_source(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
_setup_hermes_auth(hermes_home)
|
||||
|
||||
897
tests/hermes_cli/test_backup.py
Normal file
897
tests/hermes_cli/test_backup.py
Normal file
@@ -0,0 +1,897 @@
|
||||
"""Tests for hermes backup and import commands."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_hermes_tree(root: Path) -> None:
|
||||
"""Create a realistic ~/.hermes directory structure for testing."""
|
||||
(root / "config.yaml").write_text("model:\n provider: openrouter\n")
|
||||
(root / ".env").write_text("OPENROUTER_API_KEY=sk-test-123\n")
|
||||
(root / "memory_store.db").write_bytes(b"fake-sqlite")
|
||||
(root / "hermes_state.db").write_bytes(b"fake-state")
|
||||
|
||||
# Sessions
|
||||
(root / "sessions").mkdir(exist_ok=True)
|
||||
(root / "sessions" / "abc123.json").write_text("{}")
|
||||
|
||||
# Skills
|
||||
(root / "skills").mkdir(exist_ok=True)
|
||||
(root / "skills" / "my-skill").mkdir()
|
||||
(root / "skills" / "my-skill" / "SKILL.md").write_text("# My Skill\n")
|
||||
|
||||
# Skins
|
||||
(root / "skins").mkdir(exist_ok=True)
|
||||
(root / "skins" / "cyber.yaml").write_text("name: cyber\n")
|
||||
|
||||
# Cron
|
||||
(root / "cron").mkdir(exist_ok=True)
|
||||
(root / "cron" / "jobs.json").write_text("[]")
|
||||
|
||||
# Memories
|
||||
(root / "memories").mkdir(exist_ok=True)
|
||||
(root / "memories" / "notes.json").write_text("{}")
|
||||
|
||||
# Profiles
|
||||
(root / "profiles").mkdir(exist_ok=True)
|
||||
(root / "profiles" / "coder").mkdir()
|
||||
(root / "profiles" / "coder" / "config.yaml").write_text("model:\n provider: anthropic\n")
|
||||
(root / "profiles" / "coder" / ".env").write_text("ANTHROPIC_API_KEY=sk-ant-123\n")
|
||||
|
||||
# hermes-agent repo (should be EXCLUDED)
|
||||
(root / "hermes-agent").mkdir(exist_ok=True)
|
||||
(root / "hermes-agent" / "run_agent.py").write_text("# big file\n")
|
||||
(root / "hermes-agent" / ".git").mkdir()
|
||||
(root / "hermes-agent" / ".git" / "HEAD").write_text("ref: refs/heads/main\n")
|
||||
|
||||
# __pycache__ (should be EXCLUDED)
|
||||
(root / "plugins").mkdir(exist_ok=True)
|
||||
(root / "plugins" / "__pycache__").mkdir()
|
||||
(root / "plugins" / "__pycache__" / "mod.cpython-312.pyc").write_bytes(b"\x00")
|
||||
|
||||
# PID files (should be EXCLUDED)
|
||||
(root / "gateway.pid").write_text("12345")
|
||||
|
||||
# Logs (should be included)
|
||||
(root / "logs").mkdir(exist_ok=True)
|
||||
(root / "logs" / "agent.log").write_text("log line\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _should_exclude tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShouldExclude:
|
||||
def test_excludes_hermes_agent(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("hermes-agent/run_agent.py"))
|
||||
assert _should_exclude(Path("hermes-agent/.git/HEAD"))
|
||||
|
||||
def test_excludes_pycache(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("plugins/__pycache__/mod.cpython-312.pyc"))
|
||||
|
||||
def test_excludes_pyc_files(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("some/module.pyc"))
|
||||
|
||||
def test_excludes_pid_files(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("gateway.pid"))
|
||||
assert _should_exclude(Path("cron.pid"))
|
||||
|
||||
def test_includes_config(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("config.yaml"))
|
||||
|
||||
def test_includes_env(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path(".env"))
|
||||
|
||||
def test_includes_skills(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("skills/my-skill/SKILL.md"))
|
||||
|
||||
def test_includes_profiles(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("profiles/coder/config.yaml"))
|
||||
|
||||
def test_includes_sessions(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("sessions/abc.json"))
|
||||
|
||||
def test_includes_logs(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("logs/agent.log"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backup tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackup:
|
||||
def test_creates_zip(self, tmp_path, monkeypatch):
|
||||
"""Backup creates a valid zip containing expected files."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# get_default_hermes_root needs this
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
assert out_zip.exists()
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
# Config should be present
|
||||
assert "config.yaml" in names
|
||||
assert ".env" in names
|
||||
# Skills
|
||||
assert "skills/my-skill/SKILL.md" in names
|
||||
# Profiles
|
||||
assert "profiles/coder/config.yaml" in names
|
||||
assert "profiles/coder/.env" in names
|
||||
# Sessions
|
||||
assert "sessions/abc123.json" in names
|
||||
# Logs
|
||||
assert "logs/agent.log" in names
|
||||
# Skins
|
||||
assert "skins/cyber.yaml" in names
|
||||
|
||||
def test_excludes_hermes_agent(self, tmp_path, monkeypatch):
|
||||
"""Backup does NOT include hermes-agent/ directory."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
agent_files = [n for n in names if "hermes-agent" in n]
|
||||
assert agent_files == [], f"hermes-agent files leaked into backup: {agent_files}"
|
||||
|
||||
def test_excludes_pycache(self, tmp_path, monkeypatch):
|
||||
"""Backup does NOT include __pycache__ dirs."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
pycache_files = [n for n in names if "__pycache__" in n]
|
||||
assert pycache_files == []
|
||||
|
||||
def test_excludes_pid_files(self, tmp_path, monkeypatch):
|
||||
"""Backup does NOT include PID files."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
pid_files = [n for n in names if n.endswith(".pid")]
|
||||
assert pid_files == []
|
||||
|
||||
def test_default_output_path(self, tmp_path, monkeypatch):
|
||||
"""When no output path given, zip goes to ~/hermes-backup-*.zip."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
args = Namespace(output=None)
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# Should exist in home dir
|
||||
zips = list(tmp_path.glob("hermes-backup-*.zip"))
|
||||
assert len(zips) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImport:
|
||||
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
|
||||
"""Create a test zip with given files."""
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name, content in files.items():
|
||||
if isinstance(content, bytes):
|
||||
zf.writestr(name, content)
|
||||
else:
|
||||
zf.writestr(name, content)
|
||||
|
||||
def test_restores_files(self, tmp_path, monkeypatch):
|
||||
"""Import extracts files into hermes home."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model:\n provider: openrouter\n",
|
||||
".env": "OPENROUTER_API_KEY=sk-test\n",
|
||||
"skills/my-skill/SKILL.md": "# My Skill\n",
|
||||
"profiles/coder/config.yaml": "model:\n provider: anthropic\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").read_text() == "model:\n provider: openrouter\n"
|
||||
assert (hermes_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test\n"
|
||||
assert (hermes_home / "skills" / "my-skill" / "SKILL.md").read_text() == "# My Skill\n"
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
|
||||
def test_strips_hermes_prefix(self, tmp_path, monkeypatch):
|
||||
"""Import strips .hermes/ prefix if all entries share it."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
".hermes/config.yaml": "model: test\n",
|
||||
".hermes/skills/a/SKILL.md": "# A\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").read_text() == "model: test\n"
|
||||
assert (hermes_home / "skills" / "a" / "SKILL.md").read_text() == "# A\n"
|
||||
|
||||
def test_rejects_empty_zip(self, tmp_path, monkeypatch):
|
||||
"""Import rejects an empty zip."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "empty.zip"
|
||||
with zipfile.ZipFile(zip_path, "w"):
|
||||
pass # empty
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_rejects_non_hermes_zip(self, tmp_path, monkeypatch):
|
||||
"""Import rejects a zip that doesn't look like a hermes backup."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "random.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"some/random/file.txt": "hello",
|
||||
"another/thing.json": "{}",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_blocks_path_traversal(self, tmp_path, monkeypatch):
|
||||
"""Import blocks zip entries with path traversal."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "evil.zip"
|
||||
# Include a marker file so validation passes
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"../../etc/passwd": "root:x:0:0\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
# config.yaml should be restored
|
||||
assert (hermes_home / "config.yaml").exists()
|
||||
# traversal file should NOT exist outside hermes home
|
||||
assert not (tmp_path / "etc" / "passwd").exists()
|
||||
|
||||
def test_confirmation_prompt_abort(self, tmp_path, monkeypatch):
|
||||
"""Import aborts when user says no to confirmation."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
# Pre-existing config triggers the confirmation
|
||||
(hermes_home / "config.yaml").write_text("existing: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: restored\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=False)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.input", return_value="n"):
|
||||
run_import(args)
|
||||
|
||||
# Original config should be unchanged
|
||||
assert (hermes_home / "config.yaml").read_text() == "existing: true\n"
|
||||
|
||||
def test_force_skips_confirmation(self, tmp_path, monkeypatch):
|
||||
"""Import with --force skips confirmation and overwrites."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("existing: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: restored\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").read_text() == "model: restored\n"
|
||||
|
||||
def test_missing_file_exits(self, tmp_path, monkeypatch):
|
||||
"""Import exits with error for nonexistent file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
args = Namespace(zipfile=str(tmp_path / "nonexistent.zip"), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRoundTrip:
|
||||
def test_backup_then_import(self, tmp_path, monkeypatch):
|
||||
"""Full round-trip: backup -> import to a new location -> verify."""
|
||||
# Source
|
||||
src_home = tmp_path / "source" / ".hermes"
|
||||
src_home.mkdir(parents=True)
|
||||
_make_hermes_tree(src_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(src_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "source")
|
||||
|
||||
# Backup
|
||||
out_zip = tmp_path / "roundtrip.zip"
|
||||
from hermes_cli.backup import run_backup, run_import
|
||||
|
||||
run_backup(Namespace(output=str(out_zip)))
|
||||
assert out_zip.exists()
|
||||
|
||||
# Import into a different location
|
||||
dst_home = tmp_path / "dest" / ".hermes"
|
||||
dst_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(dst_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "dest")
|
||||
|
||||
run_import(Namespace(zipfile=str(out_zip), force=True))
|
||||
|
||||
# Verify key files
|
||||
assert (dst_home / "config.yaml").read_text() == "model:\n provider: openrouter\n"
|
||||
assert (dst_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test-123\n"
|
||||
assert (dst_home / "skills" / "my-skill" / "SKILL.md").exists()
|
||||
assert (dst_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
assert (dst_home / "sessions" / "abc123.json").exists()
|
||||
assert (dst_home / "logs" / "agent.log").exists()
|
||||
|
||||
# hermes-agent should NOT be present
|
||||
assert not (dst_home / "hermes-agent").exists()
|
||||
# __pycache__ should NOT be present
|
||||
assert not (dst_home / "plugins" / "__pycache__").exists()
|
||||
# PID files should NOT be present
|
||||
assert not (dst_home / "gateway.pid").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validate / detect-prefix unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFormatSize:
|
||||
def test_bytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert _format_size(512) == "512 B"
|
||||
|
||||
def test_kilobytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "KB" in _format_size(2048)
|
||||
|
||||
def test_megabytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "MB" in _format_size(5 * 1024 * 1024)
|
||||
|
||||
def test_gigabytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "GB" in _format_size(3 * 1024 ** 3)
|
||||
|
||||
def test_terabytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "TB" in _format_size(2 * 1024 ** 4)
|
||||
|
||||
|
||||
class TestValidation:
|
||||
def test_validate_with_config(self):
|
||||
"""Zip with config.yaml passes validation."""
|
||||
import io
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("config.yaml", "test")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert ok
|
||||
|
||||
def test_validate_with_env(self):
|
||||
"""Zip with .env passes validation."""
|
||||
import io
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr(".env", "KEY=val")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert ok
|
||||
|
||||
def test_validate_rejects_random(self):
|
||||
"""Zip without hermes markers fails validation."""
|
||||
import io
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("random/file.txt", "hello")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert not ok
|
||||
|
||||
def test_detect_prefix_hermes(self):
|
||||
"""Detects .hermes/ prefix wrapping all entries."""
|
||||
import io
|
||||
from hermes_cli.backup import _detect_prefix
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr(".hermes/config.yaml", "test")
|
||||
zf.writestr(".hermes/skills/a/SKILL.md", "skill")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
assert _detect_prefix(zf) == ".hermes/"
|
||||
|
||||
def test_detect_prefix_none(self):
|
||||
"""No prefix when entries are at root."""
|
||||
import io
|
||||
from hermes_cli.backup import _detect_prefix
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("config.yaml", "test")
|
||||
zf.writestr("skills/a/SKILL.md", "skill")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
assert _detect_prefix(zf) == ""
|
||||
|
||||
def test_detect_prefix_only_dirs(self):
|
||||
"""Prefix detection returns empty for zip with only directory entries."""
|
||||
import io
|
||||
from hermes_cli.backup import _detect_prefix
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
# Only directory entries (trailing slash)
|
||||
zf.writestr(".hermes/", "")
|
||||
zf.writestr(".hermes/skills/", "")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
assert _detect_prefix(zf) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge case tests for uncovered paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackupEdgeCases:
|
||||
def test_nonexistent_hermes_home(self, tmp_path, monkeypatch):
|
||||
"""Backup exits when hermes home doesn't exist."""
|
||||
fake_home = tmp_path / "nonexistent" / ".hermes"
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "nonexistent")
|
||||
|
||||
args = Namespace(output=str(tmp_path / "out.zip"))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
with pytest.raises(SystemExit):
|
||||
run_backup(args)
|
||||
|
||||
def test_output_is_directory(self, tmp_path, monkeypatch):
|
||||
"""When output path is a directory, zip is created inside it."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_dir = tmp_path / "backups"
|
||||
out_dir.mkdir()
|
||||
|
||||
args = Namespace(output=str(out_dir))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
zips = list(out_dir.glob("hermes-backup-*.zip"))
|
||||
assert len(zips) == 1
|
||||
|
||||
def test_output_without_zip_suffix(self, tmp_path, monkeypatch):
|
||||
"""Output path without .zip gets suffix appended."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_path = tmp_path / "mybackup.tar"
|
||||
args = Namespace(output=str(out_path))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# Should have .tar.zip suffix
|
||||
assert (tmp_path / "mybackup.tar.zip").exists()
|
||||
|
||||
def test_empty_hermes_home(self, tmp_path, monkeypatch):
|
||||
"""Backup handles empty hermes home (no files to back up)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
# Only excluded dirs, no actual files
|
||||
(hermes_home / "__pycache__").mkdir()
|
||||
(hermes_home / "__pycache__" / "foo.pyc").write_bytes(b"\x00")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
args = Namespace(output=str(tmp_path / "out.zip"))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# No zip should be created
|
||||
assert not (tmp_path / "out.zip").exists()
|
||||
|
||||
def test_permission_error_during_backup(self, tmp_path, monkeypatch):
|
||||
"""Backup handles permission errors gracefully."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
# Create an unreadable file
|
||||
bad_file = hermes_home / "secret.db"
|
||||
bad_file.write_text("data")
|
||||
bad_file.chmod(0o000)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "out.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
try:
|
||||
run_backup(args)
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
bad_file.chmod(0o644)
|
||||
|
||||
# Zip should still be created with the readable files
|
||||
assert out_zip.exists()
|
||||
|
||||
def test_skips_output_zip_inside_hermes(self, tmp_path, monkeypatch):
|
||||
"""Backup skips its own output zip if it's inside hermes root."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
# Output inside hermes home
|
||||
out_zip = hermes_home / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# The zip should exist but not contain itself
|
||||
assert out_zip.exists()
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
assert "backup.zip" not in zf.namelist()
|
||||
|
||||
|
||||
class TestImportEdgeCases:
|
||||
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name, content in files.items():
|
||||
zf.writestr(name, content)
|
||||
|
||||
def test_not_a_zip(self, tmp_path, monkeypatch):
|
||||
"""Import rejects a non-zip file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
not_zip = tmp_path / "fake.zip"
|
||||
not_zip.write_text("this is not a zip")
|
||||
|
||||
args = Namespace(zipfile=str(not_zip), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_eof_during_confirmation(self, tmp_path, monkeypatch):
|
||||
"""Import handles EOFError during confirmation prompt."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("existing\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {"config.yaml": "new\n"})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=False)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.input", side_effect=EOFError):
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_keyboard_interrupt_during_confirmation(self, tmp_path, monkeypatch):
|
||||
"""Import handles KeyboardInterrupt during confirmation prompt."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / ".env").write_text("KEY=val\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {"config.yaml": "new\n"})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=False)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.input", side_effect=KeyboardInterrupt):
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_permission_error_during_import(self, tmp_path, monkeypatch):
|
||||
"""Import handles permission errors during extraction."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
# Create a read-only directory so extraction fails
|
||||
locked_dir = hermes_home / "locked"
|
||||
locked_dir.mkdir()
|
||||
locked_dir.chmod(0o555)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"locked/secret.txt": "data",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
try:
|
||||
run_import(args)
|
||||
finally:
|
||||
locked_dir.chmod(0o755)
|
||||
|
||||
# config.yaml should still be restored despite the error
|
||||
assert (hermes_home / "config.yaml").exists()
|
||||
|
||||
def test_progress_with_many_files(self, tmp_path, monkeypatch):
|
||||
"""Import shows progress with 500+ files."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "big.zip"
|
||||
files = {"config.yaml": "model: test\n"}
|
||||
for i in range(600):
|
||||
files[f"sessions/s{i:04d}.json"] = "{}"
|
||||
|
||||
self._make_backup_zip(zip_path, files)
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").exists()
|
||||
assert (hermes_home / "sessions" / "s0599.json").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profile restoration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProfileRestoration:
|
||||
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name, content in files.items():
|
||||
zf.writestr(name, content)
|
||||
|
||||
def test_import_creates_profile_wrappers(self, tmp_path, monkeypatch):
|
||||
"""Import auto-creates wrapper scripts for restored profiles."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
# Mock the wrapper dir to be inside tmp_path
|
||||
wrapper_dir = tmp_path / ".local" / "bin"
|
||||
wrapper_dir.mkdir(parents=True)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model:\n provider: openrouter\n",
|
||||
"profiles/coder/config.yaml": "model:\n provider: anthropic\n",
|
||||
"profiles/coder/.env": "ANTHROPIC_API_KEY=sk-test\n",
|
||||
"profiles/researcher/config.yaml": "model:\n provider: deepseek\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
# Profile directories should exist
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
assert (hermes_home / "profiles" / "researcher" / "config.yaml").exists()
|
||||
|
||||
# Wrapper scripts should be created
|
||||
assert (wrapper_dir / "coder").exists()
|
||||
assert (wrapper_dir / "researcher").exists()
|
||||
|
||||
# Wrappers should contain the right content
|
||||
coder_wrapper = (wrapper_dir / "coder").read_text()
|
||||
assert "hermes -p coder" in coder_wrapper
|
||||
|
||||
def test_import_skips_profile_dirs_without_config(self, tmp_path, monkeypatch):
|
||||
"""Import doesn't create wrappers for profile dirs without config."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
wrapper_dir = tmp_path / ".local" / "bin"
|
||||
wrapper_dir.mkdir(parents=True)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"profiles/valid/config.yaml": "model: test\n",
|
||||
"profiles/empty/readme.txt": "nothing here\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
# Only valid profile should get a wrapper
|
||||
assert (wrapper_dir / "valid").exists()
|
||||
assert not (wrapper_dir / "empty").exists()
|
||||
|
||||
def test_import_without_profiles_module(self, tmp_path, monkeypatch):
|
||||
"""Import gracefully handles missing profiles module (fresh install)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"profiles/coder/config.yaml": "model: test\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
# Simulate profiles module not being available
|
||||
import hermes_cli.backup as backup_mod
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def fake_import(name, *a, **kw):
|
||||
if name == "hermes_cli.profiles":
|
||||
raise ImportError("no profiles module")
|
||||
return original_import(name, *a, **kw)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
run_import(args)
|
||||
|
||||
# Files should still be restored even if wrappers can't be created
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
@@ -58,13 +58,13 @@ class TestFindOpenclawDirs:
|
||||
def test_finds_legacy_dirs(self, tmp_path):
|
||||
clawdbot = tmp_path / ".clawdbot"
|
||||
clawdbot.mkdir()
|
||||
moldbot = tmp_path / ".moldbot"
|
||||
moldbot.mkdir()
|
||||
moltbot = tmp_path / ".moltbot"
|
||||
moltbot.mkdir()
|
||||
with patch("pathlib.Path.home", return_value=tmp_path):
|
||||
found = claw_mod._find_openclaw_dirs()
|
||||
assert len(found) == 2
|
||||
assert clawdbot in found
|
||||
assert moldbot in found
|
||||
assert moltbot in found
|
||||
|
||||
def test_returns_empty_when_none_exist(self, tmp_path):
|
||||
with patch("pathlib.Path.home", return_value=tmp_path):
|
||||
@@ -297,7 +297,6 @@ class TestCmdMigrate:
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(claw_mod, "_offer_source_archival"),
|
||||
patch("sys.stdin", mock_stdin),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
@@ -306,43 +305,8 @@ class TestCmdMigrate:
|
||||
assert "Migration Results" in captured.out
|
||||
assert "Migration complete!" in captured.out
|
||||
|
||||
def test_execute_offers_archival_on_success(self, tmp_path, capsys):
|
||||
"""After successful migration, _offer_source_archival should be called."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 3, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": str(tmp_path / "SOUL.md")},
|
||||
],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
patch.object(claw_mod, "_offer_source_archival") as mock_archival,
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
mock_archival.assert_called_once_with(openclaw_dir, True)
|
||||
|
||||
def test_dry_run_skips_archival(self, tmp_path, capsys):
|
||||
"""Dry run should not offer archival."""
|
||||
def test_dry_run_does_not_touch_source(self, tmp_path, capsys):
|
||||
"""Dry run should not modify the source directory."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
@@ -369,11 +333,10 @@ class TestCmdMigrate:
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
patch.object(claw_mod, "_offer_source_archival") as mock_archival,
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
mock_archival.assert_not_called()
|
||||
assert openclaw_dir.is_dir() # Source untouched
|
||||
|
||||
def test_execute_cancelled_by_user(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
@@ -506,73 +469,6 @@ class TestCmdMigrate:
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _offer_source_archival
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOfferSourceArchival:
|
||||
"""Test the post-migration archival offer."""
|
||||
|
||||
def test_archives_with_auto_yes(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
(source / "workspace").mkdir()
|
||||
(source / "workspace" / "todo.json").write_text("{}")
|
||||
|
||||
claw_mod._offer_source_archival(source, auto_yes=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Archived" in captured.out
|
||||
assert not source.exists()
|
||||
assert (tmp_path / ".openclaw.pre-migration").is_dir()
|
||||
|
||||
def test_skips_when_user_declines(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
patch("sys.stdin", mock_stdin),
|
||||
):
|
||||
claw_mod._offer_source_archival(source, auto_yes=False)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Skipped" in captured.out
|
||||
assert source.is_dir() # Still exists
|
||||
|
||||
def test_noop_when_source_missing(self, tmp_path, capsys):
|
||||
claw_mod._offer_source_archival(tmp_path / "nonexistent", auto_yes=True)
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "" # No output
|
||||
|
||||
def test_shows_state_files(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
ws = source / "workspace"
|
||||
ws.mkdir()
|
||||
(ws / "todo.json").write_text("{}")
|
||||
|
||||
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
|
||||
claw_mod._offer_source_archival(source, auto_yes=False)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "todo.json" in captured.out
|
||||
|
||||
def test_handles_archive_error(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
|
||||
with patch.object(claw_mod, "_archive_directory", side_effect=OSError("permission denied")):
|
||||
claw_mod._offer_source_archival(source, auto_yes=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Could not archive" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cmd_cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
254
tests/hermes_cli/test_cli_model_picker.py
Normal file
254
tests/hermes_cli/test_cli_model_picker.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Tests for the interactive CLI /model picker (provider → model drill-down)."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class _FakeBuffer:
|
||||
def __init__(self, text="draft text"):
|
||||
self.text = text
|
||||
self.cursor_position = len(text)
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, append_to_history=False):
|
||||
self.reset_calls.append(append_to_history)
|
||||
self.text = ""
|
||||
self.cursor_position = 0
|
||||
|
||||
|
||||
def _make_providers():
|
||||
return [
|
||||
{
|
||||
"slug": "openrouter",
|
||||
"name": "OpenRouter",
|
||||
"is_current": True,
|
||||
"is_user_defined": False,
|
||||
"models": ["anthropic/claude-opus-4.6", "openai/gpt-5.4"],
|
||||
"total_models": 2,
|
||||
"source": "built-in",
|
||||
},
|
||||
{
|
||||
"slug": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"is_current": False,
|
||||
"is_user_defined": False,
|
||||
"models": ["claude-opus-4.6", "claude-sonnet-4.6"],
|
||||
"total_models": 2,
|
||||
"source": "built-in",
|
||||
},
|
||||
{
|
||||
"slug": "custom:my-ollama",
|
||||
"name": "My Ollama",
|
||||
"is_current": False,
|
||||
"is_user_defined": True,
|
||||
"models": ["llama3", "mistral"],
|
||||
"total_models": 2,
|
||||
"source": "user-config",
|
||||
"api_url": "http://localhost:11434/v1",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _make_picker_cli(picker_return_value):
|
||||
cli = MagicMock()
|
||||
cli._run_curses_picker = MagicMock(return_value=picker_return_value)
|
||||
cli._app = MagicMock()
|
||||
cli._status_bar_visible = True
|
||||
return cli
|
||||
|
||||
|
||||
def _make_modal_cli():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.model = "gpt-5.4"
|
||||
cli.provider = "openrouter"
|
||||
cli.requested_provider = "openrouter"
|
||||
cli.base_url = ""
|
||||
cli.api_key = ""
|
||||
cli.api_mode = ""
|
||||
cli._explicit_api_key = ""
|
||||
cli._explicit_base_url = ""
|
||||
cli._pending_model_switch_note = None
|
||||
cli._model_picker_state = None
|
||||
cli._modal_input_snapshot = None
|
||||
cli._status_bar_visible = True
|
||||
cli._invalidate = MagicMock()
|
||||
cli.agent = None
|
||||
cli.config = {}
|
||||
cli.console = MagicMock()
|
||||
cli._app = SimpleNamespace(
|
||||
current_buffer=_FakeBuffer(),
|
||||
invalidate=MagicMock(),
|
||||
)
|
||||
return cli
|
||||
|
||||
|
||||
def test_provider_selection_returns_slug_on_choice():
|
||||
providers = _make_providers()
|
||||
cli = _make_picker_cli(1)
|
||||
from cli import HermesCLI
|
||||
|
||||
result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
|
||||
|
||||
assert result == "anthropic"
|
||||
cli._run_curses_picker.assert_called_once()
|
||||
|
||||
|
||||
def test_provider_selection_returns_none_on_cancel():
|
||||
providers = _make_providers()
|
||||
cli = _make_picker_cli(None)
|
||||
from cli import HermesCLI
|
||||
|
||||
result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_provider_selection_default_is_current():
|
||||
providers = _make_providers()
|
||||
cli = _make_picker_cli(0)
|
||||
from cli import HermesCLI
|
||||
|
||||
HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
|
||||
|
||||
assert cli._run_curses_picker.call_args.kwargs["default_index"] == 0
|
||||
|
||||
|
||||
def test_model_selection_returns_model_on_choice():
|
||||
provider_data = _make_providers()[0]
|
||||
cli = _make_picker_cli(0)
|
||||
from cli import HermesCLI
|
||||
|
||||
result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data)
|
||||
|
||||
assert result == "anthropic/claude-opus-4.6"
|
||||
|
||||
|
||||
def test_model_selection_custom_entry_prompts_for_input():
|
||||
provider_data = _make_providers()[0]
|
||||
cli = _make_picker_cli(2)
|
||||
from cli import HermesCLI
|
||||
|
||||
cli._prompt_text_input = MagicMock(return_value="my-custom-model")
|
||||
result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data)
|
||||
|
||||
assert result == "my-custom-model"
|
||||
cli._prompt_text_input.assert_called_once_with(" Enter model name: ")
|
||||
|
||||
|
||||
def test_model_selection_empty_prompts_for_manual_input():
|
||||
provider_data = {
|
||||
"slug": "custom:empty",
|
||||
"name": "Empty Provider",
|
||||
"models": [],
|
||||
"total_models": 0,
|
||||
}
|
||||
cli = _make_picker_cli(None)
|
||||
from cli import HermesCLI
|
||||
|
||||
cli._prompt_text_input = MagicMock(return_value="my-model")
|
||||
result = HermesCLI._interactive_model_selection(cli, [], provider_data)
|
||||
|
||||
assert result == "my-model"
|
||||
cli._prompt_text_input.assert_called_once_with(" Enter model name manually (or Enter to cancel): ")
|
||||
|
||||
|
||||
def test_prompt_text_input_uses_run_in_terminal_when_app_active():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
|
||||
with (
|
||||
patch("prompt_toolkit.application.run_in_terminal", side_effect=lambda fn: fn()) as run_mock,
|
||||
patch("builtins.input", return_value="manual-value"),
|
||||
):
|
||||
result = HermesCLI._prompt_text_input(cli, "Enter value: ")
|
||||
|
||||
assert result == "manual-value"
|
||||
run_mock.assert_called_once()
|
||||
assert cli._status_bar_visible is True
|
||||
|
||||
|
||||
def test_should_handle_model_command_inline_uses_command_name_resolution():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
|
||||
with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="model")):
|
||||
assert HermesCLI._should_handle_model_command_inline(cli, "/model") is True
|
||||
|
||||
with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="help")):
|
||||
assert HermesCLI._should_handle_model_command_inline(cli, "/model") is False
|
||||
|
||||
assert HermesCLI._should_handle_model_command_inline(cli, "/model", has_images=True) is False
|
||||
|
||||
|
||||
def test_process_command_model_without_args_opens_modal_picker_and_captures_draft():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
providers = _make_providers()
|
||||
|
||||
with (
|
||||
patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers),
|
||||
patch("cli._cprint"),
|
||||
):
|
||||
result = cli.process_command("/model")
|
||||
|
||||
assert result is True
|
||||
assert cli._model_picker_state is not None
|
||||
assert cli._model_picker_state["stage"] == "provider"
|
||||
assert cli._model_picker_state["selected"] == 0
|
||||
assert cli._modal_input_snapshot == {"text": "draft text", "cursor_position": len("draft text")}
|
||||
assert cli._app.current_buffer.text == ""
|
||||
|
||||
|
||||
def test_model_picker_provider_then_model_selection_applies_switch_result_and_restores_draft():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
providers = _make_providers()
|
||||
|
||||
with (
|
||||
patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers),
|
||||
patch("cli._cprint"),
|
||||
):
|
||||
assert cli.process_command("/model") is True
|
||||
|
||||
cli._model_picker_state["selected"] = 1
|
||||
with patch("hermes_cli.models.provider_model_ids", return_value=["claude-opus-4.6", "claude-sonnet-4.6"]):
|
||||
HermesCLI._handle_model_picker_selection(cli)
|
||||
|
||||
assert cli._model_picker_state["stage"] == "model"
|
||||
assert cli._model_picker_state["provider_data"]["slug"] == "anthropic"
|
||||
assert cli._model_picker_state["model_list"] == ["claude-opus-4.6", "claude-sonnet-4.6"]
|
||||
|
||||
cli._model_picker_state["selected"] = 0
|
||||
switch_result = SimpleNamespace(
|
||||
success=True,
|
||||
error_message=None,
|
||||
new_model="claude-opus-4.6",
|
||||
target_provider="anthropic",
|
||||
api_key="",
|
||||
base_url="",
|
||||
api_mode="anthropic_messages",
|
||||
provider_label="Anthropic",
|
||||
model_info=None,
|
||||
warning_message=None,
|
||||
provider_changed=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("hermes_cli.model_switch.switch_model", return_value=switch_result) as switch_mock,
|
||||
patch("cli._cprint"),
|
||||
):
|
||||
HermesCLI._handle_model_picker_selection(cli)
|
||||
|
||||
assert cli._model_picker_state is None
|
||||
assert cli.model == "claude-opus-4.6"
|
||||
assert cli.provider == "anthropic"
|
||||
assert cli.requested_provider == "anthropic"
|
||||
assert cli._app.current_buffer.text == "draft text"
|
||||
switch_mock.assert_called_once()
|
||||
assert switch_mock.call_args.kwargs["explicit_provider"] == "anthropic"
|
||||
241
tests/hermes_cli/test_codex_cli_model_picker.py
Normal file
241
tests/hermes_cli/test_codex_cli_model_picker.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Regression test: openai-codex must appear in /model picker when
|
||||
credentials are only in the Codex CLI shared file (~/.codex/auth.json)
|
||||
and haven't been migrated to the Hermes auth store yet.
|
||||
|
||||
Root cause: list_authenticated_providers() checked the raw Hermes auth
|
||||
store but didn't know about the Codex CLI fallback import path.
|
||||
|
||||
Fix: _seed_from_singletons() now imports from the Codex CLI when the
|
||||
Hermes auth store has no openai-codex tokens, and
|
||||
list_authenticated_providers() falls back to load_pool() for OAuth
|
||||
providers.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_fake_jwt(expiry_offset: int = 3600) -> str:
|
||||
"""Build a fake JWT with a future expiry."""
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
|
||||
exp = int(time.time()) + expiry_offset
|
||||
payload_bytes = json.dumps({"exp": exp, "sub": "test"}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode()
|
||||
return f"{header}.{payload}.fakesig"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def codex_cli_only_env(tmp_path, monkeypatch):
|
||||
"""Set up an environment where Codex tokens exist only in ~/.codex/auth.json,
|
||||
NOT in the Hermes auth store."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
codex_home = tmp_path / ".codex"
|
||||
codex_home.mkdir()
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("CODEX_HOME", str(codex_home))
|
||||
|
||||
# Empty Hermes auth store
|
||||
(hermes_home / "auth.json").write_text(
|
||||
json.dumps({"version": 2, "providers": {}})
|
||||
)
|
||||
|
||||
# Valid Codex CLI tokens
|
||||
fake_jwt = _make_fake_jwt()
|
||||
(codex_home / "auth.json").write_text(
|
||||
json.dumps({
|
||||
"tokens": {
|
||||
"access_token": fake_jwt,
|
||||
"refresh_token": "fake-refresh-token",
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
# Clear provider env vars so only OAuth is a detection path
|
||||
for var in [
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"NOUS_API_KEY", "DEEPSEEK_API_KEY", "COPILOT_GITHUB_TOKEN",
|
||||
"GH_TOKEN", "GEMINI_API_KEY",
|
||||
]:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
return hermes_home
|
||||
|
||||
|
||||
def test_codex_cli_tokens_detected_by_model_picker(codex_cli_only_env):
|
||||
"""openai-codex should appear when tokens only exist in ~/.codex/auth.json."""
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="openai-codex",
|
||||
max_models=10,
|
||||
)
|
||||
slugs = [p["slug"] for p in providers]
|
||||
assert "openai-codex" in slugs, (
|
||||
f"openai-codex not found in /model picker providers: {slugs}"
|
||||
)
|
||||
|
||||
codex = next(p for p in providers if p["slug"] == "openai-codex")
|
||||
assert codex["is_current"] is True
|
||||
assert codex["total_models"] > 0
|
||||
|
||||
|
||||
def test_codex_cli_tokens_migrated_after_detection(codex_cli_only_env):
|
||||
"""After the /model picker detects Codex CLI tokens, they should be
|
||||
migrated into the Hermes auth store for subsequent fast lookups."""
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
# First call triggers migration
|
||||
list_authenticated_providers(current_provider="openai-codex")
|
||||
|
||||
# Verify tokens are now in Hermes auth store
|
||||
auth_path = codex_cli_only_env / "auth.json"
|
||||
store = json.loads(auth_path.read_text())
|
||||
providers = store.get("providers", {})
|
||||
assert "openai-codex" in providers, (
|
||||
f"openai-codex not migrated to Hermes auth store: {list(providers.keys())}"
|
||||
)
|
||||
tokens = providers["openai-codex"].get("tokens", {})
|
||||
assert tokens.get("access_token"), "access_token missing after migration"
|
||||
assert tokens.get("refresh_token"), "refresh_token missing after migration"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def hermes_auth_only_env(tmp_path, monkeypatch):
|
||||
"""Tokens already in Hermes auth store (no Codex CLI needed)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# Point CODEX_HOME to nonexistent dir to prove it's not needed
|
||||
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "no_codex"))
|
||||
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 2,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {
|
||||
"access_token": _make_fake_jwt(),
|
||||
"refresh_token": "fake-refresh",
|
||||
},
|
||||
"last_refresh": "2026-04-12T00:00:00Z",
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
for var in [
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"NOUS_API_KEY", "DEEPSEEK_API_KEY",
|
||||
]:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
return hermes_home
|
||||
|
||||
|
||||
def test_normal_path_still_works(hermes_auth_only_env):
|
||||
"""openai-codex appears when tokens are already in Hermes auth store."""
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="openai-codex",
|
||||
max_models=10,
|
||||
)
|
||||
slugs = [p["slug"] for p in providers]
|
||||
assert "openai-codex" in slugs
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def claude_code_only_env(tmp_path, monkeypatch):
|
||||
"""Set up an environment where Anthropic credentials only exist in
|
||||
~/.claude/.credentials.json (Claude Code) — not in env vars or Hermes
|
||||
auth store."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# No Codex CLI
|
||||
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "no_codex"))
|
||||
|
||||
(hermes_home / "auth.json").write_text(
|
||||
json.dumps({"version": 2, "providers": {}})
|
||||
)
|
||||
|
||||
# Claude Code credentials in the correct format
|
||||
claude_dir = tmp_path / ".claude"
|
||||
claude_dir.mkdir()
|
||||
(claude_dir / ".credentials.json").write_text(json.dumps({
|
||||
"claudeAiOauth": {
|
||||
"accessToken": _make_fake_jwt(),
|
||||
"refreshToken": "fake-refresh",
|
||||
"expiresAt": int(time.time() * 1000) + 3_600_000,
|
||||
}
|
||||
}))
|
||||
|
||||
# Patch Path.home() so the adapter finds the file
|
||||
monkeypatch.setattr(Path, "home", classmethod(lambda cls: tmp_path))
|
||||
|
||||
for var in [
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"NOUS_API_KEY", "DEEPSEEK_API_KEY",
|
||||
]:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
return hermes_home
|
||||
|
||||
|
||||
def test_claude_code_file_detected_by_model_picker(claude_code_only_env):
|
||||
"""anthropic should appear when credentials only exist in ~/.claude/.credentials.json."""
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="anthropic",
|
||||
max_models=10,
|
||||
)
|
||||
slugs = [p["slug"] for p in providers]
|
||||
assert "anthropic" in slugs, (
|
||||
f"anthropic not found in /model picker providers: {slugs}"
|
||||
)
|
||||
|
||||
anthropic = next(p for p in providers if p["slug"] == "anthropic")
|
||||
assert anthropic["is_current"] is True
|
||||
assert anthropic["total_models"] > 0
|
||||
|
||||
|
||||
def test_no_codex_when_no_credentials(tmp_path, monkeypatch):
|
||||
"""openai-codex should NOT appear when no credentials exist anywhere."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "no_codex"))
|
||||
|
||||
(hermes_home / "auth.json").write_text(
|
||||
json.dumps({"version": 2, "providers": {}})
|
||||
)
|
||||
|
||||
for var in [
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"NOUS_API_KEY", "DEEPSEEK_API_KEY", "COPILOT_GITHUB_TOKEN",
|
||||
"GH_TOKEN", "GEMINI_API_KEY",
|
||||
]:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="openrouter",
|
||||
max_models=10,
|
||||
)
|
||||
slugs = [p["slug"] for p in providers]
|
||||
assert "openai-codex" not in slugs, (
|
||||
"openai-codex should not appear without any credentials"
|
||||
)
|
||||
@@ -68,6 +68,7 @@ class TestLoadConfigDefaults:
|
||||
assert "max_turns" not in config
|
||||
assert "terminal" in config
|
||||
assert config["terminal"]["backend"] == "local"
|
||||
assert config["display"]["interim_assistant_messages"] is True
|
||||
|
||||
def test_legacy_root_level_max_turns_migrates_to_agent_config(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
@@ -421,3 +422,25 @@ class TestAnthropicTokenMigration:
|
||||
}):
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
assert load_env().get("ANTHROPIC_TOKEN") == "current-token"
|
||||
|
||||
|
||||
class TestInterimAssistantMessageConfig:
|
||||
"""Test the explicit gateway interim-message config gate."""
|
||||
|
||||
def test_default_config_enables_interim_assistant_messages(self):
|
||||
assert DEFAULT_CONFIG["display"]["interim_assistant_messages"] is True
|
||||
|
||||
def test_migrate_to_v15_adds_interim_assistant_message_gate(self, tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump({"_config_version": 14, "display": {"tool_progress": "off"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 16
|
||||
assert raw["display"]["tool_progress"] == "off"
|
||||
assert raw["display"]["interim_assistant_messages"] is True
|
||||
|
||||
342
tests/hermes_cli/test_container_aware_cli.py
Normal file
342
tests/hermes_cli/test_container_aware_cli.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Tests for container-aware CLI routing (NixOS container mode).
|
||||
|
||||
When container.enable = true in the NixOS module, the activation script
|
||||
writes a .container-mode metadata file. The host CLI detects this and
|
||||
execs into the container instead of running locally.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import (
|
||||
_is_inside_container,
|
||||
get_container_exec_info,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _is_inside_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_is_inside_container_dockerenv():
|
||||
"""Detects /.dockerenv marker file."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/.dockerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_containerenv():
|
||||
"""Detects Podman's /run/.containerenv marker."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/run/.containerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_cgroup_docker():
|
||||
"""Detects 'docker' in /proc/1/cgroup."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", create=True) as mock_open:
|
||||
mock_open.return_value.__enter__ = lambda s: s
|
||||
mock_open.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_open.return_value.read = MagicMock(
|
||||
return_value="12:memory:/docker/abc123\n"
|
||||
)
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_false_on_host():
|
||||
"""Returns False when none of the container indicators are present."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", side_effect=OSError("no such file")):
|
||||
assert _is_inside_container() is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# get_container_exec_info
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container_env(tmp_path, monkeypatch):
|
||||
"""Set up a fake HERMES_HOME with .container-mode file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("HERMES_DEV", raising=False)
|
||||
|
||||
container_mode = hermes_home / ".container-mode"
|
||||
container_mode.write_text(
|
||||
"# Written by NixOS activation script. Do not edit manually.\n"
|
||||
"backend=podman\n"
|
||||
"container_name=hermes-agent\n"
|
||||
"exec_user=hermes\n"
|
||||
"hermes_bin=/data/current-package/bin/hermes\n"
|
||||
)
|
||||
return hermes_home
|
||||
|
||||
|
||||
def test_get_container_exec_info_returns_metadata(container_env):
|
||||
"""Reads .container-mode and returns all fields including exec_user."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
assert info["backend"] == "podman"
|
||||
assert info["container_name"] == "hermes-agent"
|
||||
assert info["exec_user"] == "hermes"
|
||||
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_none_inside_container(container_env):
|
||||
"""Returns None when we're already inside a container."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=True):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch):
|
||||
"""Returns None when .container-mode doesn't exist (native mode)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("HERMES_DEV", raising=False)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_skipped_when_hermes_dev(container_env, monkeypatch):
|
||||
"""Returns None when HERMES_DEV=1 is set (dev mode bypass)."""
|
||||
monkeypatch.setenv("HERMES_DEV", "1")
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_not_skipped_when_hermes_dev_zero(container_env, monkeypatch):
|
||||
"""HERMES_DEV=0 does NOT trigger bypass — only '1' does."""
|
||||
monkeypatch.setenv("HERMES_DEV", "0")
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
|
||||
|
||||
def test_get_container_exec_info_defaults():
|
||||
"""Falls back to defaults for missing keys."""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
hermes_home = Path(tmpdir) / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / ".container-mode").write_text(
|
||||
"# minimal file with no keys\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_DEV", None)
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
assert info["backend"] == "docker"
|
||||
assert info["container_name"] == "hermes-agent"
|
||||
assert info["exec_user"] == "hermes"
|
||||
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_docker_backend(container_env):
|
||||
"""Correctly reads docker backend with custom exec_user."""
|
||||
(container_env / ".container-mode").write_text(
|
||||
"backend=docker\n"
|
||||
"container_name=hermes-custom\n"
|
||||
"exec_user=myuser\n"
|
||||
"hermes_bin=/opt/hermes/bin/hermes\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info["backend"] == "docker"
|
||||
assert info["container_name"] == "hermes-custom"
|
||||
assert info["exec_user"] == "myuser"
|
||||
assert info["hermes_bin"] == "/opt/hermes/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_crashes_on_permission_error(container_env):
|
||||
"""PermissionError propagates instead of being silently swallowed."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
patch("builtins.open", side_effect=PermissionError("permission denied")):
|
||||
with pytest.raises(PermissionError):
|
||||
get_container_exec_info()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _exec_in_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docker_container_info():
|
||||
return {
|
||||
"backend": "docker",
|
||||
"container_name": "hermes-agent",
|
||||
"exec_user": "hermes",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def podman_container_info():
|
||||
return {
|
||||
"backend": "podman",
|
||||
"container_name": "hermes-agent",
|
||||
"exec_user": "hermes",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
|
||||
def test_exec_in_container_calls_execvp(docker_container_info):
|
||||
"""Verifies os.execvp is called with correct args: runtime, tty flags,
|
||||
user, env vars, container name, binary, and CLI args."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("sys.stdin") as mock_stdin, \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
patch.dict(os.environ, {"TERM": "xterm-256color", "LANG": "en_US.UTF-8"},
|
||||
clear=False):
|
||||
mock_stdin.isatty.return_value = True
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
_exec_in_container(docker_container_info, ["chat", "-m", "opus"])
|
||||
|
||||
mock_execvp.assert_called_once()
|
||||
cmd = mock_execvp.call_args[0][1]
|
||||
assert cmd[0] == "/usr/bin/docker"
|
||||
assert cmd[1] == "exec"
|
||||
assert "-it" in cmd
|
||||
idx_u = cmd.index("-u")
|
||||
assert cmd[idx_u + 1] == "hermes"
|
||||
e_indices = [i for i, v in enumerate(cmd) if v == "-e"]
|
||||
e_values = [cmd[i + 1] for i in e_indices]
|
||||
assert "TERM=xterm-256color" in e_values
|
||||
assert "LANG=en_US.UTF-8" in e_values
|
||||
assert "hermes-agent" in cmd
|
||||
assert "/data/current-package/bin/hermes" in cmd
|
||||
assert "chat" in cmd
|
||||
|
||||
|
||||
def test_exec_in_container_non_tty_uses_i_only(docker_container_info):
|
||||
"""Non-TTY mode uses -i instead of -it."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("sys.stdin") as mock_stdin, \
|
||||
patch("os.execvp") as mock_execvp:
|
||||
mock_stdin.isatty.return_value = False
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
_exec_in_container(docker_container_info, ["sessions", "list"])
|
||||
|
||||
cmd = mock_execvp.call_args[0][1]
|
||||
assert "-i" in cmd
|
||||
assert "-it" not in cmd
|
||||
|
||||
|
||||
def test_exec_in_container_no_runtime_hard_fails(podman_container_info):
|
||||
"""Hard fails when runtime not found (no fallback)."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
_exec_in_container(podman_container_info, ["chat"])
|
||||
|
||||
mock_run.assert_not_called()
|
||||
mock_execvp.assert_not_called()
|
||||
assert exc_info.value.code != 0
|
||||
|
||||
|
||||
def test_exec_in_container_sudo_probe_sets_prefix(podman_container_info):
|
||||
"""When first probe fails and sudo probe succeeds, execvp is called
|
||||
with sudo -n prefix."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
def which_side_effect(name):
|
||||
if name == "podman":
|
||||
return "/usr/bin/podman"
|
||||
if name == "sudo":
|
||||
return "/usr/bin/sudo"
|
||||
return None
|
||||
|
||||
with patch("shutil.which", side_effect=which_side_effect), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("sys.stdin") as mock_stdin, \
|
||||
patch("os.execvp") as mock_execvp:
|
||||
mock_stdin.isatty.return_value = True
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=1), # direct probe fails
|
||||
MagicMock(returncode=0), # sudo probe succeeds
|
||||
]
|
||||
|
||||
_exec_in_container(podman_container_info, ["chat"])
|
||||
|
||||
mock_execvp.assert_called_once()
|
||||
cmd = mock_execvp.call_args[0][1]
|
||||
assert cmd[0] == "/usr/bin/sudo"
|
||||
assert cmd[1] == "-n"
|
||||
assert cmd[2] == "/usr/bin/podman"
|
||||
assert cmd[3] == "exec"
|
||||
|
||||
|
||||
def test_exec_in_container_probe_timeout_prints_message(docker_container_info):
|
||||
"""TimeoutExpired from probe produces a human-readable error, not a
|
||||
raw traceback."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run", side_effect=subprocess.TimeoutExpired(
|
||||
cmd=["docker", "inspect"], timeout=15)), \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
_exec_in_container(docker_container_info, ["chat"])
|
||||
|
||||
mock_execvp.assert_not_called()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
def test_exec_in_container_container_not_running_no_sudo(docker_container_info):
|
||||
"""When runtime exists but container not found and no sudo available,
|
||||
prints helpful error about root containers."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
def which_side_effect(name):
|
||||
if name == "docker":
|
||||
return "/usr/bin/docker"
|
||||
return None
|
||||
|
||||
with patch("shutil.which", side_effect=which_side_effect), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_run.return_value = MagicMock(returncode=1)
|
||||
|
||||
_exec_in_container(docker_container_info, ["chat"])
|
||||
|
||||
mock_execvp.assert_not_called()
|
||||
assert exc_info.value.code == 1
|
||||
@@ -122,3 +122,54 @@ class TestCustomProviderModelSwitch:
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert model["default"] == "model-X"
|
||||
|
||||
def test_api_mode_set_from_provider_info(self, config_home):
|
||||
"""When custom_providers entry has api_mode, it should be applied."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
provider_info = {
|
||||
"name": "Anthropic Proxy",
|
||||
"base_url": "https://proxy.example.com/anthropic",
|
||||
"api_key": "***",
|
||||
"model": "claude-3",
|
||||
"api_mode": "anthropic_messages",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["claude-3"]), \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="1"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert model.get("api_mode") == "anthropic_messages"
|
||||
|
||||
def test_api_mode_cleared_when_not_specified(self, config_home):
|
||||
"""When custom_providers entry has no api_mode, stale api_mode is removed."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
# Pre-seed a stale api_mode in config
|
||||
config_path = config_home / "config.yaml"
|
||||
config_path.write_text(yaml.dump({"model": {"api_mode": "anthropic_messages"}}))
|
||||
|
||||
provider_info = {
|
||||
"name": "My vLLM",
|
||||
"base_url": "https://vllm.example.com/v1",
|
||||
"api_key": "***",
|
||||
"model": "llama-3",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["llama-3"]), \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="1"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert "api_mode" not in model, "Stale api_mode should be removed"
|
||||
|
||||
@@ -1,288 +1,255 @@
|
||||
"""Tests for hermes_cli/logs.py — log viewing and filtering."""
|
||||
"""Tests for hermes_cli.logs — log viewing and filtering."""
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from datetime import datetime, timedelta
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.logs import (
|
||||
LOG_FILES,
|
||||
_extract_level,
|
||||
_extract_logger_name,
|
||||
_line_matches_component,
|
||||
_matches_filters,
|
||||
_parse_line_timestamp,
|
||||
_parse_since,
|
||||
_read_last_n_lines,
|
||||
list_logs,
|
||||
tail_log,
|
||||
_read_tail,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def log_dir(tmp_path, monkeypatch):
|
||||
"""Create a fake HERMES_HOME with a logs/ directory."""
|
||||
home = Path(os.environ["HERMES_HOME"])
|
||||
logs = home / "logs"
|
||||
logs.mkdir(parents=True, exist_ok=True)
|
||||
return logs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_log(log_dir):
|
||||
"""Write a realistic agent.log with mixed levels and sessions."""
|
||||
lines = textwrap.dedent("""\
|
||||
2026-04-05 10:00:00,000 INFO run_agent: conversation turn: session=sess_aaa model=claude provider=openrouter platform=cli history=0 msg='hello'
|
||||
2026-04-05 10:00:01,000 INFO run_agent: tool terminal completed (0.50s, 200 chars)
|
||||
2026-04-05 10:00:02,000 INFO run_agent: API call #1: model=claude provider=openrouter in=1000 out=200 total=1200 latency=1.5s
|
||||
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
|
||||
2026-04-05 10:00:04,000 INFO run_agent: conversation turn: session=sess_bbb model=gpt-5 provider=openai platform=telegram history=5 msg='fix bug'
|
||||
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
|
||||
2026-04-05 10:00:06,000 INFO run_agent: tool read_file completed (0.01s, 500 chars)
|
||||
2026-04-05 10:00:07,000 DEBUG run_agent: verbose internal detail
|
||||
2026-04-05 10:00:08,000 INFO credential_pool: credential pool: marking key-1 exhausted (status=429), rotating
|
||||
2026-04-05 10:00:09,000 INFO credential_pool: credential pool: rotated to key-2
|
||||
""")
|
||||
path = log_dir / "agent.log"
|
||||
path.write_text(lines)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_errors_log(log_dir):
|
||||
"""Write a small errors.log."""
|
||||
lines = textwrap.dedent("""\
|
||||
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
|
||||
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
|
||||
""")
|
||||
path = log_dir / "errors.log"
|
||||
path.write_text(lines)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_since
|
||||
# Timestamp parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseSince:
|
||||
def test_hours(self):
|
||||
cutoff = _parse_since("2h")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(7200, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 7200) < 2
|
||||
|
||||
def test_minutes(self):
|
||||
cutoff = _parse_since("30m")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(1800, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 1800) < 2
|
||||
|
||||
def test_days(self):
|
||||
cutoff = _parse_since("1d")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(86400, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 86400) < 2
|
||||
|
||||
def test_seconds(self):
|
||||
cutoff = _parse_since("60s")
|
||||
cutoff = _parse_since("120s")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(60, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 120) < 2
|
||||
|
||||
def test_invalid_returns_none(self):
|
||||
assert _parse_since("abc") is None
|
||||
assert _parse_since("") is None
|
||||
assert _parse_since("10x") is None
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
cutoff = _parse_since(" 1h ")
|
||||
def test_whitespace_tolerance(self):
|
||||
cutoff = _parse_since(" 5m ")
|
||||
assert cutoff is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_line_timestamp
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseLineTimestamp:
|
||||
def test_standard_format(self):
|
||||
ts = _parse_line_timestamp("2026-04-05 10:00:00,123 INFO something")
|
||||
assert ts is not None
|
||||
assert ts.year == 2026
|
||||
assert ts.hour == 10
|
||||
ts = _parse_line_timestamp("2026-04-11 10:23:45 INFO gateway.run: msg")
|
||||
assert ts == datetime(2026, 4, 11, 10, 23, 45)
|
||||
|
||||
def test_no_timestamp(self):
|
||||
assert _parse_line_timestamp("just some text") is None
|
||||
assert _parse_line_timestamp("no timestamp here") is None
|
||||
|
||||
def test_continuation_line(self):
|
||||
assert _parse_line_timestamp(" at module.function (line 42)") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_level
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLevel:
|
||||
def test_info(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 INFO run_agent: something") == "INFO"
|
||||
assert _extract_level("2026-01-01 00:00:00 INFO gateway.run: msg") == "INFO"
|
||||
|
||||
def test_warning(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 WARNING run_agent: bad") == "WARNING"
|
||||
assert _extract_level("2026-01-01 00:00:00 WARNING tools.file: msg") == "WARNING"
|
||||
|
||||
def test_error(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 ERROR run_agent: crash") == "ERROR"
|
||||
assert _extract_level("2026-01-01 00:00:00 ERROR run_agent: msg") == "ERROR"
|
||||
|
||||
def test_debug(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 DEBUG run_agent: detail") == "DEBUG"
|
||||
assert _extract_level("2026-01-01 00:00:00 DEBUG agent.aux: msg") == "DEBUG"
|
||||
|
||||
def test_no_level(self):
|
||||
assert _extract_level("just a plain line") is None
|
||||
assert _extract_level("random text") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _matches_filters
|
||||
# Logger name extraction (new for component filtering)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLoggerName:
|
||||
def test_standard_line(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.run: Starting gateway"
|
||||
assert _extract_logger_name(line) == "gateway.run"
|
||||
|
||||
def test_nested_logger(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: connected"
|
||||
assert _extract_logger_name(line) == "gateway.platforms.telegram"
|
||||
|
||||
def test_warning_level(self):
|
||||
line = "2026-04-11 10:23:45 WARNING tools.terminal_tool: timeout"
|
||||
assert _extract_logger_name(line) == "tools.terminal_tool"
|
||||
|
||||
def test_with_session_tag(self):
|
||||
line = "2026-04-11 10:23:45 INFO [abc123] tools.file_tools: reading file"
|
||||
assert _extract_logger_name(line) == "tools.file_tools"
|
||||
|
||||
def test_with_session_tag_and_error(self):
|
||||
line = "2026-04-11 10:23:45 ERROR [sess_xyz] agent.context_compressor: failed"
|
||||
assert _extract_logger_name(line) == "agent.context_compressor"
|
||||
|
||||
def test_top_level_module(self):
|
||||
line = "2026-04-11 10:23:45 INFO run_agent: starting conversation"
|
||||
assert _extract_logger_name(line) == "run_agent"
|
||||
|
||||
def test_no_match(self):
|
||||
assert _extract_logger_name("random text") is None
|
||||
|
||||
|
||||
class TestLineMatchesComponent:
|
||||
def test_gateway_component(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.run: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_gateway_nested(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_tools_component(self):
|
||||
line = "2026-04-11 10:23:45 INFO tools.terminal_tool: msg"
|
||||
assert _line_matches_component(line, ("tools",))
|
||||
|
||||
def test_agent_with_multiple_prefixes(self):
|
||||
prefixes = ("agent", "run_agent", "model_tools")
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO agent.context_compressor: msg", prefixes)
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO run_agent: msg", prefixes)
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO model_tools: msg", prefixes)
|
||||
|
||||
def test_no_match(self):
|
||||
line = "2026-04-11 10:23:45 INFO tools.browser: msg"
|
||||
assert not _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_with_session_tag(self):
|
||||
line = "2026-04-11 10:23:45 INFO [abc] gateway.run: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_unparseable_line(self):
|
||||
assert not _line_matches_component("random text", ("gateway",))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined filter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatchesFilters:
|
||||
def test_no_filters_always_matches(self):
|
||||
assert _matches_filters("any line") is True
|
||||
def test_no_filters_passes_everything(self):
|
||||
assert _matches_filters("any line")
|
||||
|
||||
def test_level_filter_passes(self):
|
||||
def test_level_filter(self):
|
||||
assert _matches_filters(
|
||||
"2026-04-05 10:00:00 WARNING something",
|
||||
min_level="WARNING",
|
||||
) is True
|
||||
"2026-01-01 00:00:00 WARNING x: msg", min_level="WARNING")
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO x: msg", min_level="WARNING")
|
||||
|
||||
def test_level_filter_rejects(self):
|
||||
def test_session_filter(self):
|
||||
assert _matches_filters(
|
||||
"2026-04-05 10:00:00 INFO something",
|
||||
min_level="WARNING",
|
||||
) is False
|
||||
"2026-01-01 00:00:00 INFO [abc123] x: msg", session_filter="abc123")
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO [xyz789] x: msg", session_filter="abc123")
|
||||
|
||||
def test_session_filter_passes(self):
|
||||
def test_component_filter(self):
|
||||
assert _matches_filters(
|
||||
"session=sess_aaa model=claude",
|
||||
session_filter="sess_aaa",
|
||||
) is True
|
||||
|
||||
def test_session_filter_rejects(self):
|
||||
assert _matches_filters(
|
||||
"session=sess_aaa model=claude",
|
||||
session_filter="sess_bbb",
|
||||
) is False
|
||||
|
||||
def test_since_filter_passes(self):
|
||||
# Line from the future should always pass
|
||||
assert _matches_filters(
|
||||
"2099-01-01 00:00:00 INFO future",
|
||||
since=datetime.now(),
|
||||
) is True
|
||||
|
||||
def test_since_filter_rejects(self):
|
||||
assert _matches_filters(
|
||||
"2020-01-01 00:00:00 INFO past",
|
||||
since=datetime.now(),
|
||||
) is False
|
||||
"2026-01-01 00:00:00 INFO gateway.run: msg",
|
||||
component_prefixes=("gateway",))
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO tools.file: msg",
|
||||
component_prefixes=("gateway",))
|
||||
|
||||
def test_combined_filters(self):
|
||||
line = "2099-01-01 00:00:00 WARNING run_agent: session=abc error"
|
||||
"""All filters must pass for a line to match."""
|
||||
line = "2026-04-11 10:00:00 WARNING [sess_1] gateway.run: connection lost"
|
||||
assert _matches_filters(
|
||||
line, min_level="WARNING", session_filter="abc",
|
||||
since=datetime.now(),
|
||||
) is True
|
||||
# Fails session filter
|
||||
line,
|
||||
min_level="WARNING",
|
||||
session_filter="sess_1",
|
||||
component_prefixes=("gateway",),
|
||||
)
|
||||
# Fails component filter
|
||||
assert not _matches_filters(
|
||||
line,
|
||||
min_level="WARNING",
|
||||
session_filter="sess_1",
|
||||
component_prefixes=("tools",),
|
||||
)
|
||||
|
||||
def test_since_filter(self):
|
||||
# Line with a very old timestamp should be filtered out
|
||||
assert not _matches_filters(
|
||||
"2020-01-01 00:00:00 INFO x: old msg",
|
||||
since=datetime.now() - timedelta(hours=1))
|
||||
# Line with a recent timestamp should pass
|
||||
recent = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert _matches_filters(
|
||||
line, min_level="WARNING", session_filter="xyz",
|
||||
) is False
|
||||
f"{recent} INFO x: recent msg",
|
||||
since=datetime.now() - timedelta(hours=1))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_last_n_lines
|
||||
# File reading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadLastNLines:
|
||||
def test_reads_correct_count(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 3)
|
||||
assert len(lines) == 3
|
||||
class TestReadTail:
|
||||
def test_read_small_file(self, tmp_path):
|
||||
log_file = tmp_path / "test.log"
|
||||
lines = [f"2026-01-01 00:00:0{i} INFO x: line {i}\n" for i in range(10)]
|
||||
log_file.write_text("".join(lines))
|
||||
|
||||
def test_reads_all_when_fewer(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 100)
|
||||
assert len(lines) == 10 # sample has 10 lines
|
||||
result = _read_last_n_lines(log_file, 5)
|
||||
assert len(result) == 5
|
||||
assert "line 9" in result[-1]
|
||||
|
||||
def test_empty_file(self, log_dir):
|
||||
empty = log_dir / "empty.log"
|
||||
empty.write_text("")
|
||||
lines = _read_last_n_lines(empty, 10)
|
||||
assert lines == []
|
||||
def test_read_with_component_filter(self, tmp_path):
|
||||
log_file = tmp_path / "test.log"
|
||||
lines = [
|
||||
"2026-01-01 00:00:00 INFO gateway.run: gw msg\n",
|
||||
"2026-01-01 00:00:01 INFO tools.file: tool msg\n",
|
||||
"2026-01-01 00:00:02 INFO gateway.session: session msg\n",
|
||||
"2026-01-01 00:00:03 INFO agent.compressor: agent msg\n",
|
||||
]
|
||||
log_file.write_text("".join(lines))
|
||||
|
||||
def test_last_line_content(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 1)
|
||||
assert "rotated to key-2" in lines[0]
|
||||
result = _read_tail(
|
||||
log_file, 50,
|
||||
has_filters=True,
|
||||
component_prefixes=("gateway",),
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert "gw msg" in result[0]
|
||||
assert "session msg" in result[1]
|
||||
|
||||
def test_empty_file(self, tmp_path):
|
||||
log_file = tmp_path / "empty.log"
|
||||
log_file.write_text("")
|
||||
result = _read_last_n_lines(log_file, 10)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tail_log
|
||||
# LOG_FILES registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTailLog:
|
||||
def test_basic_tail(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=3)
|
||||
captured = capsys.readouterr()
|
||||
assert "agent.log" in captured.out
|
||||
# Should have the header + 3 lines
|
||||
lines = captured.out.strip().split("\n")
|
||||
assert len(lines) == 4 # 1 header + 3 content
|
||||
|
||||
def test_level_filter(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=50, level="ERROR")
|
||||
captured = capsys.readouterr()
|
||||
assert "level>=ERROR" in captured.out
|
||||
# Only the ERROR line should appear
|
||||
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
|
||||
assert len(content_lines) == 1
|
||||
assert "API call failed" in content_lines[0]
|
||||
|
||||
def test_session_filter(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=50, session="sess_bbb")
|
||||
captured = capsys.readouterr()
|
||||
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
|
||||
assert len(content_lines) == 1
|
||||
assert "sess_bbb" in content_lines[0]
|
||||
|
||||
def test_errors_log(self, sample_errors_log, capsys):
|
||||
tail_log("errors", num_lines=10)
|
||||
captured = capsys.readouterr()
|
||||
assert "errors.log" in captured.out
|
||||
assert "WARNING" in captured.out or "ERROR" in captured.out
|
||||
|
||||
def test_unknown_log_exits(self):
|
||||
with pytest.raises(SystemExit):
|
||||
tail_log("nonexistent")
|
||||
|
||||
def test_missing_file_exits(self, log_dir):
|
||||
with pytest.raises(SystemExit):
|
||||
tail_log("agent") # agent.log doesn't exist in clean log_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListLogs:
|
||||
def test_lists_files(self, sample_agent_log, sample_errors_log, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
assert "agent.log" in captured.out
|
||||
assert "errors.log" in captured.out
|
||||
|
||||
def test_empty_dir(self, log_dir, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
assert "no log files yet" in captured.out
|
||||
|
||||
def test_shows_sizes(self, sample_agent_log, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
# File is small, should show as bytes or KB
|
||||
assert "B" in captured.out or "KB" in captured.out
|
||||
class TestLogFiles:
|
||||
def test_known_log_files(self):
|
||||
assert "agent" in LOG_FILES
|
||||
assert "errors" in LOG_FILES
|
||||
assert "gateway" in LOG_FILES
|
||||
|
||||
@@ -46,6 +46,8 @@ def _make_args(**kwargs):
|
||||
"command": None,
|
||||
"args": None,
|
||||
"auth": None,
|
||||
"preset": None,
|
||||
"env": None,
|
||||
"mcp_action": None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
@@ -269,6 +271,145 @@ class TestMcpAdd:
|
||||
config = load_config()
|
||||
assert config["mcp_servers"]["broken"]["enabled"] is False
|
||||
|
||||
def test_add_stdio_server_with_env(self, tmp_path, capsys, monkeypatch):
|
||||
"""Stdio servers can persist explicit environment variables."""
|
||||
fake_tools = [FakeTool("search", "Search repos")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
assert config["env"] == {
|
||||
"MY_API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
}
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _: "")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="github",
|
||||
command="npx",
|
||||
args=["@mcp/github"],
|
||||
env=["MY_API_KEY=secret123", "DEBUG=true"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
srv = config["mcp_servers"]["github"]
|
||||
assert srv["env"] == {
|
||||
"MY_API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
}
|
||||
|
||||
def test_add_stdio_server_rejects_invalid_env_name(self, capsys):
|
||||
"""Invalid environment variable names are rejected up front."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="github",
|
||||
command="npx",
|
||||
args=["@mcp/github"],
|
||||
env=["BAD-NAME=value"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Invalid --env variable name" in out
|
||||
|
||||
def test_add_http_server_rejects_env_flag(self, capsys):
|
||||
"""The --env flag is only valid for stdio transports."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="ink",
|
||||
url="https://mcp.ml.ink/mcp",
|
||||
env=["DEBUG=true"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "only supported for stdio MCP servers" in out
|
||||
|
||||
def test_add_preset_fills_transport(self, tmp_path, capsys, monkeypatch):
|
||||
"""A preset fills in command/args when no explicit transport given."""
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._MCP_PRESETS",
|
||||
{"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}},
|
||||
)
|
||||
fake_tools = [FakeTool("do_thing", "Does a thing")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
assert name == "myserver"
|
||||
assert config["command"] == "npx"
|
||||
assert config["args"] == ["-y", "test-mcp-server"]
|
||||
assert "env" not in config
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _: "")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
from hermes_cli.config import read_raw_config
|
||||
|
||||
cmd_mcp_add(_make_args(name="myserver", preset="testmcp"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
config = read_raw_config()
|
||||
srv = config["mcp_servers"]["myserver"]
|
||||
assert srv["command"] == "npx"
|
||||
assert srv["args"] == ["-y", "test-mcp-server"]
|
||||
assert "env" not in srv
|
||||
|
||||
def test_preset_does_not_override_explicit_command(self, tmp_path, capsys, monkeypatch):
|
||||
"""Explicit transports win over presets."""
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._MCP_PRESETS",
|
||||
{"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}},
|
||||
)
|
||||
fake_tools = [FakeTool("search", "Search repos")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
assert config["command"] == "uvx"
|
||||
assert config["args"] == ["custom-server"]
|
||||
assert "env" not in config
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _: "")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
from hermes_cli.config import read_raw_config
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="custom",
|
||||
preset="testmcp",
|
||||
command="uvx",
|
||||
args=["custom-server"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
config = read_raw_config()
|
||||
srv = config["mcp_servers"]["custom"]
|
||||
assert srv["command"] == "uvx"
|
||||
assert srv["args"] == ["custom-server"]
|
||||
assert "env" not in srv
|
||||
|
||||
def test_unknown_preset_rejected(self, capsys):
|
||||
"""An unknown preset name is rejected with a clear error."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="foo", preset="nonexistent"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Unknown MCP preset" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_test
|
||||
|
||||
@@ -257,3 +257,76 @@ class TestProviderPersistsAfterModelSave:
|
||||
assert model.get("provider") == "opencode-go"
|
||||
assert model.get("default") == "minimax-m2.5"
|
||||
assert model.get("api_mode") == "anthropic_messages"
|
||||
|
||||
|
||||
class TestBaseUrlValidation:
|
||||
"""Reject non-URL values in the base URL prompt (e.g. shell commands)."""
|
||||
|
||||
def test_invalid_base_url_rejected(self, config_home, monkeypatch, capsys):
|
||||
"""Typing a non-URL string should not be saved as the base URL."""
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
|
||||
pconfig = PROVIDER_REGISTRY.get("zai")
|
||||
if not pconfig:
|
||||
pytest.skip("zai not in PROVIDER_REGISTRY")
|
||||
|
||||
monkeypatch.setenv("GLM_API_KEY", "test-key")
|
||||
|
||||
from hermes_cli.main import _model_flow_api_key_provider
|
||||
from hermes_cli.config import load_config, get_env_value
|
||||
|
||||
# User types a shell command instead of a URL at the base URL prompt
|
||||
with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \
|
||||
patch("hermes_cli.auth.deactivate_provider"), \
|
||||
patch("builtins.input", return_value="nano ~/.hermes/.env"):
|
||||
_model_flow_api_key_provider(load_config(), "zai", "old-model")
|
||||
|
||||
# The garbage value should NOT have been saved
|
||||
saved = get_env_value("GLM_BASE_URL") or ""
|
||||
assert not saved or saved.startswith(("http://", "https://")), \
|
||||
f"Non-URL value was saved as GLM_BASE_URL: {saved}"
|
||||
captured = capsys.readouterr()
|
||||
assert "Invalid URL" in captured.out
|
||||
|
||||
def test_valid_base_url_accepted(self, config_home, monkeypatch):
|
||||
"""A proper URL should be saved normally."""
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
|
||||
pconfig = PROVIDER_REGISTRY.get("zai")
|
||||
if not pconfig:
|
||||
pytest.skip("zai not in PROVIDER_REGISTRY")
|
||||
|
||||
monkeypatch.setenv("GLM_API_KEY", "test-key")
|
||||
|
||||
from hermes_cli.main import _model_flow_api_key_provider
|
||||
from hermes_cli.config import load_config, get_env_value
|
||||
|
||||
with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \
|
||||
patch("hermes_cli.auth.deactivate_provider"), \
|
||||
patch("builtins.input", return_value="https://custom.z.ai/api/paas/v4"):
|
||||
_model_flow_api_key_provider(load_config(), "zai", "old-model")
|
||||
|
||||
saved = get_env_value("GLM_BASE_URL") or ""
|
||||
assert saved == "https://custom.z.ai/api/paas/v4"
|
||||
|
||||
def test_empty_base_url_keeps_default(self, config_home, monkeypatch):
|
||||
"""Pressing Enter (empty) should not change the base URL."""
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
|
||||
pconfig = PROVIDER_REGISTRY.get("zai")
|
||||
if not pconfig:
|
||||
pytest.skip("zai not in PROVIDER_REGISTRY")
|
||||
|
||||
monkeypatch.setenv("GLM_API_KEY", "test-key")
|
||||
monkeypatch.delenv("GLM_BASE_URL", raising=False)
|
||||
|
||||
from hermes_cli.main import _model_flow_api_key_provider
|
||||
from hermes_cli.config import load_config, get_env_value
|
||||
|
||||
with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \
|
||||
patch("hermes_cli.auth.deactivate_provider"), \
|
||||
patch("builtins.input", return_value=""):
|
||||
_model_flow_api_key_provider(load_config(), "zai", "old-model")
|
||||
|
||||
saved = get_env_value("GLM_BASE_URL") or ""
|
||||
assert saved == "", "Empty input should not save a base URL"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from rich.console import Console
|
||||
|
||||
from cli import ChatConsole
|
||||
from hermes_cli.skills_hub import do_check, do_install, do_list, do_update, handle_skills_slash
|
||||
|
||||
|
||||
@@ -179,6 +181,21 @@ def test_do_update_reinstalls_outdated_skills(monkeypatch):
|
||||
assert "Updated 1 skill" in output
|
||||
|
||||
|
||||
def test_handle_skills_slash_search_accepts_chatconsole_without_status_errors():
|
||||
results = [type("R", (), {
|
||||
"name": "kubernetes",
|
||||
"description": "Cluster orchestration",
|
||||
"source": "skills.sh",
|
||||
"trust_level": "community",
|
||||
"identifier": "skills-sh/example/kubernetes",
|
||||
})()]
|
||||
|
||||
with patch("tools.skills_hub.unified_search", return_value=results), \
|
||||
patch("tools.skills_hub.create_source_router", return_value={}), \
|
||||
patch("tools.skills_hub.GitHubAuth"):
|
||||
handle_skills_slash("/skills search kubernetes", console=ChatConsole())
|
||||
|
||||
|
||||
def test_do_install_scans_with_resolved_identifier(monkeypatch, tmp_path, hub_env):
|
||||
import tools.skills_guard as guard
|
||||
import tools.skills_hub as hub
|
||||
|
||||
77
tests/hermes_cli/test_tips.py
Normal file
77
tests/hermes_cli/test_tips.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Tests for hermes_cli/tips.py — random tip display at session start."""
|
||||
|
||||
import pytest
|
||||
from hermes_cli.tips import TIPS, get_random_tip, get_tip_count
|
||||
|
||||
|
||||
class TestTipsCorpus:
|
||||
"""Validate the tip corpus itself."""
|
||||
|
||||
def test_has_at_least_200_tips(self):
|
||||
assert len(TIPS) >= 200, f"Expected 200+ tips, got {len(TIPS)}"
|
||||
|
||||
def test_no_duplicates(self):
|
||||
assert len(TIPS) == len(set(TIPS)), "Duplicate tips found"
|
||||
|
||||
def test_all_tips_are_strings(self):
|
||||
for i, tip in enumerate(TIPS):
|
||||
assert isinstance(tip, str), f"Tip {i} is not a string: {type(tip)}"
|
||||
|
||||
def test_no_empty_tips(self):
|
||||
for i, tip in enumerate(TIPS):
|
||||
assert tip.strip(), f"Tip {i} is empty or whitespace-only"
|
||||
|
||||
def test_max_length_reasonable(self):
|
||||
"""Tips should fit on a single terminal line (~120 chars max)."""
|
||||
for i, tip in enumerate(TIPS):
|
||||
assert len(tip) <= 150, (
|
||||
f"Tip {i} too long ({len(tip)} chars): {tip[:60]}..."
|
||||
)
|
||||
|
||||
def test_no_leading_trailing_whitespace(self):
|
||||
for i, tip in enumerate(TIPS):
|
||||
assert tip == tip.strip(), f"Tip {i} has leading/trailing whitespace"
|
||||
|
||||
|
||||
class TestGetRandomTip:
|
||||
"""Validate the get_random_tip() function."""
|
||||
|
||||
def test_returns_string(self):
|
||||
tip = get_random_tip()
|
||||
assert isinstance(tip, str)
|
||||
assert len(tip) > 0
|
||||
|
||||
def test_returns_tip_from_corpus(self):
|
||||
tip = get_random_tip()
|
||||
assert tip in TIPS
|
||||
|
||||
def test_randomness(self):
|
||||
"""Multiple calls should eventually return different tips."""
|
||||
seen = set()
|
||||
for _ in range(50):
|
||||
seen.add(get_random_tip())
|
||||
# With 200+ tips and 50 draws, we should see at least 10 unique
|
||||
assert len(seen) >= 10, f"Only got {len(seen)} unique tips in 50 draws"
|
||||
|
||||
|
||||
class TestGetTipCount:
|
||||
def test_matches_corpus_length(self):
|
||||
assert get_tip_count() == len(TIPS)
|
||||
|
||||
|
||||
class TestTipIntegrationInCLI:
|
||||
"""Test that the tip display code in cli.py works correctly."""
|
||||
|
||||
def test_tip_import_works(self):
|
||||
"""The import used in cli.py must succeed."""
|
||||
from hermes_cli.tips import get_random_tip
|
||||
assert callable(get_random_tip)
|
||||
|
||||
def test_tip_display_format(self):
|
||||
"""Verify the Rich markup format doesn't break."""
|
||||
tip = get_random_tip()
|
||||
color = "#B8860B"
|
||||
markup = f"[dim {color}]✦ Tip: {tip}[/]"
|
||||
# Should not contain nested/broken Rich tags
|
||||
assert markup.count("[/]") == 1
|
||||
assert "[dim #B8860B]" in markup
|
||||
@@ -798,3 +798,120 @@ class TestFindGatewayPidsExclude:
|
||||
pids = gateway_cli.find_gateway_pids()
|
||||
|
||||
assert pids == [100]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gateway mode writes exit code before restart (#8300)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGatewayModeWritesExitCodeEarly:
|
||||
"""When running as ``hermes update --gateway``, the exit code marker must be
|
||||
written *before* the gateway restart attempt. Without this, systemd's
|
||||
``KillMode=mixed`` kills the update process (and its wrapping shell) during
|
||||
the cgroup teardown, so the shell epilogue that normally writes the exit
|
||||
code never executes. The new gateway's update watcher then polls for 30
|
||||
minutes and sends a spurious timeout message.
|
||||
"""
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_exit_code_written_in_gateway_mode(
|
||||
self, mock_run, _mock_which, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
# Point HERMES_HOME at a temp dir so the marker file lands there
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
import hermes_cli.config as _cfg
|
||||
monkeypatch.setattr(_cfg, "get_hermes_home", lambda: hermes_home)
|
||||
# Also patch the module-level ref used by cmd_update
|
||||
import hermes_cli.main as _main_mod
|
||||
monkeypatch.setattr(_main_mod, "get_hermes_home", lambda: hermes_home)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(commit_count="1")
|
||||
|
||||
args = SimpleNamespace(gateway=True)
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(args)
|
||||
|
||||
exit_code_path = hermes_home / ".update_exit_code"
|
||||
assert exit_code_path.exists(), ".update_exit_code not written in gateway mode"
|
||||
assert exit_code_path.read_text() == "0"
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_exit_code_not_written_in_normal_mode(
|
||||
self, mock_run, _mock_which, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""Non-gateway mode should NOT write the exit code (the shell does it)."""
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
import hermes_cli.config as _cfg
|
||||
monkeypatch.setattr(_cfg, "get_hermes_home", lambda: hermes_home)
|
||||
import hermes_cli.main as _main_mod
|
||||
monkeypatch.setattr(_main_mod, "get_hermes_home", lambda: hermes_home)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(commit_count="1")
|
||||
|
||||
args = SimpleNamespace(gateway=False)
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(args)
|
||||
|
||||
exit_code_path = hermes_home / ".update_exit_code"
|
||||
assert not exit_code_path.exists(), ".update_exit_code should not be written outside gateway mode"
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_exit_code_written_before_restart_call(
|
||||
self, mock_run, _mock_which, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""Exit code must exist BEFORE systemctl restart is called."""
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
import hermes_cli.config as _cfg
|
||||
monkeypatch.setattr(_cfg, "get_hermes_home", lambda: hermes_home)
|
||||
import hermes_cli.main as _main_mod
|
||||
monkeypatch.setattr(_main_mod, "get_hermes_home", lambda: hermes_home)
|
||||
|
||||
exit_code_path = hermes_home / ".update_exit_code"
|
||||
|
||||
# Track whether exit code exists when systemctl restart is called
|
||||
exit_code_existed_at_restart = []
|
||||
|
||||
original_side_effect = _make_run_side_effect(
|
||||
commit_count="1", systemd_active=True,
|
||||
)
|
||||
|
||||
def tracking_side_effect(cmd, **kwargs):
|
||||
joined = " ".join(str(c) for c in cmd)
|
||||
if "systemctl" in joined and "restart" in joined:
|
||||
exit_code_existed_at_restart.append(exit_code_path.exists())
|
||||
return original_side_effect(cmd, **kwargs)
|
||||
|
||||
mock_run.side_effect = tracking_side_effect
|
||||
|
||||
args = SimpleNamespace(gateway=True)
|
||||
|
||||
with patch.object(gateway_cli, "find_gateway_pids", return_value=[]):
|
||||
cmd_update(args)
|
||||
|
||||
assert exit_code_existed_at_restart, "systemctl restart was never called"
|
||||
assert exit_code_existed_at_restart[0] is True, \
|
||||
".update_exit_code must exist BEFORE systemctl restart (cgroup kill race)"
|
||||
|
||||
@@ -26,6 +26,7 @@ def _make_agent(
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-test"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.quiet_mode = True
|
||||
agent.log_prefix = ""
|
||||
agent.compression_enabled = compression_enabled
|
||||
@@ -99,6 +100,36 @@ def test_no_warning_when_aux_context_sufficient(mock_get_client, mock_ctx_len):
|
||||
assert agent._compression_warning is None
|
||||
|
||||
|
||||
def test_feasibility_check_passes_live_main_runtime():
|
||||
"""Compression feasibility should probe using the live session runtime."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
agent.model = "gpt-5.4"
|
||||
agent.provider = "openai-codex"
|
||||
agent.base_url = "https://chatgpt.com/backend-api/codex"
|
||||
agent.api_key = "codex-token"
|
||||
agent.api_mode = "codex_responses"
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://chatgpt.com/backend-api/codex"
|
||||
mock_client.api_key = "codex-token"
|
||||
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_client, "gpt-5.4")) as mock_get_client, \
|
||||
patch("agent.model_metadata.get_model_context_length", return_value=200_000):
|
||||
agent._emit_status = lambda msg: None
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
mock_get_client.assert_called_once_with(
|
||||
"compression",
|
||||
main_runtime={
|
||||
"model": "gpt-5.4",
|
||||
"provider": "openai-codex",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "codex-token",
|
||||
"api_mode": "codex_responses",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_warns_when_no_auxiliary_provider(mock_get_client):
|
||||
"""Warning emitted when no auxiliary provider is configured."""
|
||||
|
||||
@@ -2742,74 +2742,12 @@ class TestSystemPromptStability:
|
||||
assert "Hermes Agent" in agent._cached_system_prompt
|
||||
|
||||
class TestBudgetPressure:
|
||||
"""Budget pressure warning system (issue #414)."""
|
||||
"""Budget exhaustion grace call system."""
|
||||
|
||||
def test_no_warning_below_caution(self, agent):
|
||||
agent.max_iterations = 60
|
||||
assert agent._get_budget_warning(30) is None
|
||||
|
||||
def test_caution_at_70_percent(self, agent):
|
||||
agent.max_iterations = 60
|
||||
msg = agent._get_budget_warning(42)
|
||||
assert msg is not None
|
||||
assert "[BUDGET:" in msg
|
||||
assert "18 iterations left" in msg
|
||||
|
||||
def test_warning_at_90_percent(self, agent):
|
||||
agent.max_iterations = 60
|
||||
msg = agent._get_budget_warning(54)
|
||||
assert "[BUDGET WARNING:" in msg
|
||||
assert "Provide your final response NOW" in msg
|
||||
|
||||
def test_last_iteration(self, agent):
|
||||
agent.max_iterations = 60
|
||||
msg = agent._get_budget_warning(59)
|
||||
assert "1 iteration(s) left" in msg
|
||||
|
||||
def test_disabled(self, agent):
|
||||
agent.max_iterations = 60
|
||||
agent._budget_pressure_enabled = False
|
||||
assert agent._get_budget_warning(55) is None
|
||||
|
||||
def test_zero_max_iterations(self, agent):
|
||||
agent.max_iterations = 0
|
||||
assert agent._get_budget_warning(0) is None
|
||||
|
||||
def test_injects_into_json_tool_result(self, agent):
|
||||
"""Warning should be injected as _budget_warning field in JSON tool results."""
|
||||
import json
|
||||
agent.max_iterations = 10
|
||||
messages = [
|
||||
{"role": "tool", "content": json.dumps({"output": "done", "exit_code": 0}), "tool_call_id": "tc1"}
|
||||
]
|
||||
warning = agent._get_budget_warning(9)
|
||||
assert warning is not None
|
||||
# Simulate the injection logic
|
||||
last_content = messages[-1]["content"]
|
||||
parsed = json.loads(last_content)
|
||||
parsed["_budget_warning"] = warning
|
||||
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
result = json.loads(messages[-1]["content"])
|
||||
assert "_budget_warning" in result
|
||||
assert "BUDGET WARNING" in result["_budget_warning"]
|
||||
assert result["output"] == "done" # original content preserved
|
||||
|
||||
def test_appends_to_non_json_tool_result(self, agent):
|
||||
"""Warning should be appended as text for non-JSON tool results."""
|
||||
agent.max_iterations = 10
|
||||
messages = [
|
||||
{"role": "tool", "content": "plain text result", "tool_call_id": "tc1"}
|
||||
]
|
||||
warning = agent._get_budget_warning(9)
|
||||
# Simulate injection logic for non-JSON
|
||||
last_content = messages[-1]["content"]
|
||||
try:
|
||||
import json
|
||||
json.loads(last_content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
messages[-1]["content"] = last_content + f"\n\n{warning}"
|
||||
assert "plain text result" in messages[-1]["content"]
|
||||
assert "BUDGET WARNING" in messages[-1]["content"]
|
||||
def test_grace_call_flags_initialized(self, agent):
|
||||
"""Agent should have budget grace call flags."""
|
||||
assert agent._budget_exhausted_injected is False
|
||||
assert agent._budget_grace_call is False
|
||||
|
||||
|
||||
class TestSafeWriter:
|
||||
|
||||
@@ -744,6 +744,44 @@ def test_normalize_codex_response_marks_commentary_only_message_as_incomplete(mo
|
||||
assert "inspect the repository" in (assistant_message.content or "")
|
||||
|
||||
|
||||
def test_interim_commentary_is_not_marked_already_streamed_without_callbacks(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
observed = {}
|
||||
|
||||
agent._fire_stream_delta("short version: yes")
|
||||
agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update(
|
||||
{"text": text, "already_streamed": already_streamed}
|
||||
)
|
||||
|
||||
agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"})
|
||||
|
||||
assert observed == {
|
||||
"text": "short version: yes",
|
||||
"already_streamed": False,
|
||||
}
|
||||
|
||||
|
||||
def test_interim_commentary_is_not_marked_already_streamed_when_stream_callback_fails(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
observed = {}
|
||||
|
||||
def failing_callback(_text):
|
||||
raise RuntimeError("display failed")
|
||||
|
||||
agent.stream_delta_callback = failing_callback
|
||||
agent._fire_stream_delta("short version: yes")
|
||||
agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update(
|
||||
{"text": text, "already_streamed": already_streamed}
|
||||
)
|
||||
|
||||
agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"})
|
||||
|
||||
assert observed == {
|
||||
"text": "short version: yes",
|
||||
"already_streamed": False,
|
||||
}
|
||||
|
||||
|
||||
def test_run_conversation_codex_continues_after_commentary_phase_message(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
responses = [
|
||||
|
||||
@@ -185,6 +185,38 @@ def test_migrator_optionally_imports_supported_secrets_and_messaging_settings(tm
|
||||
assert "TELEGRAM_BOT_TOKEN=123:abc" in env_text
|
||||
|
||||
|
||||
def test_messaging_cwd_skipped_when_inside_source(tmp_path: Path):
|
||||
"""MESSAGING_CWD pointing inside the OpenClaw source dir should be skipped."""
|
||||
mod = load_module()
|
||||
source = tmp_path / ".openclaw"
|
||||
target = tmp_path / ".hermes"
|
||||
target.mkdir()
|
||||
|
||||
# Workspace path is inside the source directory
|
||||
ws_path = str(source / "workspace")
|
||||
(source / "credentials").mkdir(parents=True)
|
||||
(source / "openclaw.json").write_text(
|
||||
json.dumps({"agents": {"defaults": {"workspace": ws_path}}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
migrator = mod.Migrator(
|
||||
source_root=source,
|
||||
target_root=target,
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
migrate_secrets=True,
|
||||
output_dir=target / "migration-report",
|
||||
selected_options={"messaging-settings"},
|
||||
)
|
||||
migrator.migrate()
|
||||
|
||||
env_path = target / ".env"
|
||||
if env_path.exists():
|
||||
assert "MESSAGING_CWD" not in env_path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_migrator_can_execute_only_selected_categories(tmp_path: Path):
|
||||
mod = load_module()
|
||||
source = tmp_path / ".openclaw"
|
||||
@@ -722,3 +754,98 @@ def test_skill_installs_cleanly_under_skills_guard():
|
||||
KNOWN_FALSE_POSITIVES = {"agent_config_mod", "python_os_environ", "hermes_config_mod"}
|
||||
for f in result.findings:
|
||||
assert f.pattern_id in KNOWN_FALSE_POSITIVES, f"Unexpected finding: {f}"
|
||||
|
||||
|
||||
# ── rebrand_text tests ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rebrand_text_replaces_openclaw_variants():
|
||||
mod = load_module()
|
||||
assert mod.rebrand_text("OpenClaw prefers Python 3.11") == "Hermes prefers Python 3.11"
|
||||
assert mod.rebrand_text("I told Open Claw to use dark mode") == "I told Hermes to use dark mode"
|
||||
assert mod.rebrand_text("Open-Claw config is great") == "Hermes config is great"
|
||||
assert mod.rebrand_text("openclaw should always respond concisely") == "Hermes should always respond concisely"
|
||||
assert mod.rebrand_text("OPENCLAW uses tools well") == "Hermes uses tools well"
|
||||
|
||||
|
||||
def test_rebrand_text_replaces_legacy_bot_names():
|
||||
mod = load_module()
|
||||
assert mod.rebrand_text("ClawdBot remembers my timezone") == "Hermes remembers my timezone"
|
||||
assert mod.rebrand_text("clawdbot prefers tabs") == "Hermes prefers tabs"
|
||||
assert mod.rebrand_text("MoltBot was configured for Spanish") == "Hermes was configured for Spanish"
|
||||
assert mod.rebrand_text("moltbot uses Python") == "Hermes uses Python"
|
||||
|
||||
|
||||
def test_rebrand_text_preserves_unrelated_content():
|
||||
mod = load_module()
|
||||
text = "User prefers dark mode and lives in Las Vegas"
|
||||
assert mod.rebrand_text(text) == text
|
||||
|
||||
|
||||
def test_rebrand_text_handles_multiple_replacements():
|
||||
mod = load_module()
|
||||
text = "OpenClaw said to ask ClawdBot about MoltBot settings"
|
||||
assert mod.rebrand_text(text) == "Hermes said to ask Hermes about Hermes settings"
|
||||
|
||||
|
||||
def test_migrate_memory_rebrands_entries(tmp_path):
|
||||
mod = load_module()
|
||||
source_root = tmp_path / "openclaw"
|
||||
source_root.mkdir()
|
||||
workspace = source_root / "workspace"
|
||||
workspace.mkdir()
|
||||
memory_md = workspace / "MEMORY.md"
|
||||
memory_md.write_text(
|
||||
"# Memory\n\n- OpenClaw should use Python 3.11\n- ClawdBot prefers dark mode\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
target_root = tmp_path / "hermes"
|
||||
target_root.mkdir()
|
||||
(target_root / "memories").mkdir()
|
||||
|
||||
migrator = mod.Migrator(
|
||||
source_root=source_root,
|
||||
target_root=target_root,
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
migrate_secrets=False,
|
||||
output_dir=tmp_path / "report",
|
||||
selected_options={"memory"},
|
||||
)
|
||||
migrator.migrate()
|
||||
|
||||
result = (target_root / "memories" / "MEMORY.md").read_text(encoding="utf-8")
|
||||
assert "OpenClaw" not in result
|
||||
assert "ClawdBot" not in result
|
||||
assert "Hermes" in result
|
||||
|
||||
|
||||
def test_migrate_soul_rebrands_content(tmp_path):
|
||||
mod = load_module()
|
||||
source_root = tmp_path / "openclaw"
|
||||
source_root.mkdir()
|
||||
workspace = source_root / "workspace"
|
||||
workspace.mkdir()
|
||||
soul_md = workspace / "SOUL.md"
|
||||
soul_md.write_text("You are OpenClaw, an AI assistant made by SparkLab.", encoding="utf-8")
|
||||
|
||||
target_root = tmp_path / "hermes"
|
||||
target_root.mkdir()
|
||||
|
||||
migrator = mod.Migrator(
|
||||
source_root=source_root,
|
||||
target_root=target_root,
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
migrate_secrets=False,
|
||||
output_dir=tmp_path / "report",
|
||||
selected_options={"soul"},
|
||||
)
|
||||
migrator.migrate()
|
||||
|
||||
result = (target_root / "SOUL.md").read_text(encoding="utf-8")
|
||||
assert "OpenClaw" not in result
|
||||
assert "You are Hermes" in result
|
||||
|
||||
120
tests/test_empty_model_fallback.py
Normal file
120
tests/test_empty_model_fallback.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Tests for empty model fallback — when provider is configured but model is missing."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
|
||||
class TestGetDefaultModelForProvider:
|
||||
"""Unit tests for hermes_cli.models.get_default_model_for_provider."""
|
||||
|
||||
def test_known_provider_returns_first_model(self):
|
||||
from hermes_cli.models import get_default_model_for_provider
|
||||
result = get_default_model_for_provider("openai-codex")
|
||||
# Should return first model from _PROVIDER_MODELS["openai-codex"]
|
||||
assert result
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_openrouter_returns_empty(self):
|
||||
"""OpenRouter uses dynamic model fetch, no static catalog entry."""
|
||||
from hermes_cli.models import get_default_model_for_provider
|
||||
# OpenRouter is not in _PROVIDER_MODELS — it uses live fetching
|
||||
result = get_default_model_for_provider("openrouter")
|
||||
assert result == ""
|
||||
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
from hermes_cli.models import get_default_model_for_provider
|
||||
assert get_default_model_for_provider("nonexistent-provider") == ""
|
||||
|
||||
def test_custom_provider_returns_empty(self):
|
||||
"""Custom provider has no model catalog — should return empty."""
|
||||
from hermes_cli.models import get_default_model_for_provider
|
||||
# Custom providers don't have entries in _PROVIDER_MODELS
|
||||
assert get_default_model_for_provider("some-random-custom") == ""
|
||||
|
||||
|
||||
class TestGatewayEmptyModelFallback:
|
||||
"""Test that _resolve_session_agent_runtime fills in empty model from provider catalog."""
|
||||
|
||||
def test_empty_model_filled_from_provider(self):
|
||||
"""When config has no model but provider is openai-codex, use first codex model."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_model_overrides = {}
|
||||
|
||||
# Mock _resolve_gateway_model to return empty string
|
||||
# Mock _resolve_runtime_agent_kwargs to return openai-codex provider
|
||||
with patch("gateway.run._resolve_gateway_model", return_value=""), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
|
||||
"provider": "openai-codex",
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_mode": "codex_responses",
|
||||
}):
|
||||
model, kwargs = runner._resolve_session_agent_runtime()
|
||||
|
||||
# Model should have been filled in from provider catalog
|
||||
assert model, "Model should not be empty when provider is known"
|
||||
assert isinstance(model, str)
|
||||
assert kwargs["provider"] == "openai-codex"
|
||||
|
||||
def test_nonempty_model_not_overridden(self):
|
||||
"""When config has a model set, don't override it."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_model_overrides = {}
|
||||
|
||||
with patch("gateway.run._resolve_gateway_model", return_value="gpt-5.4"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
|
||||
"provider": "openai-codex",
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_mode": "codex_responses",
|
||||
}):
|
||||
model, kwargs = runner._resolve_session_agent_runtime()
|
||||
|
||||
assert model == "gpt-5.4", "Explicit model should not be overridden"
|
||||
|
||||
def test_empty_model_no_provider_stays_empty(self):
|
||||
"""When both model and provider are empty, model stays empty."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_model_overrides = {}
|
||||
|
||||
with patch("gateway.run._resolve_gateway_model", return_value=""), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
|
||||
"provider": "",
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://example.com",
|
||||
"api_mode": "chat_completions",
|
||||
}):
|
||||
model, kwargs = runner._resolve_session_agent_runtime()
|
||||
|
||||
# Can't fill in a default without knowing the provider
|
||||
assert model == ""
|
||||
|
||||
|
||||
class TestResolveGatewayModel:
|
||||
"""Test _resolve_gateway_model reads model from config correctly."""
|
||||
|
||||
def test_returns_default_key(self):
|
||||
from gateway.run import _resolve_gateway_model
|
||||
assert _resolve_gateway_model({"model": {"default": "gpt-5.4"}}) == "gpt-5.4"
|
||||
|
||||
def test_returns_model_key_fallback(self):
|
||||
from gateway.run import _resolve_gateway_model
|
||||
assert _resolve_gateway_model({"model": {"model": "gpt-5.4"}}) == "gpt-5.4"
|
||||
|
||||
def test_returns_empty_when_missing(self):
|
||||
from gateway.run import _resolve_gateway_model
|
||||
assert _resolve_gateway_model({"model": {}}) == ""
|
||||
|
||||
def test_returns_empty_when_no_model_section(self):
|
||||
from gateway.run import _resolve_gateway_model
|
||||
assert _resolve_gateway_model({}) == ""
|
||||
|
||||
def test_string_model_config(self):
|
||||
from gateway.run import _resolve_gateway_model
|
||||
assert _resolve_gateway_model({"model": "my-model"}) == "my-model"
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@@ -34,6 +35,8 @@ def _reset_logging_state():
|
||||
h.close()
|
||||
else:
|
||||
pre_existing.append(h)
|
||||
# Ensure the record factory is installed (it's idempotent).
|
||||
hermes_logging._install_session_record_factory()
|
||||
yield
|
||||
# Restore — remove any handlers added during the test.
|
||||
for h in list(root.handlers):
|
||||
@@ -41,6 +44,7 @@ def _reset_logging_state():
|
||||
root.removeHandler(h)
|
||||
h.close()
|
||||
hermes_logging._logging_initialized = False
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -220,6 +224,294 @@ class TestSetupLogging:
|
||||
]
|
||||
assert agent_handlers[0].level == logging.WARNING
|
||||
|
||||
def test_record_factory_installed(self, hermes_home):
|
||||
"""The custom record factory injects session_tag on all records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
factory = logging.getLogRecordFactory()
|
||||
assert getattr(factory, "_hermes_session_injector", False), (
|
||||
"Record factory should have _hermes_session_injector marker"
|
||||
)
|
||||
# Verify session_tag exists on a fresh record
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert hasattr(record, "session_tag")
|
||||
|
||||
|
||||
class TestGatewayMode:
|
||||
"""setup_logging(mode='gateway') creates a filtered gateway.log."""
|
||||
|
||||
def test_gateway_log_created(self, hermes_home):
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
root = logging.getLogger()
|
||||
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 1
|
||||
|
||||
def test_gateway_log_not_created_in_cli_mode(self, hermes_home):
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli")
|
||||
root = logging.getLogger()
|
||||
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 0
|
||||
|
||||
def test_gateway_log_receives_gateway_records(self, hermes_home):
|
||||
"""gateway.log captures records from gateway.* loggers."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
gw_logger = logging.getLogger("gateway.platforms.telegram")
|
||||
gw_logger.info("telegram connected")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
assert gw_log.exists()
|
||||
assert "telegram connected" in gw_log.read_text()
|
||||
|
||||
def test_gateway_log_rejects_non_gateway_records(self, hermes_home):
|
||||
"""gateway.log does NOT capture records from tools.*, agent.*, etc."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
tool_logger = logging.getLogger("tools.terminal_tool")
|
||||
tool_logger.info("running command")
|
||||
|
||||
agent_logger = logging.getLogger("agent.context_compressor")
|
||||
agent_logger.info("compressing context")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
if gw_log.exists():
|
||||
content = gw_log.read_text()
|
||||
assert "running command" not in content
|
||||
assert "compressing context" not in content
|
||||
|
||||
def test_agent_log_still_receives_all(self, hermes_home):
|
||||
"""agent.log (catch-all) still receives gateway AND tool records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
logging.getLogger("gateway.run").info("gateway msg")
|
||||
logging.getLogger("tools.file_tools").info("file msg")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "gateway msg" in content
|
||||
assert "file msg" in content
|
||||
|
||||
|
||||
class TestSessionContext:
|
||||
"""set_session_context / clear_session_context + _SessionFilter."""
|
||||
|
||||
def test_session_tag_in_log_output(self, hermes_home):
|
||||
"""When session context is set, log lines include [session_id]."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.set_session_context("abc123")
|
||||
|
||||
test_logger = logging.getLogger("test.session_tag")
|
||||
test_logger.info("tagged message")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "[abc123]" in content
|
||||
assert "tagged message" in content
|
||||
|
||||
def test_no_session_tag_without_context(self, hermes_home):
|
||||
"""Without session context, log lines have no session tag."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
test_logger = logging.getLogger("test.no_session")
|
||||
test_logger.info("untagged message")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "untagged message" in content
|
||||
# Should not have any [xxx] session tag
|
||||
import re
|
||||
for line in content.splitlines():
|
||||
if "untagged message" in line:
|
||||
assert not re.search(r"\[.+?\]", line.split("INFO")[1].split("test.no_session")[0])
|
||||
|
||||
def test_clear_session_context(self, hermes_home):
|
||||
"""After clearing, session tag disappears."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.set_session_context("xyz789")
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
test_logger = logging.getLogger("test.cleared")
|
||||
test_logger.info("after clear")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "[xyz789]" not in content
|
||||
|
||||
def test_session_context_thread_isolated(self, hermes_home):
|
||||
"""Session context is per-thread — one thread's context doesn't leak."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
|
||||
results = {}
|
||||
|
||||
def thread_a():
|
||||
hermes_logging.set_session_context("thread_a_session")
|
||||
logging.getLogger("test.thread_a").info("from thread A")
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
def thread_b():
|
||||
hermes_logging.set_session_context("thread_b_session")
|
||||
logging.getLogger("test.thread_b").info("from thread B")
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
ta = threading.Thread(target=thread_a)
|
||||
tb = threading.Thread(target=thread_b)
|
||||
ta.start()
|
||||
ta.join()
|
||||
tb.start()
|
||||
tb.join()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
|
||||
# Each thread's message should have its own session tag
|
||||
for line in content.splitlines():
|
||||
if "from thread A" in line:
|
||||
assert "[thread_a_session]" in line
|
||||
assert "[thread_b_session]" not in line
|
||||
if "from thread B" in line:
|
||||
assert "[thread_b_session]" in line
|
||||
assert "[thread_a_session]" not in line
|
||||
|
||||
|
||||
class TestRecordFactory:
|
||||
"""Unit tests for the custom LogRecord factory."""
|
||||
|
||||
def test_record_has_session_tag(self):
|
||||
"""Every record gets a session_tag attribute."""
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert hasattr(record, "session_tag")
|
||||
|
||||
def test_empty_tag_without_context(self):
|
||||
hermes_logging.clear_session_context()
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert record.session_tag == ""
|
||||
|
||||
def test_tag_with_context(self):
|
||||
hermes_logging.set_session_context("sess_42")
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert record.session_tag == " [sess_42]"
|
||||
|
||||
def test_idempotent_install(self):
|
||||
"""Calling _install_session_record_factory() twice doesn't double-wrap."""
|
||||
hermes_logging._install_session_record_factory()
|
||||
factory_a = logging.getLogRecordFactory()
|
||||
hermes_logging._install_session_record_factory()
|
||||
factory_b = logging.getLogRecordFactory()
|
||||
assert factory_a is factory_b
|
||||
|
||||
def test_works_with_any_handler(self):
|
||||
"""A handler using %(session_tag)s works even without _SessionFilter."""
|
||||
hermes_logging.set_session_context("any_handler_test")
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(session_tag)s %(message)s"))
|
||||
|
||||
logger = logging.getLogger("_test_any_handler")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
# Should not raise KeyError
|
||||
logger.info("hello")
|
||||
finally:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
class TestComponentFilter:
|
||||
"""Unit tests for _ComponentFilter."""
|
||||
|
||||
def test_passes_matching_prefix(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"gateway.run", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is True
|
||||
|
||||
def test_passes_nested_matching_prefix(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"gateway.platforms.telegram", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is True
|
||||
|
||||
def test_blocks_non_matching(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"tools.terminal_tool", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is False
|
||||
|
||||
def test_multiple_prefixes(self):
|
||||
f = hermes_logging._ComponentFilter(("agent", "run_agent", "model_tools"))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"agent.compressor", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"run_agent", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"model_tools", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert not f.filter(logging.LogRecord(
|
||||
"tools.browser", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
|
||||
|
||||
class TestComponentPrefixes:
|
||||
"""COMPONENT_PREFIXES covers the expected components."""
|
||||
|
||||
def test_gateway_prefix(self):
|
||||
assert "gateway" in hermes_logging.COMPONENT_PREFIXES
|
||||
assert ("gateway",) == hermes_logging.COMPONENT_PREFIXES["gateway"]
|
||||
|
||||
def test_agent_prefix(self):
|
||||
prefixes = hermes_logging.COMPONENT_PREFIXES["agent"]
|
||||
assert "agent" in prefixes
|
||||
assert "run_agent" in prefixes
|
||||
assert "model_tools" in prefixes
|
||||
|
||||
def test_tools_prefix(self):
|
||||
assert ("tools",) == hermes_logging.COMPONENT_PREFIXES["tools"]
|
||||
|
||||
def test_cli_prefix(self):
|
||||
prefixes = hermes_logging.COMPONENT_PREFIXES["cli"]
|
||||
assert "hermes_cli" in prefixes
|
||||
assert "cli" in prefixes
|
||||
|
||||
def test_cron_prefix(self):
|
||||
assert ("cron",) == hermes_logging.COMPONENT_PREFIXES["cron"]
|
||||
|
||||
|
||||
class TestSetupVerboseLogging:
|
||||
"""setup_verbose_logging() adds a DEBUG-level console handler."""
|
||||
@@ -301,6 +593,59 @@ class TestAddRotatingHandler:
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_log_filter_attached(self, tmp_path):
|
||||
"""Optional log_filter is attached to the handler."""
|
||||
log_path = tmp_path / "filtered.log"
|
||||
logger = logging.getLogger("_test_rotating_filter")
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
component_filter = hermes_logging._ComponentFilter(("test",))
|
||||
|
||||
hermes_logging._add_rotating_handler(
|
||||
logger, log_path,
|
||||
level=logging.INFO, max_bytes=1024, backup_count=1,
|
||||
formatter=formatter,
|
||||
log_filter=component_filter,
|
||||
)
|
||||
|
||||
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
assert len(handlers) == 1
|
||||
assert component_filter in handlers[0].filters
|
||||
# Clean up
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, RotatingFileHandler):
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_no_session_filter_on_handler(self, tmp_path):
|
||||
"""Handlers rely on record factory, not per-handler _SessionFilter."""
|
||||
log_path = tmp_path / "no_session_filter.log"
|
||||
logger = logging.getLogger("_test_no_session_filter")
|
||||
formatter = logging.Formatter("%(session_tag)s%(message)s")
|
||||
|
||||
hermes_logging._add_rotating_handler(
|
||||
logger, log_path,
|
||||
level=logging.INFO, max_bytes=1024, backup_count=1,
|
||||
formatter=formatter,
|
||||
)
|
||||
|
||||
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
assert len(handlers) == 1
|
||||
# No _SessionFilter on the handler — record factory handles it
|
||||
assert len(handlers[0].filters) == 0
|
||||
|
||||
# But session_tag still works (via record factory)
|
||||
hermes_logging.set_session_context("factory_test")
|
||||
logger.info("test msg")
|
||||
handlers[0].flush()
|
||||
content = log_path.read_text()
|
||||
assert "[factory_test]" in content
|
||||
|
||||
# Clean up
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, RotatingFileHandler):
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_managed_mode_initial_open_sets_group_writable(self, tmp_path):
|
||||
log_path = tmp_path / "managed-open.log"
|
||||
logger = logging.getLogger("_test_rotating_managed_open")
|
||||
|
||||
114
tests/test_ipv4_preference.py
Normal file
114
tests/test_ipv4_preference.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Tests for network.force_ipv4 — the socket.getaddrinfo monkey-patch."""
|
||||
|
||||
import importlib
|
||||
import socket
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _reload_constants():
|
||||
"""Reload hermes_constants to get a fresh apply_ipv4_preference."""
|
||||
import hermes_constants
|
||||
importlib.reload(hermes_constants)
|
||||
return hermes_constants
|
||||
|
||||
|
||||
class TestApplyIPv4Preference:
|
||||
"""Tests for apply_ipv4_preference()."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Save the original getaddrinfo before each test."""
|
||||
self._original = socket.getaddrinfo
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore the original getaddrinfo after each test."""
|
||||
socket.getaddrinfo = self._original
|
||||
|
||||
def test_noop_when_force_false(self):
|
||||
"""No patch when force=False."""
|
||||
from hermes_constants import apply_ipv4_preference
|
||||
original = socket.getaddrinfo
|
||||
apply_ipv4_preference(force=False)
|
||||
assert socket.getaddrinfo is original
|
||||
|
||||
def test_patches_getaddrinfo_when_forced(self):
|
||||
"""Patches socket.getaddrinfo when force=True."""
|
||||
from hermes_constants import apply_ipv4_preference
|
||||
original = socket.getaddrinfo
|
||||
apply_ipv4_preference(force=True)
|
||||
assert socket.getaddrinfo is not original
|
||||
assert getattr(socket.getaddrinfo, "_hermes_ipv4_patched", False) is True
|
||||
|
||||
def test_double_patch_is_safe(self):
|
||||
"""Calling apply twice doesn't double-wrap."""
|
||||
from hermes_constants import apply_ipv4_preference
|
||||
apply_ipv4_preference(force=True)
|
||||
first_patch = socket.getaddrinfo
|
||||
apply_ipv4_preference(force=True)
|
||||
assert socket.getaddrinfo is first_patch
|
||||
|
||||
def test_af_unspec_becomes_af_inet(self):
|
||||
"""AF_UNSPEC (default) calls get rewritten to AF_INET."""
|
||||
from hermes_constants import apply_ipv4_preference
|
||||
|
||||
calls = []
|
||||
original = socket.getaddrinfo
|
||||
|
||||
def mock_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
|
||||
calls.append(family)
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 80))]
|
||||
|
||||
socket.getaddrinfo = mock_getaddrinfo
|
||||
apply_ipv4_preference(force=True)
|
||||
|
||||
# Call with default family (AF_UNSPEC = 0)
|
||||
socket.getaddrinfo("example.com", 80)
|
||||
assert calls[-1] == socket.AF_INET, "AF_UNSPEC should be rewritten to AF_INET"
|
||||
|
||||
def test_explicit_family_preserved(self):
|
||||
"""Explicit AF_INET6 requests are not intercepted."""
|
||||
from hermes_constants import apply_ipv4_preference
|
||||
|
||||
calls = []
|
||||
original = socket.getaddrinfo
|
||||
|
||||
def mock_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
|
||||
calls.append(family)
|
||||
return [(family, socket.SOCK_STREAM, 6, "", ("::1", 80))]
|
||||
|
||||
socket.getaddrinfo = mock_getaddrinfo
|
||||
apply_ipv4_preference(force=True)
|
||||
|
||||
socket.getaddrinfo("example.com", 80, family=socket.AF_INET6)
|
||||
assert calls[-1] == socket.AF_INET6, "Explicit AF_INET6 should pass through"
|
||||
|
||||
def test_fallback_on_gaierror(self):
|
||||
"""Falls back to AF_UNSPEC if AF_INET resolution fails."""
|
||||
from hermes_constants import apply_ipv4_preference
|
||||
|
||||
call_families = []
|
||||
|
||||
def mock_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
|
||||
call_families.append(family)
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror("No A record")
|
||||
# AF_UNSPEC fallback returns IPv6
|
||||
return [(socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("::1", 80))]
|
||||
|
||||
socket.getaddrinfo = mock_getaddrinfo
|
||||
apply_ipv4_preference(force=True)
|
||||
|
||||
result = socket.getaddrinfo("ipv6only.example.com", 80)
|
||||
# Should have tried AF_INET first, then fallen back to AF_UNSPEC
|
||||
assert call_families == [socket.AF_INET, 0]
|
||||
assert result[0][0] == socket.AF_INET6
|
||||
|
||||
|
||||
class TestConfigDefault:
|
||||
"""Verify network section exists in DEFAULT_CONFIG."""
|
||||
|
||||
def test_network_section_in_default_config(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
assert "network" in DEFAULT_CONFIG
|
||||
assert DEFAULT_CONFIG["network"]["force_ipv4"] is False
|
||||
@@ -59,8 +59,9 @@ class TestCamofoxConfigDefaults:
|
||||
browser_cfg = DEFAULT_CONFIG["browser"]
|
||||
assert browser_cfg["camofox"]["managed_persistence"] is False
|
||||
|
||||
def test_config_version_unchanged(self):
|
||||
def test_config_version_matches_current_schema(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
# managed_persistence is auto-merged by _deep_merge, no version bump needed
|
||||
assert DEFAULT_CONFIG["_config_version"] == 13
|
||||
# The current schema version is tracked globally; unrelated default
|
||||
# options may bump it after browser defaults are added.
|
||||
assert DEFAULT_CONFIG["_config_version"] == 15
|
||||
|
||||
@@ -380,7 +380,7 @@ class TestStubSchemaDrift(unittest.TestCase):
|
||||
# Parameters that are internal (injected by the handler, not user-facing)
|
||||
_INTERNAL_PARAMS = {"task_id", "user_task"}
|
||||
# Parameters intentionally blocked in the sandbox
|
||||
_BLOCKED_TERMINAL_PARAMS = {"background", "check_interval", "pty", "notify_on_complete"}
|
||||
_BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete"}
|
||||
|
||||
def test_stubs_cover_all_schema_params(self):
|
||||
"""Every user-facing parameter in the real schema must appear in the
|
||||
|
||||
295
tests/tools/test_modal_bulk_upload.py
Normal file
295
tests/tools/test_modal_bulk_upload.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Tests for Modal bulk upload via tar/base64 archive."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import modal as modal_env
|
||||
|
||||
|
||||
def _make_mock_modal_env(monkeypatch, tmp_path):
|
||||
"""Create a minimal mock ModalEnvironment for testing upload methods.
|
||||
|
||||
Returns a ModalEnvironment-like object with _sandbox and _worker mocked.
|
||||
We don't call __init__ because it requires the Modal SDK.
|
||||
"""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _make_mock_stdin():
|
||||
"""Create a mock stdin that captures written data."""
|
||||
stdin = MagicMock()
|
||||
written_chunks = []
|
||||
|
||||
def mock_write(data):
|
||||
written_chunks.append(data)
|
||||
|
||||
stdin.write = mock_write
|
||||
stdin.write_eof = MagicMock()
|
||||
stdin.drain = MagicMock()
|
||||
stdin.drain.aio = AsyncMock()
|
||||
stdin._written_chunks = written_chunks
|
||||
return stdin
|
||||
|
||||
|
||||
def _wire_async_exec(env, exec_calls=None):
|
||||
"""Wire mock sandbox.exec.aio and a real run_coroutine on the env.
|
||||
|
||||
Optionally captures exec call args into *exec_calls* list.
|
||||
Returns (exec_calls, run_kwargs, stdin_mock).
|
||||
"""
|
||||
if exec_calls is None:
|
||||
exec_calls = []
|
||||
run_kwargs: dict = {}
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=0)
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls, run_kwargs, stdin_mock
|
||||
|
||||
|
||||
class TestModalBulkUpload:
|
||||
"""Test _modal_bulk_upload method."""
|
||||
|
||||
def test_empty_files_is_noop(self, monkeypatch, tmp_path):
|
||||
"""Empty file list should not call worker.run_coroutine."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
env._modal_bulk_upload([])
|
||||
env._worker.run_coroutine.assert_not_called()
|
||||
|
||||
def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path):
|
||||
"""The tar archive sent via stdin should contain all files."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src_a = tmp_path / "a.json"
|
||||
src_b = tmp_path / "b.py"
|
||||
src_a.write_text("cred_content")
|
||||
src_b.write_text("skill_content")
|
||||
|
||||
files = [
|
||||
(str(src_a), "/root/.hermes/credentials/a.json"),
|
||||
(str(src_b), "/root/.hermes/skills/b.py"),
|
||||
]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Verify the command reads from stdin (no echo with embedded payload)
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
cmd = args[2]
|
||||
assert "mkdir -p" in cmd
|
||||
assert "base64 -d" in cmd
|
||||
assert "tar xzf" in cmd
|
||||
assert "-C /" in cmd
|
||||
|
||||
# Reassemble the base64 payload from stdin chunks and verify tar contents
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = sorted(tar.getnames())
|
||||
assert "root/.hermes/credentials/a.json" in names
|
||||
assert "root/.hermes/skills/b.py" in names
|
||||
|
||||
# Verify content
|
||||
a_content = tar.extractfile("root/.hermes/credentials/a.json").read()
|
||||
assert a_content == b"cred_content"
|
||||
b_content = tar.extractfile("root/.hermes/skills/b.py").read()
|
||||
assert b_content == b"skill_content"
|
||||
|
||||
# Verify stdin was closed
|
||||
stdin_mock.write_eof.assert_called_once()
|
||||
|
||||
def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path):
|
||||
"""Remote parent directories should be pre-created in the command."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
|
||||
files = [
|
||||
(str(src), "/root/.hermes/credentials/f.txt"),
|
||||
(str(src), "/root/.hermes/skills/deep/nested/f.txt"),
|
||||
]
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
cmd = exec_calls[0][2]
|
||||
assert "/root/.hermes/credentials" in cmd
|
||||
assert "/root/.hermes/skills/deep/nested" in cmd
|
||||
|
||||
def test_single_exec_call(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload should use exactly one exec call regardless of file count."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
files = []
|
||||
for i in range(20):
|
||||
src = tmp_path / f"file_{i}.txt"
|
||||
src.write_text(f"content_{i}")
|
||||
files.append((str(src), f"/root/.hermes/cache/file_{i}.txt"))
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should be exactly 1 exec call, not 20
|
||||
assert len(exec_calls) == 1
|
||||
|
||||
def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch):
|
||||
"""Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
# Create a minimal env without full __init__
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Manually call the part of __init__ that wires FileSyncManager
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
)
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
def test_timeout_set_to_120(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload uses a 120s timeout (not the per-file 15s)."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
_, run_kwargs, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
def test_nonzero_exit_raises(self, monkeypatch, tmp_path):
|
||||
"""Non-zero exit code from remote exec should raise RuntimeError."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=1) # non-zero exit
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="tar: error")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk upload failed"):
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
def test_payload_not_in_command_string(self, monkeypatch, tmp_path):
|
||||
"""The base64 payload must NOT appear in the bash -c argument.
|
||||
|
||||
This is the core ARG_MAX fix: the payload goes through stdin,
|
||||
not embedded in the command string.
|
||||
"""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("some data to upload")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# The command should NOT contain an echo with the payload
|
||||
cmd = exec_calls[0][2]
|
||||
assert "echo" not in cmd
|
||||
# The payload should go through stdin
|
||||
assert len(stdin_mock._written_chunks) > 0
|
||||
|
||||
def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path):
|
||||
"""Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
# Use random bytes so gzip cannot compress them -- ensures the
|
||||
# base64 payload exceeds one 1 MB chunk.
|
||||
import os as _os
|
||||
src = tmp_path / "large.bin"
|
||||
src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024))
|
||||
files = [(str(src), "/root/.hermes/large.bin")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should have multiple stdin write chunks
|
||||
assert len(stdin_mock._written_chunks) >= 2
|
||||
|
||||
# Reassembled payload should still decode to valid tar
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = tar.getnames()
|
||||
assert "root/.hermes/large.bin" in names
|
||||
@@ -289,3 +289,62 @@ class TestCodeExecutionBlocked:
|
||||
def test_notify_on_complete_blocked_in_sandbox(self):
|
||||
from tools.code_execution_tool import _TERMINAL_BLOCKED_PARAMS
|
||||
assert "notify_on_complete" in _TERMINAL_BLOCKED_PARAMS
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Completion consumed suppression
|
||||
# =========================================================================
|
||||
|
||||
class TestCompletionConsumed:
|
||||
"""Test that wait/poll/log suppress redundant completion notifications."""
|
||||
|
||||
def test_wait_marks_completion_consumed(self, registry):
|
||||
"""wait() returning exited status marks session as consumed."""
|
||||
s = _make_session(sid="proc_wait", notify_on_complete=True, output="done")
|
||||
s.exited = True
|
||||
s.exit_code = 0
|
||||
registry._running[s.id] = s
|
||||
with patch.object(registry, "_write_checkpoint"):
|
||||
registry._move_to_finished(s)
|
||||
|
||||
# Notification is in the queue
|
||||
assert not registry.completion_queue.empty()
|
||||
assert not registry.is_completion_consumed("proc_wait")
|
||||
|
||||
# Agent calls wait() — gets the result directly
|
||||
result = registry.wait("proc_wait", timeout=1)
|
||||
assert result["status"] == "exited"
|
||||
|
||||
# Now the completion is marked as consumed
|
||||
assert registry.is_completion_consumed("proc_wait")
|
||||
|
||||
def test_poll_marks_completion_consumed(self, registry):
|
||||
"""poll() returning exited status marks session as consumed."""
|
||||
s = _make_session(sid="proc_poll", notify_on_complete=True, output="done")
|
||||
s.exited = True
|
||||
s.exit_code = 0
|
||||
registry._finished[s.id] = s
|
||||
|
||||
result = registry.poll("proc_poll")
|
||||
assert result["status"] == "exited"
|
||||
assert registry.is_completion_consumed("proc_poll")
|
||||
|
||||
def test_log_marks_completion_consumed(self, registry):
|
||||
"""read_log() on exited session marks as consumed."""
|
||||
s = _make_session(sid="proc_log", notify_on_complete=True, output="line1\nline2")
|
||||
s.exited = True
|
||||
s.exit_code = 0
|
||||
registry._finished[s.id] = s
|
||||
|
||||
result = registry.read_log("proc_log")
|
||||
assert result["status"] == "exited"
|
||||
assert registry.is_completion_consumed("proc_log")
|
||||
|
||||
def test_running_process_not_consumed(self, registry):
|
||||
"""poll() on a still-running process does not mark as consumed."""
|
||||
s = _make_session(sid="proc_running", notify_on_complete=True, output="partial")
|
||||
registry._running[s.id] = s
|
||||
|
||||
result = registry.poll("proc_running")
|
||||
assert result["status"] == "running"
|
||||
assert not registry.is_completion_consumed("proc_running")
|
||||
|
||||
517
tests/tools/test_ssh_bulk_upload.py
Normal file
517
tests/tools/test_ssh_bulk_upload.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""Tests for SSH bulk upload via tar pipe."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
|
||||
stderr_read=b""):
|
||||
"""Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = returncode
|
||||
m.poll.return_value = poll_return
|
||||
m.communicate.return_value = communicate_return
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = stderr_read
|
||||
return m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
class TestSSHBulkUpload:
|
||||
"""Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
|
||||
|
||||
def test_empty_files_is_noop(self, mock_env):
|
||||
"""Empty file list should not spawn any subprocesses."""
|
||||
with patch.object(subprocess, "run") as mock_run, \
|
||||
patch.object(subprocess, "Popen") as mock_popen:
|
||||
mock_env._ssh_bulk_upload([])
|
||||
mock_run.assert_not_called()
|
||||
mock_popen.assert_not_called()
|
||||
|
||||
def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
|
||||
"""All parent directories should be created in one SSH call."""
|
||||
# Create test files
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("aaa")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("bbb")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/credentials/b.txt"),
|
||||
]
|
||||
|
||||
# Mock subprocess.run for mkdir and Popen for tar pipe
|
||||
mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = 0
|
||||
m.poll.return_value = 0
|
||||
m.communicate.return_value = (b"", b"")
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = b""
|
||||
return m
|
||||
|
||||
with patch.object(subprocess, "run", mock_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Exactly one subprocess.run call for mkdir
|
||||
assert mock_run.call_count == 1
|
||||
mkdir_cmd = mock_run.call_args[0][0]
|
||||
# Should contain mkdir -p with both parent dirs
|
||||
mkdir_str = " ".join(mkdir_cmd)
|
||||
assert "mkdir -p" in mkdir_str
|
||||
assert "/home/testuser/.hermes/skills" in mkdir_str
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
|
||||
"""Symlinks in staging dir should mirror the remote path structure."""
|
||||
f1 = tmp_path / "local_a.txt"
|
||||
f1.write_text("content a")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
|
||||
]
|
||||
|
||||
staging_paths = []
|
||||
|
||||
def capture_tar_cmd(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
# Capture the staging dir from -C argument
|
||||
c_idx = cmd.index("-C")
|
||||
staging_dir = cmd[c_idx + 1]
|
||||
# Check the symlink exists
|
||||
expected = os.path.join(
|
||||
staging_dir, "home/testuser/.hermes/skills/my_skill.md"
|
||||
)
|
||||
staging_paths.append(expected)
|
||||
assert os.path.islink(expected), f"Expected symlink at {expected}"
|
||||
assert os.readlink(expected) == os.path.abspath(str(f1))
|
||||
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(staging_paths) == 1, "tar command should have been called"
|
||||
|
||||
def test_tar_pipe_commands(self, mock_env, tmp_path):
|
||||
"""Verify tar and SSH commands are wired correctly."""
|
||||
f1 = tmp_path / "x.txt"
|
||||
f1.write_text("x")
|
||||
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
|
||||
|
||||
tar_cmd = popen_cmds[0]
|
||||
ssh_cmd = popen_cmds[1]
|
||||
|
||||
# tar: create, dereference symlinks, to stdout
|
||||
assert tar_cmd[0] == "tar"
|
||||
assert "-chf" in tar_cmd
|
||||
assert "-" in tar_cmd # stdout
|
||||
assert "-C" in tar_cmd
|
||||
|
||||
# ssh: extract from stdin at /
|
||||
ssh_str = " ".join(ssh_cmd)
|
||||
assert "ssh" in ssh_str
|
||||
assert "tar xf - -C /" in ssh_str
|
||||
assert "testuser@example.com" in ssh_str
|
||||
|
||||
def test_mkdir_failure_raises(self, mock_env, tmp_path):
|
||||
"""mkdir failure should raise RuntimeError before tar pipe."""
|
||||
f1 = tmp_path / "y.txt"
|
||||
f1.write_text("y")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
|
||||
|
||||
failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed_run):
|
||||
with pytest.raises(RuntimeError, match="remote mkdir failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_tar_create_failure_raises(self, mock_env, tmp_path):
|
||||
"""tar create failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "z.txt"
|
||||
f1.write_text("z")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 1
|
||||
mock_tar.poll.return_value = 1
|
||||
mock_tar.communicate.return_value = (b"tar: error", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b"tar: error"
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"")
|
||||
mock_ssh.returncode = 0
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar create failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
|
||||
"""SSH tar extract failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "w.txt"
|
||||
f1.write_text("w")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 0
|
||||
mock_tar.poll.return_value = 0
|
||||
mock_tar.communicate.return_value = (b"", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b""
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"Permission denied")
|
||||
mock_ssh.returncode = 1
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
|
||||
"""SSH command for tar extract should reuse ControlMaster socket."""
|
||||
f1 = tmp_path / "c.txt"
|
||||
f1.write_text("c")
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# The SSH command (second Popen call) should include ControlPath
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
|
||||
|
||||
def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload SSH command should include custom port and key."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
|
||||
|
||||
f1 = tmp_path / "d.txt"
|
||||
f1.write_text("d")
|
||||
files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
|
||||
|
||||
run_cmds = []
|
||||
popen_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
env._ssh_bulk_upload(files)
|
||||
|
||||
# Check mkdir SSH call includes port and key
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_cmd = run_cmds[0]
|
||||
assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
|
||||
assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
|
||||
|
||||
# Check tar extract SSH call includes port and key
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert "-p" in ssh_cmd and "2222" in ssh_cmd
|
||||
assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
|
||||
|
||||
def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
|
||||
"""Multiple files in the same dir should produce one mkdir entry."""
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("a")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("b")
|
||||
f3 = tmp_path / "c.txt"
|
||||
f3.write_text("c")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/skills/b.txt"),
|
||||
(str(f3), "/home/testuser/.hermes/credentials/c.txt"),
|
||||
]
|
||||
|
||||
run_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def make_mock_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_mock_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Only one mkdir call
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_str = " ".join(run_cmds[0])
|
||||
# skills dir should appear exactly once despite two files
|
||||
assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
|
||||
"""tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
|
||||
f1 = tmp_path / "s.txt"
|
||||
f1.write_text("s")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
|
||||
|
||||
mock_tar_stdout = MagicMock()
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
if cmd[0] == "tar":
|
||||
mock.stdout = mock_tar_stdout
|
||||
else:
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar_stdout.close.assert_called_once()
|
||||
|
||||
def test_timeout_kills_both_processes(self, mock_env, tmp_path):
|
||||
"""TimeoutExpired during communicate should kill both processes."""
|
||||
f1 = tmp_path / "t.txt"
|
||||
f1.write_text("t")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = None
|
||||
mock_tar.poll.return_value = None
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
|
||||
mock_ssh.returncode = None
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_ssh.kill.assert_called_once()
|
||||
|
||||
|
||||
class TestSSHBulkUploadWiring:
|
||||
"""Verify bulk_upload_fn is wired into FileSyncManager."""
|
||||
|
||||
def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class FakeSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
# Should be the bound method
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
|
||||
class TestSharedHelpers:
|
||||
"""Direct unit tests for file_sync.py helpers."""
|
||||
|
||||
def test_quoted_mkdir_command_basic(self):
|
||||
result = quoted_mkdir_command(["/a", "/b/c"])
|
||||
assert result == "mkdir -p /a /b/c"
|
||||
|
||||
def test_quoted_mkdir_command_quotes_special_chars(self):
|
||||
result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
|
||||
assert "mkdir -p" in result
|
||||
# shlex.quote wraps in single quotes
|
||||
assert "'/path/with spaces'" in result
|
||||
|
||||
def test_quoted_mkdir_command_empty(self):
|
||||
result = quoted_mkdir_command([])
|
||||
assert result == "mkdir -p "
|
||||
|
||||
def test_unique_parent_dirs_deduplicates(self):
|
||||
files = [
|
||||
("/local/a.txt", "/remote/dir/a.txt"),
|
||||
("/local/b.txt", "/remote/dir/b.txt"),
|
||||
("/local/c.txt", "/remote/other/c.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/remote/dir", "/remote/other"]
|
||||
|
||||
def test_unique_parent_dirs_sorted(self):
|
||||
files = [
|
||||
("/local/z.txt", "/z/file.txt"),
|
||||
("/local/a.txt", "/a/file.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/a", "/z"]
|
||||
|
||||
def test_unique_parent_dirs_empty(self):
|
||||
assert unique_parent_dirs([]) == []
|
||||
|
||||
|
||||
class TestSSHBulkUploadEdgeCases:
|
||||
"""Edge cases for _ssh_bulk_upload."""
|
||||
|
||||
def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
|
||||
"""If SSH Popen raises, tar process must be killed and cleaned up."""
|
||||
f1 = tmp_path / "e.txt"
|
||||
f1.write_text("e")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
|
||||
|
||||
mock_tar = _mock_proc()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def failing_ssh_popen(cmd, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_tar # tar Popen succeeds
|
||||
raise OSError("SSH binary not found")
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
|
||||
with pytest.raises(OSError, match="SSH binary not found"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_tar.wait.assert_called_once()
|
||||
@@ -24,6 +24,18 @@ class TestWriteAndRead:
|
||||
items[0]["content"] = "MUTATED"
|
||||
assert store.read()[0]["content"] == "Task"
|
||||
|
||||
def test_write_deduplicates_duplicate_ids(self):
|
||||
store = TodoStore()
|
||||
result = store.write([
|
||||
{"id": "1", "content": "First version", "status": "pending"},
|
||||
{"id": "2", "content": "Other task", "status": "pending"},
|
||||
{"id": "1", "content": "Latest version", "status": "in_progress"},
|
||||
])
|
||||
assert result == [
|
||||
{"id": "2", "content": "Other task", "status": "pending"},
|
||||
{"id": "1", "content": "Latest version", "status": "in_progress"},
|
||||
]
|
||||
|
||||
|
||||
class TestHasItems:
|
||||
def test_empty_store(self):
|
||||
|
||||
Reference in New Issue
Block a user