Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor

This commit is contained in:
Brooklyn Nicholson
2026-04-08 19:11:44 -05:00
391 changed files with 20699 additions and 2902 deletions

View File

@@ -1276,6 +1276,258 @@ class TestRoleAlternation:
assert [m["role"] for m in result] == ["user", "assistant", "user"]
# ---------------------------------------------------------------------------
# Thinking block signature management
# ---------------------------------------------------------------------------
class TestThinkingBlockSignatureManagement:
"""Tests for the thinking block handling strategy:
strip from old turns, preserve latest signed, downgrade unsigned."""
def test_thinking_stripped_from_non_last_assistant(self):
"""Thinking blocks are removed from all assistant messages except the last."""
messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "tc_1", "function": {"name": "tool1", "arguments": "{}"}},
],
"reasoning_details": [
{"type": "thinking", "thinking": "Old reasoning.", "signature": "sig_old"},
],
},
{"role": "tool", "tool_call_id": "tc_1", "content": "result 1"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "tc_2", "function": {"name": "tool2", "arguments": "{}"}},
],
"reasoning_details": [
{"type": "thinking", "thinking": "Latest reasoning.", "signature": "sig_new"},
],
},
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
]
_, result = convert_messages_to_anthropic(messages)
# Find both assistant messages
assistants = [m for m in result if m["role"] == "assistant"]
assert len(assistants) == 2
# First (non-last) assistant: no thinking blocks
first_types = [b.get("type") for b in assistants[0]["content"]]
assert "thinking" not in first_types
assert "redacted_thinking" not in first_types
assert "tool_use" in first_types # tool_use should survive
# Last assistant: thinking block preserved with signature
last_blocks = assistants[1]["content"]
thinking_blocks = [b for b in last_blocks if b.get("type") == "thinking"]
assert len(thinking_blocks) == 1
assert thinking_blocks[0]["thinking"] == "Latest reasoning."
assert thinking_blocks[0]["signature"] == "sig_new"
def test_signed_thinking_preserved_on_last_turn(self):
"""A signed thinking block on the last assistant message is kept."""
messages = [
{
"role": "assistant",
"content": "The answer is 42.",
"reasoning_details": [
{"type": "thinking", "thinking": "Deep thought.", "signature": "sig_valid"},
],
},
]
_, result = convert_messages_to_anthropic(messages)
blocks = result[0]["content"]
thinking = [b for b in blocks if b.get("type") == "thinking"]
assert len(thinking) == 1
assert thinking[0]["signature"] == "sig_valid"
def test_unsigned_thinking_downgraded_to_text_on_last_turn(self):
"""Unsigned thinking blocks on the last turn become text blocks."""
messages = [
{
"role": "assistant",
"content": "Response text.",
"reasoning_details": [
{"type": "thinking", "thinking": "Unsigned reasoning."},
# No 'signature' field
],
},
]
_, result = convert_messages_to_anthropic(messages)
blocks = result[0]["content"]
# No thinking blocks should remain
assert not any(b.get("type") == "thinking" for b in blocks)
# The reasoning text should be preserved as a text block
text_contents = [b.get("text", "") for b in blocks if b.get("type") == "text"]
assert "Unsigned reasoning." in text_contents
def test_redacted_thinking_with_data_preserved(self):
"""Redacted thinking with 'data' field is kept on last turn."""
messages = [
{
"role": "assistant",
"content": "Response.",
"reasoning_details": [
{"type": "redacted_thinking", "data": "opaque_signature_data"},
],
},
]
_, result = convert_messages_to_anthropic(messages)
blocks = result[0]["content"]
redacted = [b for b in blocks if b.get("type") == "redacted_thinking"]
assert len(redacted) == 1
assert redacted[0]["data"] == "opaque_signature_data"
def test_redacted_thinking_without_data_dropped(self):
"""Redacted thinking without 'data' is dropped — can't be validated."""
messages = [
{
"role": "assistant",
"content": "Response.",
"reasoning_details": [
{"type": "redacted_thinking"},
# No 'data' field
],
},
]
_, result = convert_messages_to_anthropic(messages)
blocks = result[0]["content"]
assert not any(b.get("type") == "redacted_thinking" for b in blocks)
def test_cache_control_stripped_from_thinking_blocks(self):
"""cache_control markers are removed from thinking/redacted_thinking blocks."""
messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "tc_1", "function": {"name": "t", "arguments": "{}"}},
],
"reasoning_details": [
{
"type": "thinking",
"thinking": "Reasoning.",
"signature": "sig_1",
"cache_control": {"type": "ephemeral"},
},
],
},
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
]
_, result = convert_messages_to_anthropic(messages)
assistant = next(m for m in result if m["role"] == "assistant")
for block in assistant["content"]:
if block.get("type") in ("thinking", "redacted_thinking"):
assert "cache_control" not in block
def test_thinking_stripped_from_merged_consecutive_assistants(self):
"""When consecutive assistants are merged, second one's thinking is dropped."""
messages = [
{
"role": "assistant",
"content": "First response.",
"reasoning_details": [
{"type": "thinking", "thinking": "First thought.", "signature": "sig_1"},
],
},
{
"role": "assistant",
"content": "Second response.",
"reasoning_details": [
{"type": "thinking", "thinking": "Second thought.", "signature": "sig_2"},
],
},
]
_, result = convert_messages_to_anthropic(messages)
# Should be merged into one assistant message
assistants = [m for m in result if m["role"] == "assistant"]
assert len(assistants) == 1
# Only the first thinking block should remain (signed, on the last/only assistant)
blocks = assistants[0]["content"]
thinking = [b for b in blocks if b.get("type") == "thinking"]
assert len(thinking) == 1
assert thinking[0]["thinking"] == "First thought."
def test_empty_content_after_strip_gets_placeholder(self):
"""If stripping thinking leaves an empty message, a placeholder is added."""
messages = [
{
"role": "assistant",
"content": "",
"reasoning_details": [
{"type": "thinking", "thinking": "Only thinking, no text."},
# Unsigned — will be downgraded, but content was empty string
],
},
{"role": "user", "content": "Next message."},
{"role": "assistant", "content": "Final."},
]
_, result = convert_messages_to_anthropic(messages)
# First assistant is non-last, so thinking is stripped completely.
# The original content was empty and thinking was unsigned → placeholder
first_assistant = result[0]
assert first_assistant["role"] == "assistant"
assert len(first_assistant["content"]) >= 1
def test_multi_turn_conversation_preserves_only_last(self):
"""Full multi-turn conversation: only last assistant keeps thinking."""
messages = [
{"role": "user", "content": "Question 1"},
{
"role": "assistant",
"content": "Answer 1",
"reasoning_details": [
{"type": "thinking", "thinking": "Thought 1", "signature": "sig_1"},
],
},
{"role": "user", "content": "Question 2"},
{
"role": "assistant",
"content": "Answer 2",
"reasoning_details": [
{"type": "thinking", "thinking": "Thought 2", "signature": "sig_2"},
],
},
{"role": "user", "content": "Question 3"},
{
"role": "assistant",
"content": "Answer 3",
"reasoning_details": [
{"type": "thinking", "thinking": "Thought 3", "signature": "sig_3"},
],
},
]
_, result = convert_messages_to_anthropic(messages)
assistants = [m for m in result if m["role"] == "assistant"]
assert len(assistants) == 3
# First two: no thinking blocks
for a in assistants[:2]:
assert not any(
b.get("type") in ("thinking", "redacted_thinking")
for b in a["content"]
if isinstance(b, dict)
)
# Last one: thinking preserved
last_thinking = [
b for b in assistants[2]["content"]
if isinstance(b, dict) and b.get("type") == "thinking"
]
assert len(last_thinking) == 1
assert last_thinking[0]["signature"] == "sig_3"
# ---------------------------------------------------------------------------
# Tool choice
# ---------------------------------------------------------------------------

View File

@@ -471,6 +471,23 @@ class TestExplicitProviderRouting:
client, model = resolve_provider_client("zai")
assert client is not None
def test_explicit_google_alias_uses_gemini_credentials(self):
"""provider='google' should route through the gemini API-key provider."""
with (
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
"api_key": "gemini-key",
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
}),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
mock_openai.return_value = MagicMock()
client, model = resolve_provider_client("google", model="gemini-3.1-pro-preview")
assert client is not None
assert model == "gemini-3.1-pro-preview"
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
def test_explicit_unknown_returns_none(self, monkeypatch):
"""Unknown provider should return None."""
client, model = resolve_provider_client("nonexistent-provider")
@@ -624,12 +641,15 @@ class TestVisionClientFallback:
assert client is None
assert model is None
def test_vision_auto_includes_anthropic_when_configured(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
"""Active provider appears in available backends when credentials exist."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
):
backends = get_available_vision_backends()
@@ -702,88 +722,51 @@ class TestAuxiliaryPoolAwareness:
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
assert call_kwargs["default_headers"]["Editor-Version"]
def test_vision_auto_uses_anthropic_when_no_higher_priority_backend(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
):
client, model = get_vision_auxiliary_client()
assert client is not None
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
assert model == "claude-haiku-4-5-20251001"
def test_selected_anthropic_provider_is_preferred_for_vision_auto(self, monkeypatch):
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
"""Active provider is tried before OpenRouter in vision auto."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
def fake_load_config():
return {"model": {"provider": "anthropic", "default": "claude-sonnet-4-6"}}
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
patch("hermes_cli.config.load_config", fake_load_config),
):
client, model = get_vision_auxiliary_client()
assert client is not None
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
assert model == "claude-haiku-4-5-20251001"
def test_selected_codex_provider_short_circuits_vision_auto(self, monkeypatch):
def fake_load_config():
return {"model": {"provider": "openai-codex", "default": "gpt-5.2-codex"}}
codex_client = MagicMock()
with (
patch("hermes_cli.config.load_config", fake_load_config),
patch("agent.auxiliary_client._try_codex", return_value=(codex_client, "gpt-5.2-codex")) as mock_codex,
patch("agent.auxiliary_client._try_openrouter") as mock_openrouter,
patch("agent.auxiliary_client._try_nous") as mock_nous,
patch("agent.auxiliary_client._try_anthropic") as mock_anthropic,
patch("agent.auxiliary_client._try_custom_endpoint") as mock_custom,
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
):
provider, client, model = resolve_vision_provider_client()
assert provider == "openai-codex"
assert client is codex_client
assert model == "gpt-5.2-codex"
mock_codex.assert_called_once()
mock_openrouter.assert_not_called()
mock_nous.assert_not_called()
mock_anthropic.assert_not_called()
mock_custom.assert_not_called()
# Active provider should win over OpenRouter
assert provider == "anthropic"
def test_vision_auto_includes_codex(self, codex_auth_dir):
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = get_vision_auxiliary_client()
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
"""Custom endpoint is used as fallback in vision auto mode.
Many local models (Qwen-VL, LLaVA, etc.) support vision.
When no OpenRouter/Nous/Codex is available, try the custom endpoint.
"""
def test_vision_auto_uses_named_custom_as_active_provider(self, monkeypatch):
"""Named custom provider works as active provider fallback in vision auto."""
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
patch("agent.auxiliary_client._resolve_custom_runtime",
return_value=("http://localhost:1234/v1", "local-key")), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert client is not None # Custom endpoint picked up as fallback
patch("agent.auxiliary_client._read_main_provider", return_value="custom:local"), \
patch("agent.auxiliary_client._read_main_model", return_value="my-local-model"), \
patch("agent.auxiliary_client.resolve_provider_client",
return_value=(MagicMock(), "my-local-model")) as mock_resolve:
provider, client, model = resolve_vision_provider_client()
assert client is not None
assert provider == "custom:local"
def test_vision_direct_endpoint_override(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
@@ -822,6 +805,31 @@ class TestAuxiliaryPoolAwareness:
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
config = {
"auxiliary": {
"vision": {
"provider": "google",
"model": "gemini-3.1-pro-preview",
}
}
}
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
with (
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
"api_key": "gemini-key",
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
}),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
resolved_provider, client, model = resolve_vision_provider_client()
assert resolved_provider == "gemini"
assert client is not None
assert model == "gemini-3.1-pro-preview"
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
config = {
@@ -846,7 +854,14 @@ class TestAuxiliaryPoolAwareness:
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
# Clear client cache to avoid stale entries from previous tests
from agent.auxiliary_client import _client_cache
_client_cache.clear()
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
patch("agent.auxiliary_client._read_main_model", return_value=""), \
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
client, model = get_vision_auxiliary_client()

View File

@@ -13,7 +13,7 @@ from unittest.mock import patch, MagicMock
import pytest
import yaml
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def _run_auxiliary_bridge(config_dict, monkeypatch):
@@ -199,7 +199,7 @@ class TestGatewayBridgeCodeParity:
def test_gateway_has_auxiliary_bridge(self):
"""The gateway config bridge must include auxiliary.* bridging."""
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
content = gateway_path.read_text()
# Check for key patterns that indicate the bridge is present
assert "AUXILIARY_VISION_PROVIDER" in content
@@ -213,7 +213,7 @@ class TestGatewayBridgeCodeParity:
def test_gateway_no_compression_env_bridge(self):
"""Gateway should NOT bridge compression config to env vars (config-only)."""
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
content = gateway_path.read_text()
assert "CONTEXT_COMPRESSION_PROVIDER" not in content
assert "CONTEXT_COMPRESSION_MODEL" not in content

View File

@@ -0,0 +1,151 @@
"""Tests for named custom provider and 'main' alias resolution in auxiliary_client."""
import os
from unittest.mock import patch, MagicMock
import pytest
@pytest.fixture(autouse=True)
def _isolate(tmp_path, monkeypatch):
"""Redirect HERMES_HOME and clear module caches."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
# Write a minimal config so load_config doesn't fail
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
def _write_config(tmp_path, config_dict):
"""Write a config.yaml to the test HERMES_HOME."""
import yaml
config_path = tmp_path / ".hermes" / "config.yaml"
config_path.write_text(yaml.dump(config_dict))
class TestNormalizeVisionProvider:
"""_normalize_vision_provider should resolve 'main' to actual main provider."""
def test_main_resolves_to_named_custom(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "my-model", "provider": "custom:beans"},
"custom_providers": [{"name": "beans", "base_url": "http://localhost/v1"}],
})
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("main") == "custom:beans"
def test_main_resolves_to_openrouter(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "anthropic/claude-sonnet-4", "provider": "openrouter"},
})
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("main") == "openrouter"
def test_main_resolves_to_deepseek(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "deepseek-chat", "provider": "deepseek"},
})
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("main") == "deepseek"
def test_main_falls_back_to_custom_when_no_provider(self, tmp_path):
_write_config(tmp_path, {"model": {"default": "gpt-4o"}})
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("main") == "custom"
def test_bare_provider_name_unchanged(self):
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("beans") == "beans"
assert _normalize_vision_provider("deepseek") == "deepseek"
def test_codex_alias_still_works(self):
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("codex") == "openai-codex"
def test_auto_unchanged(self):
from agent.auxiliary_client import _normalize_vision_provider
assert _normalize_vision_provider("auto") == "auto"
assert _normalize_vision_provider(None) == "auto"
class TestResolveProviderClientMainAlias:
"""resolve_provider_client('main', ...) should resolve to actual main provider."""
def test_main_resolves_to_named_custom_provider(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "my-model", "provider": "beans"},
"custom_providers": [
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
],
})
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("main", "override-model")
assert client is not None
assert model == "override-model"
assert "beans.local" in str(client.base_url)
def test_main_with_custom_colon_prefix(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "my-model", "provider": "custom:beans"},
"custom_providers": [
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
],
})
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("main", "test")
assert client is not None
assert "beans.local" in str(client.base_url)
class TestResolveProviderClientNamedCustom:
"""resolve_provider_client should resolve named custom providers directly."""
def test_named_custom_provider(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "test-model"},
"custom_providers": [
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
],
})
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("beans", "my-model")
assert client is not None
assert model == "my-model"
assert "beans.local" in str(client.base_url)
def test_named_custom_provider_default_model(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "main-model"},
"custom_providers": [
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
],
})
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("beans")
assert client is not None
# Should use _read_main_model() fallback
assert model == "main-model"
def test_named_custom_no_api_key_uses_fallback(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "test"},
"custom_providers": [
{"name": "local", "base_url": "http://localhost:8080/v1"},
],
})
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("local", "test")
assert client is not None
# no-key-required should be used
def test_nonexistent_named_custom_falls_through(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "test"},
"custom_providers": [
{"name": "beans", "base_url": "http://beans.local/v1"},
],
})
from agent.auxiliary_client import resolve_provider_client
# "coffee" doesn't exist in custom_providers
client, model = resolve_provider_client("coffee", "test")
assert client is None

View File

@@ -197,6 +197,44 @@ class TestNonStringContent:
assert summary is not None
assert summary == SUMMARY_PREFIX
def test_summary_call_does_not_force_temperature(self):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "ok"
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
c = ContextCompressor(model="test", quiet_mode=True)
messages = [
{"role": "user", "content": "do something"},
{"role": "assistant", "content": "ok"},
]
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_call:
c._generate_summary(messages)
kwargs = mock_call.call_args.kwargs
assert "temperature" not in kwargs
class TestSummaryFailureCooldown:
def test_summary_failure_enters_cooldown_and_skips_retry(self):
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
c = ContextCompressor(model="test", quiet_mode=True)
messages = [
{"role": "user", "content": "do something"},
{"role": "assistant", "content": "ok"},
]
with patch("agent.context_compressor.call_llm", side_effect=Exception("boom")) as mock_call:
first = c._generate_summary(messages)
second = c._generate_summary(messages)
assert first is None
assert second is None
assert mock_call.call_count == 1
class TestSummaryPrefixNormalization:
def test_legacy_prefix_is_replaced(self):

View File

@@ -947,7 +947,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "sk-ant-xxx",
"access_token": "***",
}
],
"custom:together.ai": [
@@ -957,7 +957,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "sk-tog-xxx",
"access_token": "***",
}
],
"custom:fireworks": [
@@ -967,7 +967,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "sk-fw-xxx",
"access_token": "***",
}
],
"custom:empty": [],
@@ -980,3 +980,78 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
result = list_custom_pool_providers()
assert result == ["custom:fireworks", "custom:together.ai"]
# "custom:empty" not included because it's empty
def test_acquire_lease_prefers_unleased_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"openrouter": [
{
"id": "cred-1",
"label": "primary",
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "***",
},
{
"id": "cred-2",
"label": "secondary",
"auth_type": "api_key",
"priority": 1,
"source": "manual",
"access_token": "***",
},
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("openrouter")
first = pool.acquire_lease()
second = pool.acquire_lease()
assert first == "cred-1"
assert second == "cred-2"
assert pool.active_lease_count("cred-1") == 1
assert pool.active_lease_count("cred-2") == 1
def test_release_lease_decrements_counter(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"openrouter": [
{
"id": "cred-1",
"label": "primary",
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "***",
}
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("openrouter")
leased = pool.acquire_lease()
assert leased == "cred-1"
assert pool.active_lease_count("cred-1") == 1
pool.release_lease("cred-1")
assert pool.active_lease_count("cred-1") == 0

View File

@@ -0,0 +1,289 @@
"""Tests for per-user memory scoping via user_id threading.
Verifies that gateway user_id flows from AIAgent -> MemoryManager -> plugins,
so each gateway user gets their own memory bucket instead of sharing a static one.
"""
import json
import os
import pytest
from unittest.mock import MagicMock, patch
from agent.memory_provider import MemoryProvider
from agent.memory_manager import MemoryManager
# ---------------------------------------------------------------------------
# Concrete test provider that records init kwargs
# ---------------------------------------------------------------------------
class RecordingProvider(MemoryProvider):
"""Minimal provider that records what initialize() receives."""
def __init__(self, name="recording"):
self._name = name
self._init_kwargs = {}
self._init_session_id = None
@property
def name(self) -> str:
return self._name
def is_available(self) -> bool:
return True
def initialize(self, session_id: str, **kwargs) -> None:
self._init_session_id = session_id
self._init_kwargs = dict(kwargs)
def system_prompt_block(self) -> str:
return ""
def prefetch(self, query: str, *, session_id: str = "") -> str:
return ""
def sync_turn(self, user_content, assistant_content, *, session_id=""):
pass
def get_tool_schemas(self):
return []
def handle_tool_call(self, tool_name, args, **kwargs):
return json.dumps({})
def shutdown(self):
pass
# ---------------------------------------------------------------------------
# MemoryManager user_id threading tests
# ---------------------------------------------------------------------------
class TestMemoryManagerUserIdThreading:
"""Verify user_id reaches providers via initialize_all."""
def test_user_id_forwarded_to_provider(self):
mgr = MemoryManager()
p = RecordingProvider()
mgr.add_provider(p)
mgr.initialize_all(
session_id="sess-123",
platform="telegram",
user_id="tg_user_42",
)
assert p._init_kwargs.get("user_id") == "tg_user_42"
assert p._init_kwargs.get("platform") == "telegram"
assert p._init_session_id == "sess-123"
def test_no_user_id_when_cli(self):
"""CLI sessions should not have user_id in kwargs."""
mgr = MemoryManager()
p = RecordingProvider()
mgr.add_provider(p)
mgr.initialize_all(
session_id="sess-456",
platform="cli",
)
assert "user_id" not in p._init_kwargs
assert p._init_kwargs.get("platform") == "cli"
def test_user_id_none_not_forwarded(self):
"""Explicit None user_id should not appear in kwargs."""
mgr = MemoryManager()
p = RecordingProvider()
mgr.add_provider(p)
# Simulates what happens when AIAgent passes user_id=None
# (the agent code only adds user_id to kwargs when it's truthy)
mgr.initialize_all(
session_id="sess-789",
platform="discord",
)
assert "user_id" not in p._init_kwargs
def test_multiple_providers_all_receive_user_id(self):
from agent.builtin_memory_provider import BuiltinMemoryProvider
mgr = MemoryManager()
# Use builtin + one external (MemoryManager only allows one external)
builtin = BuiltinMemoryProvider()
ext = RecordingProvider("external")
mgr.add_provider(builtin)
mgr.add_provider(ext)
mgr.initialize_all(
session_id="sess-multi",
platform="slack",
user_id="slack_U12345",
)
assert ext._init_kwargs.get("user_id") == "slack_U12345"
assert ext._init_kwargs.get("platform") == "slack"
# ---------------------------------------------------------------------------
# Mem0 provider user_id tests
# ---------------------------------------------------------------------------
class TestMem0UserIdScoping:
"""Verify Mem0 plugin uses gateway user_id when provided."""
def test_gateway_user_id_overrides_default(self):
"""When user_id is passed via kwargs, it should override the config default."""
from plugins.memory.mem0 import Mem0MemoryProvider
provider = Mem0MemoryProvider()
# Mock _load_config to return a config with default user_id
with patch("plugins.memory.mem0._load_config", return_value={
"api_key": "test-key",
"user_id": "hermes-user",
"agent_id": "hermes",
"rerank": True,
}):
provider.initialize(session_id="test-sess", user_id="tg_user_99")
assert provider._user_id == "tg_user_99"
def test_no_user_id_falls_back_to_config(self):
"""Without user_id in kwargs, should use config default."""
from plugins.memory.mem0 import Mem0MemoryProvider
provider = Mem0MemoryProvider()
with patch("plugins.memory.mem0._load_config", return_value={
"api_key": "test-key",
"user_id": "custom-default",
"agent_id": "hermes",
"rerank": True,
}):
provider.initialize(session_id="test-sess")
assert provider._user_id == "custom-default"
def test_no_user_id_no_config_uses_hermes_user(self):
"""Without user_id or config override, should default to 'hermes-user'."""
from plugins.memory.mem0 import Mem0MemoryProvider
provider = Mem0MemoryProvider()
with patch("plugins.memory.mem0._load_config", return_value={
"api_key": "test-key",
"agent_id": "hermes",
"rerank": True,
}):
provider.initialize(session_id="test-sess")
assert provider._user_id == "hermes-user"
def test_different_users_get_different_ids(self):
"""Two providers initialized with different user_ids should be scoped differently."""
from plugins.memory.mem0 import Mem0MemoryProvider
p1 = Mem0MemoryProvider()
p2 = Mem0MemoryProvider()
with patch("plugins.memory.mem0._load_config", return_value={
"api_key": "test-key",
"user_id": "hermes-user",
"agent_id": "hermes",
"rerank": True,
}):
p1.initialize(session_id="sess-1", user_id="alice_123")
p2.initialize(session_id="sess-2", user_id="bob_456")
assert p1._user_id == "alice_123"
assert p2._user_id == "bob_456"
assert p1._user_id != p2._user_id
# ---------------------------------------------------------------------------
# Honcho provider user_id tests
# ---------------------------------------------------------------------------
class TestHonchoUserIdScoping:
"""Verify Honcho plugin uses gateway user_id for peer_name when provided."""
def test_gateway_user_id_overrides_peer_name(self):
"""When user_id is in kwargs, cfg.peer_name should be overridden."""
from plugins.memory.honcho import HonchoMemoryProvider
provider = HonchoMemoryProvider()
# Create a mock config with a static peer_name
mock_cfg = MagicMock()
mock_cfg.enabled = True
mock_cfg.api_key = "test-key"
mock_cfg.base_url = None
mock_cfg.peer_name = "static-user"
mock_cfg.recall_mode = "tools" # Use tools mode to defer session init
with patch(
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
return_value=mock_cfg,
):
provider.initialize(
session_id="test-sess",
user_id="discord_user_789",
platform="discord",
)
# The config's peer_name should have been overridden with the user_id
assert mock_cfg.peer_name == "discord_user_789"
def test_no_user_id_preserves_config_peer_name(self):
"""Without user_id, the config peer_name should be preserved."""
from plugins.memory.honcho import HonchoMemoryProvider
provider = HonchoMemoryProvider()
mock_cfg = MagicMock()
mock_cfg.enabled = True
mock_cfg.api_key = "test-key"
mock_cfg.base_url = None
mock_cfg.peer_name = "my-custom-peer"
mock_cfg.recall_mode = "tools"
with patch(
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
return_value=mock_cfg,
):
provider.initialize(
session_id="test-sess",
platform="cli",
)
# peer_name should not have been overridden
assert mock_cfg.peer_name == "my-custom-peer"
# ---------------------------------------------------------------------------
# AIAgent user_id propagation test
# ---------------------------------------------------------------------------
class TestAIAgentUserIdPropagation:
"""Verify AIAgent stores user_id and passes it to memory init kwargs."""
def test_user_id_stored_on_agent(self):
"""AIAgent should store user_id as instance attribute."""
with patch.dict(os.environ, {"HERMES_HOME": "/tmp/test_hermes"}):
from run_agent import AIAgent
agent = object.__new__(AIAgent)
# Manually set the attribute as __init__ does
agent._user_id = "test_user_42"
assert agent._user_id == "test_user_42"
def test_user_id_none_by_default(self):
"""AIAgent should have None user_id when not provided (CLI mode)."""
with patch.dict(os.environ, {"HERMES_HOME": "/tmp/test_hermes"}):
from run_agent import AIAgent
agent = object.__new__(AIAgent)
agent._user_id = None
assert agent._user_id is None

View File

@@ -0,0 +1,42 @@
"""Tests for MiniMax auxiliary client URL normalization.
MiniMax and MiniMax-CN set inference_base_url to the /anthropic path.
The auxiliary client uses the OpenAI SDK, which needs /v1 instead.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from agent.auxiliary_client import _to_openai_base_url
class TestToOpenaiBaseUrl:
def test_minimax_global_anthropic_suffix_replaced(self):
assert _to_openai_base_url("https://api.minimax.io/anthropic") == "https://api.minimax.io/v1"
def test_minimax_cn_anthropic_suffix_replaced(self):
assert _to_openai_base_url("https://api.minimaxi.com/anthropic") == "https://api.minimaxi.com/v1"
def test_trailing_slash_stripped_before_replace(self):
assert _to_openai_base_url("https://api.minimax.io/anthropic/") == "https://api.minimax.io/v1"
def test_v1_url_unchanged(self):
assert _to_openai_base_url("https://api.openai.com/v1") == "https://api.openai.com/v1"
def test_openrouter_url_unchanged(self):
assert _to_openai_base_url("https://openrouter.ai/api/v1") == "https://openrouter.ai/api/v1"
def test_anthropic_domain_unchanged(self):
"""api.anthropic.com doesn't end with /anthropic — should be untouched."""
assert _to_openai_base_url("https://api.anthropic.com") == "https://api.anthropic.com"
def test_anthropic_in_subpath_unchanged(self):
assert _to_openai_base_url("https://example.com/anthropic/extra") == "https://example.com/anthropic/extra"
def test_empty_string(self):
assert _to_openai_base_url("") == ""
def test_none(self):
assert _to_openai_base_url(None) == ""

View File

@@ -0,0 +1,105 @@
"""Tests for MiniMax provider hardening — context lengths, thinking guard, catalog."""
class TestMinimaxContextLengths:
"""Verify per-model context length entries for MiniMax models."""
def test_m1_variants_have_1m_context(self):
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
# Keys are lowercase because the lookup lowercases model names
for model in ("minimax-m1", "minimax-m1-40k", "minimax-m1-80k",
"minimax-m1-128k", "minimax-m1-256k"):
assert model in DEFAULT_CONTEXT_LENGTHS, f"{model} missing from context lengths"
assert DEFAULT_CONTEXT_LENGTHS[model] == 1_000_000, f"{model} expected 1M"
def test_m2_variants_have_1m_context(self):
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
# Keys are lowercase because the lookup lowercases model names
for model in ("minimax-m2.5", "minimax-m2.7"):
assert model in DEFAULT_CONTEXT_LENGTHS, f"{model} missing from context lengths"
assert DEFAULT_CONTEXT_LENGTHS[model] == 1_048_576, f"{model} expected 1048576"
def test_minimax_prefix_fallback(self):
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
# The generic "minimax" prefix entry should be 1M for unknown models
assert DEFAULT_CONTEXT_LENGTHS["minimax"] == 1_048_576
class TestMinimaxThinkingGuard:
"""Verify that build_anthropic_kwargs does NOT add thinking params for MiniMax models."""
def test_no_thinking_for_minimax_m27(self):
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="MiniMax-M2.7",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "medium"},
)
assert "thinking" not in kwargs
assert "output_config" not in kwargs
def test_no_thinking_for_minimax_m1(self):
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="MiniMax-M1-128k",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "high"},
)
assert "thinking" not in kwargs
def test_thinking_still_works_for_claude(self):
from agent.anthropic_adapter import build_anthropic_kwargs
kwargs = build_anthropic_kwargs(
model="claude-sonnet-4-20250514",
messages=[{"role": "user", "content": "hello"}],
tools=None,
max_tokens=4096,
reasoning_config={"enabled": True, "effort": "medium"},
)
assert "thinking" in kwargs
class TestMinimaxAuxModel:
"""Verify auxiliary model is standard (not highspeed)."""
def test_minimax_aux_is_standard(self):
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
assert _API_KEY_PROVIDER_AUX_MODELS["minimax"] == "MiniMax-M2.7"
assert _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"] == "MiniMax-M2.7"
def test_minimax_aux_not_highspeed(self):
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax"]
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"]
class TestMinimaxModelCatalog:
"""Verify the model catalog includes M1 family and excludes deprecated models."""
def test_catalog_includes_m1_family(self):
from hermes_cli.models import _PROVIDER_MODELS
for provider in ("minimax", "minimax-cn"):
models = _PROVIDER_MODELS[provider]
assert "MiniMax-M1" in models
assert "MiniMax-M1-40k" in models
assert "MiniMax-M1-80k" in models
assert "MiniMax-M1-128k" in models
assert "MiniMax-M1-256k" in models
def test_catalog_excludes_deprecated(self):
from hermes_cli.models import _PROVIDER_MODELS
for provider in ("minimax", "minimax-cn"):
models = _PROVIDER_MODELS[provider]
assert "MiniMax-M2.1" not in models
def test_catalog_excludes_highspeed(self):
from hermes_cli.models import _PROVIDER_MODELS
for provider in ("minimax", "minimax-cn"):
models = _PROVIDER_MODELS[provider]
assert "MiniMax-M2.7-highspeed" not in models
assert "MiniMax-M2.5-highspeed" not in models

View File

@@ -423,7 +423,7 @@ class TestBuildNousSubscriptionPrompt:
"web": NousFeatureState("web", "Web tools", True, True, True, True, False, True, "firecrawl"),
"image_gen": NousFeatureState("image_gen", "Image generation", True, True, True, True, False, True, "Nous Subscription"),
"tts": NousFeatureState("tts", "OpenAI TTS", True, True, True, True, False, True, "OpenAI TTS"),
"browser": NousFeatureState("browser", "Browser automation", True, True, True, True, False, True, "Browserbase"),
"browser": NousFeatureState("browser", "Browser automation", True, True, True, True, False, True, "Browser Use"),
"modal": NousFeatureState("modal", "Modal execution", False, True, False, False, False, True, "local"),
},
),
@@ -431,9 +431,9 @@ class TestBuildNousSubscriptionPrompt:
prompt = build_nous_subscription_prompt({"web_search", "browser_navigate"})
assert "Browserbase" in prompt
assert "Browser Use" in prompt
assert "Modal execution is optional" in prompt
assert "do not ask the user for Firecrawl, FAL, OpenAI TTS, or Browserbase API keys" in prompt
assert "do not ask the user for Firecrawl, FAL, OpenAI TTS, or Browser-Use API keys" in prompt
def test_non_subscriber_prompt_includes_relevant_upgrade_guidance(self, monkeypatch):
monkeypatch.setenv("HERMES_ENABLE_NOUS_MANAGED_TOOLS", "1")

0
tests/cli/__init__.py Normal file
View File

View File

@@ -0,0 +1,46 @@
"""Tests for CLI browser CDP auto-launch helpers."""
import os
from unittest.mock import patch
from cli import HermesCLI
class TestChromeDebugLaunch:
def test_windows_launch_uses_browser_found_on_path(self):
captured = {}
def fake_popen(cmd, **kwargs):
captured["cmd"] = cmd
captured["kwargs"] = kwargs
return object()
with patch("cli.shutil.which", side_effect=lambda name: r"C:\Chrome\chrome.exe" if name == "chrome.exe" else None), \
patch("cli.os.path.isfile", side_effect=lambda path: path == r"C:\Chrome\chrome.exe"), \
patch("subprocess.Popen", side_effect=fake_popen):
assert HermesCLI._try_launch_chrome_debug(9333, "Windows") is True
assert captured["cmd"] == [r"C:\Chrome\chrome.exe", "--remote-debugging-port=9333"]
assert captured["kwargs"]["start_new_session"] is True
def test_windows_launch_falls_back_to_common_install_dirs(self, monkeypatch):
captured = {}
program_files = r"C:\Program Files"
# Use os.path.join so path separators match cross-platform
installed = os.path.join(program_files, "Google", "Chrome", "Application", "chrome.exe")
def fake_popen(cmd, **kwargs):
captured["cmd"] = cmd
captured["kwargs"] = kwargs
return object()
monkeypatch.setenv("ProgramFiles", program_files)
monkeypatch.delenv("ProgramFiles(x86)", raising=False)
monkeypatch.delenv("LOCALAPPDATA", raising=False)
with patch("cli.shutil.which", return_value=None), \
patch("cli.os.path.isfile", side_effect=lambda path: path == installed), \
patch("subprocess.Popen", side_effect=fake_popen):
assert HermesCLI._try_launch_chrome_debug(9222, "Windows") is True
assert captured["cmd"] == [installed, "--remote-debugging-port=9222"]

View File

@@ -330,7 +330,7 @@ def test_model_flow_nous_prints_subscription_guidance_without_mutating_explicit_
"hermes_cli.auth.fetch_nous_models",
lambda *args, **kwargs: ["claude-opus-4-6"],
)
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None: "claude-opus-4-6")
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None, **kw: "claude-opus-4-6")
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: None)
monkeypatch.setattr("hermes_cli.auth._update_config_for_provider", lambda provider, url: None)
monkeypatch.setattr(
@@ -368,7 +368,7 @@ def test_model_flow_nous_applies_managed_tts_default_when_unconfigured(monkeypat
"hermes_cli.auth.fetch_nous_models",
lambda *args, **kwargs: ["claude-opus-4-6"],
)
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None: "claude-opus-4-6")
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None, **kw: "claude-opus-4-6")
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: None)
monkeypatch.setattr("hermes_cli.auth._update_config_for_provider", lambda provider, url: None)
monkeypatch.setattr(
@@ -538,7 +538,7 @@ def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
return "openrouter"
monkeypatch.setattr("hermes_cli.auth.resolve_provider", _resolve_provider)
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices: len(choices) - 1)
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices, **kwargs: len(choices) - 1)
monkeypatch.setattr("sys.stdin", type("FakeTTY", (), {"isatty": lambda self: True})())
hermes_main.cmd_model(SimpleNamespace())
@@ -579,6 +579,7 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys):
# "Use this model? [Y/n]:" — confirm with Enter, then context length.
answers = iter(["http://localhost:8000", "local-key", "", ""])
monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers))
monkeypatch.setattr("getpass.getpass", lambda _prompt="": next(answers))
hermes_main._model_flow_custom({})
output = capsys.readouterr().out
@@ -601,7 +602,7 @@ def test_cmd_model_forwards_nous_login_tls_options(monkeypatch):
monkeypatch.setattr("hermes_cli.config.save_env_value", lambda key, value: None)
monkeypatch.setattr("hermes_cli.auth.resolve_provider", lambda requested, **kwargs: "nous")
monkeypatch.setattr("hermes_cli.auth.get_provider_auth_state", lambda provider_id: None)
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices: 0)
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices, **kwargs: 0)
captured = {}

View File

@@ -1,6 +1,6 @@
"""Regression tests for CLI /retry history replacement semantics."""
from tests.test_cli_init import _make_cli
from tests.cli.test_cli_init import _make_cli
def test_retry_last_truncates_history_before_requeueing_message():

View File

@@ -0,0 +1,98 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from cli import HermesCLI, _rich_text_from_ansi
from hermes_cli.skin_engine import get_active_skin, set_active_skin
def _make_cli_stub():
cli = HermesCLI.__new__(HermesCLI)
cli._sudo_state = None
cli._secret_state = None
cli._approval_state = None
cli._clarify_state = None
cli._clarify_freetext = False
cli._command_running = False
cli._agent_running = False
cli._voice_recording = False
cli._voice_processing = False
cli._voice_mode = False
cli._command_spinner_frame = lambda: ""
cli._tui_style_base = {
"prompt": "#fff",
"input-area": "#fff",
"input-rule": "#aaa",
"prompt-working": "#888 italic",
}
cli._app = SimpleNamespace(style=None)
cli._invalidate = MagicMock()
return cli
class TestCliSkinPromptIntegration:
def test_default_prompt_fragments_use_default_symbol(self):
cli = _make_cli_stub()
set_active_skin("default")
assert cli._get_tui_prompt_fragments() == [("class:prompt", " ")]
def test_ares_prompt_fragments_use_skin_symbol(self):
cli = _make_cli_stub()
set_active_skin("ares")
assert cli._get_tui_prompt_fragments() == [("class:prompt", " ")]
def test_secret_prompt_fragments_preserve_secret_state(self):
cli = _make_cli_stub()
cli._secret_state = {"response_queue": object()}
set_active_skin("ares")
assert cli._get_tui_prompt_fragments() == [("class:sudo-prompt", "🔑 ")]
def test_icon_only_skin_symbol_still_visible_in_special_states(self):
cli = _make_cli_stub()
cli._secret_state = {"response_queue": object()}
with patch("hermes_cli.skin_engine.get_active_prompt_symbol", return_value=""):
assert cli._get_tui_prompt_fragments() == [("class:sudo-prompt", "🔑 ⚔ ")]
def test_build_tui_style_dict_uses_skin_overrides(self):
cli = _make_cli_stub()
set_active_skin("ares")
skin = get_active_skin()
style_dict = cli._build_tui_style_dict()
assert style_dict["prompt"] == skin.get_color("prompt")
assert style_dict["input-rule"] == skin.get_color("input_rule")
assert style_dict["prompt-working"] == f"{skin.get_color('banner_dim')} italic"
assert style_dict["approval-title"] == f"{skin.get_color('ui_warn')} bold"
def test_apply_tui_skin_style_updates_running_app(self):
cli = _make_cli_stub()
set_active_skin("ares")
assert cli._apply_tui_skin_style() is True
assert cli._app.style is not None
cli._invalidate.assert_called_once_with(min_interval=0.0)
def test_handle_skin_command_refreshes_live_tui(self, capsys):
cli = _make_cli_stub()
with patch("cli.save_config_value", return_value=True):
cli._handle_skin_command("/skin ares")
output = capsys.readouterr().out
assert "Skin set to: ares (saved)" in output
assert "Prompt + TUI colors updated." in output
assert cli._app.style is not None
class TestAnsiRichTextHelper:
def test_preserves_literal_brackets(self):
text = _rich_text_from_ansi("[notatag] literal")
assert text.plain == "[notatag] literal"
def test_strips_ansi_but_keeps_plain_text(self):
text = _rich_text_from_ansi("\x1b[31mred\x1b[0m")
assert text.plain == "red"

View File

@@ -1,5 +1,6 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from cli import HermesCLI
@@ -78,6 +79,92 @@ class TestCLIStatusBar:
assert "$0.06" not in text # cost hidden by default
assert "15m" in text
def test_input_height_counts_wide_characters_using_cell_width(self):
cli_obj = _make_cli()
class _Doc:
lines = ["" * 10]
class _Buffer:
document = _Doc()
input_area = SimpleNamespace(buffer=_Buffer())
def _input_height():
try:
from prompt_toolkit.application import get_app
from prompt_toolkit.utils import get_cwidth
doc = input_area.buffer.document
prompt_width = max(2, get_cwidth(cli_obj._get_tui_prompt_text()))
try:
available_width = get_app().output.get_size().columns - prompt_width
except Exception:
import shutil
available_width = shutil.get_terminal_size((80, 24)).columns - prompt_width
if available_width < 10:
available_width = 40
visual_lines = 0
for line in doc.lines:
line_width = get_cwidth(line)
if line_width <= 0:
visual_lines += 1
else:
visual_lines += max(1, -(-line_width // available_width))
return min(max(visual_lines, 1), 8)
except Exception:
return 1
mock_app = MagicMock()
mock_app.output.get_size.return_value = MagicMock(columns=14)
with patch.object(HermesCLI, "_get_tui_prompt_text", return_value=" "), \
patch("prompt_toolkit.application.get_app", return_value=mock_app):
assert _input_height() == 2
def test_input_height_uses_prompt_toolkit_width_over_shutil(self):
cli_obj = _make_cli()
class _Doc:
lines = ["" * 10]
class _Buffer:
document = _Doc()
input_area = SimpleNamespace(buffer=_Buffer())
def _input_height():
try:
from prompt_toolkit.application import get_app
from prompt_toolkit.utils import get_cwidth
doc = input_area.buffer.document
prompt_width = max(2, get_cwidth(cli_obj._get_tui_prompt_text()))
try:
available_width = get_app().output.get_size().columns - prompt_width
except Exception:
import shutil
available_width = shutil.get_terminal_size((80, 24)).columns - prompt_width
if available_width < 10:
available_width = 40
visual_lines = 0
for line in doc.lines:
line_width = get_cwidth(line)
if line_width <= 0:
visual_lines += 1
else:
visual_lines += max(1, -(-line_width // available_width))
return min(max(visual_lines, 1), 8)
except Exception:
return 1
mock_app = MagicMock()
mock_app.output.get_size.return_value = MagicMock(columns=14)
with patch.object(HermesCLI, "_get_tui_prompt_text", return_value=" "), \
patch("prompt_toolkit.application.get_app", return_value=mock_app), \
patch("shutil.get_terminal_size") as mock_shutil:
assert _input_height() == 2
mock_shutil.assert_not_called()
def test_build_status_bar_text_no_cost_in_status_bar(self):
cli_obj = _attach_agent(
_make_cli(),

View File

@@ -0,0 +1,66 @@
import pytest
from unittest.mock import MagicMock, patch
from hermes_cli.plugins import VALID_HOOKS, PluginManager
import os
import shutil
import tempfile
from cli import HermesCLI
def test_session_hooks_in_valid_hooks():
"""Verify on_session_finalize and on_session_reset are registered as valid hooks."""
assert "on_session_finalize" in VALID_HOOKS
assert "on_session_reset" in VALID_HOOKS
@patch("hermes_cli.plugins.invoke_hook")
def test_session_finalize_on_reset(mock_invoke_hook):
"""Verify on_session_finalize fires when /new or /reset is used."""
cli = HermesCLI()
cli.agent = MagicMock()
cli.agent.session_id = "test-session-id"
# Simulate /new command which triggers on_session_finalize for the old session
cli.new_session(silent=True)
# Check if on_session_finalize was called for the old session
mock_invoke_hook.assert_any_call(
"on_session_finalize", session_id="test-session-id", platform="cli"
)
# Check if on_session_reset was called for the new session
mock_invoke_hook.assert_any_call(
"on_session_reset", session_id=cli.session_id, platform="cli"
)
@patch("hermes_cli.plugins.invoke_hook")
def test_session_finalize_on_cleanup(mock_invoke_hook):
"""Verify on_session_finalize fires during CLI exit cleanup."""
import cli as cli_mod
mock_agent = MagicMock()
mock_agent.session_id = "cleanup-session-id"
cli_mod._active_agent_ref = mock_agent
cli_mod._cleanup_done = False
cli_mod._run_cleanup()
mock_invoke_hook.assert_any_call(
"on_session_finalize", session_id="cleanup-session-id", platform="cli"
)
@patch("hermes_cli.plugins.invoke_hook")
def test_hook_errors_are_caught(mock_invoke_hook):
"""Verify hook exceptions are caught and don't crash the agent."""
mgr = PluginManager()
# Register a hook that raises
def bad_callback(**kwargs):
raise Exception("Hook failed")
mgr._hooks["on_session_finalize"] = [bad_callback]
# This should not raise
results = mgr.invoke_hook("on_session_finalize", session_id="test", platform="cli")
assert results == []

View File

@@ -33,6 +33,13 @@ def git_repo(tmp_path):
["git", "commit", "-m", "Initial commit"],
cwd=repo, capture_output=True,
)
# Add a fake remote ref so cleanup logic sees the initial commit as
# "pushed". Without this, `git log HEAD --not --remotes` treats every
# commit as unpushed and cleanup refuses to delete worktrees.
subprocess.run(
["git", "update-ref", "refs/remotes/origin/main", "HEAD"],
cwd=repo, capture_output=True,
)
return repo
@@ -81,7 +88,11 @@ def _setup_worktree(repo_root):
def _cleanup_worktree(info):
"""Test version of _cleanup_worktree."""
"""Test version of _cleanup_worktree.
Preserves the worktree only if it has unpushed commits.
Dirty working tree alone is not enough to keep it.
"""
wt_path = info["path"]
branch = info["branch"]
repo_root = info["repo_root"]
@@ -89,15 +100,15 @@ def _cleanup_worktree(info):
if not Path(wt_path).exists():
return
# Check for uncommitted changes
status = subprocess.run(
["git", "status", "--porcelain"],
# Check for unpushed commits
result = subprocess.run(
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
capture_output=True, text=True, timeout=10, cwd=wt_path,
)
has_changes = bool(status.stdout.strip())
has_unpushed = bool(result.stdout.strip())
if has_changes:
return False # Did not clean up
if has_unpushed:
return False # Did not clean up — has unpushed commits
subprocess.run(
["git", "worktree", "remove", wt_path, "--force"],
@@ -204,20 +215,45 @@ class TestWorktreeCleanup:
assert result is True
assert not Path(info["path"]).exists()
def test_dirty_worktree_kept(self, git_repo):
def test_dirty_worktree_cleaned_when_no_unpushed(self, git_repo):
"""Dirty working tree without unpushed commits is cleaned up.
Agent sessions typically leave untracked files / artifacts behind.
Since all real work is in pushed commits, these don't warrant
keeping the worktree.
"""
info = _setup_worktree(str(git_repo))
assert info is not None
# Make uncommitted changes
# Make uncommitted changes (untracked file)
(Path(info["path"]) / "new-file.txt").write_text("uncommitted")
subprocess.run(
["git", "add", "new-file.txt"],
cwd=info["path"], capture_output=True,
)
# The git_repo fixture already has a fake remote ref so the initial
# commit is seen as "pushed". No unpushed commits → cleanup proceeds.
result = _cleanup_worktree(info)
assert result is False
assert Path(info["path"]).exists() # Still there
assert result is True # Cleaned up despite dirty working tree
assert not Path(info["path"]).exists()
def test_worktree_with_unpushed_commits_kept(self, git_repo):
"""Worktree with unpushed commits is preserved."""
info = _setup_worktree(str(git_repo))
assert info is not None
# Make a commit that is NOT on any remote
(Path(info["path"]) / "work.txt").write_text("real work")
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
subprocess.run(
["git", "commit", "-m", "agent work"],
cwd=info["path"], capture_output=True,
)
result = _cleanup_worktree(info)
assert result is False # Kept — has unpushed commits
assert Path(info["path"]).exists()
def test_branch_deleted_on_cleanup(self, git_repo):
info = _setup_worktree(str(git_repo))
@@ -367,7 +403,7 @@ class TestMultipleWorktrees:
lines = [l for l in result.stdout.strip().splitlines() if l.strip()]
assert len(lines) == 11
# Cleanup all
# Cleanup all (git_repo fixture has a fake remote ref so cleanup works)
for info in worktrees:
# Discard changes first so cleanup works
subprocess.run(
@@ -492,33 +528,77 @@ class TestStaleWorktreePruning:
assert not pruned
assert Path(info["path"]).exists()
def test_keeps_dirty_old_worktree(self, git_repo):
"""Old worktrees with uncommitted changes should NOT be pruned."""
def test_keeps_old_worktree_with_unpushed_commits(self, git_repo):
"""Old worktrees (24-72h) with unpushed commits should NOT be pruned."""
import time
info = _setup_worktree(str(git_repo))
assert info is not None
# Make it dirty
(Path(info["path"]) / "dirty.txt").write_text("uncommitted")
# Make an unpushed commit
(Path(info["path"]) / "work.txt").write_text("real work")
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
subprocess.run(
["git", "add", "dirty.txt"],
["git", "commit", "-m", "agent work"],
cwd=info["path"], capture_output=True,
)
# Make it old
# Make it old (25h — in the 24-72h soft tier)
old_time = time.time() - (25 * 3600)
os.utime(info["path"], (old_time, old_time))
# Check if it would be pruned
status = subprocess.run(
["git", "status", "--porcelain"],
# Check for unpushed commits (simulates prune logic)
result = subprocess.run(
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
capture_output=True, text=True, cwd=info["path"],
)
has_changes = bool(status.stdout.strip())
assert has_changes # Should be dirty → not pruned
has_unpushed = bool(result.stdout.strip())
assert has_unpushed # Has unpushed commits → not pruned in soft tier
assert Path(info["path"]).exists()
def test_force_prunes_very_old_worktree(self, git_repo):
"""Worktrees older than 72h should be force-pruned regardless."""
import time
info = _setup_worktree(str(git_repo))
assert info is not None
# Make an unpushed commit (would normally protect it)
(Path(info["path"]) / "work.txt").write_text("stale work")
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
subprocess.run(
["git", "commit", "-m", "old agent work"],
cwd=info["path"], capture_output=True,
)
# Make it very old (73h — beyond the 72h hard threshold)
old_time = time.time() - (73 * 3600)
os.utime(info["path"], (old_time, old_time))
# Simulate the force-prune tier check
hard_cutoff = time.time() - (72 * 3600)
mtime = Path(info["path"]).stat().st_mtime
assert mtime <= hard_cutoff # Should qualify for force removal
# Actually remove it (simulates _prune_stale_worktrees force path)
branch_result = subprocess.run(
["git", "branch", "--show-current"],
capture_output=True, text=True, timeout=5, cwd=info["path"],
)
branch = branch_result.stdout.strip()
subprocess.run(
["git", "worktree", "remove", info["path"], "--force"],
capture_output=True, text=True, timeout=15, cwd=str(git_repo),
)
if branch:
subprocess.run(
["git", "branch", "-D", branch],
capture_output=True, text=True, timeout=10, cwd=str(git_repo),
)
assert not Path(info["path"]).exists()
class TestEdgeCases:
"""Test edge cases for robustness."""
@@ -611,6 +691,133 @@ class TestTerminalCWDIntegration:
assert result.stdout.strip() == "true"
class TestOrphanedBranchPruning:
"""Test cleanup of orphaned hermes/* and pr-* branches."""
def test_prunes_orphaned_hermes_branch(self, git_repo):
"""hermes/hermes-* branches with no worktree should be deleted."""
# Create a branch that looks like a worktree branch but has no worktree
subprocess.run(
["git", "branch", "hermes/hermes-deadbeef", "HEAD"],
cwd=str(git_repo), capture_output=True,
)
# Verify it exists
result = subprocess.run(
["git", "branch", "--list", "hermes/hermes-deadbeef"],
capture_output=True, text=True, cwd=str(git_repo),
)
assert "hermes/hermes-deadbeef" in result.stdout
# Simulate _prune_orphaned_branches logic
result = subprocess.run(
["git", "branch", "--format=%(refname:short)"],
capture_output=True, text=True, cwd=str(git_repo),
)
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
wt_result = subprocess.run(
["git", "worktree", "list", "--porcelain"],
capture_output=True, text=True, cwd=str(git_repo),
)
active_branches = {"main"}
for line in wt_result.stdout.split("\n"):
if line.startswith("branch refs/heads/"):
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
orphaned = [
b for b in all_branches
if b not in active_branches
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
]
assert "hermes/hermes-deadbeef" in orphaned
# Delete them
if orphaned:
subprocess.run(
["git", "branch", "-D"] + orphaned,
capture_output=True, text=True, cwd=str(git_repo),
)
# Verify gone
result = subprocess.run(
["git", "branch", "--list", "hermes/hermes-deadbeef"],
capture_output=True, text=True, cwd=str(git_repo),
)
assert "hermes/hermes-deadbeef" not in result.stdout
def test_prunes_orphaned_pr_branch(self, git_repo):
"""pr-* branches should be deleted during pruning."""
subprocess.run(
["git", "branch", "pr-1234", "HEAD"],
cwd=str(git_repo), capture_output=True,
)
subprocess.run(
["git", "branch", "pr-5678", "HEAD"],
cwd=str(git_repo), capture_output=True,
)
result = subprocess.run(
["git", "branch", "--format=%(refname:short)"],
capture_output=True, text=True, cwd=str(git_repo),
)
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
active_branches = {"main"}
orphaned = [
b for b in all_branches
if b not in active_branches and b.startswith("pr-")
]
assert "pr-1234" in orphaned
assert "pr-5678" in orphaned
subprocess.run(
["git", "branch", "-D"] + orphaned,
capture_output=True, text=True, cwd=str(git_repo),
)
# Verify gone
result = subprocess.run(
["git", "branch", "--format=%(refname:short)"],
capture_output=True, text=True, cwd=str(git_repo),
)
remaining = result.stdout.strip()
assert "pr-1234" not in remaining
assert "pr-5678" not in remaining
def test_preserves_active_worktree_branch(self, git_repo):
"""Branches with active worktrees should NOT be pruned."""
info = _setup_worktree(str(git_repo))
assert info is not None
result = subprocess.run(
["git", "worktree", "list", "--porcelain"],
capture_output=True, text=True, cwd=str(git_repo),
)
active_branches = set()
for line in result.stdout.split("\n"):
if line.startswith("branch refs/heads/"):
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
assert info["branch"] in active_branches # Protected
def test_preserves_main_branch(self, git_repo):
"""main branch should never be pruned."""
result = subprocess.run(
["git", "branch", "--format=%(refname:short)"],
capture_output=True, text=True, cwd=str(git_repo),
)
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
active_branches = {"main"}
orphaned = [
b for b in all_branches
if b not in active_branches
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
]
assert "main" not in orphaned
class TestSystemPromptInjection:
"""Test that the agent gets worktree context in its system prompt."""
@@ -625,7 +832,7 @@ class TestSystemPromptInjection:
f"{info['path']}. Your branch is `{info['branch']}`. "
f"Changes here do not affect the main working tree or other agents. "
f"Remember to commit and push your changes, and create a PR if appropriate. "
f"The original repo is at {info['repo_root']}.]"
f"The original repo is at {info['repo_root']}.]\n"
)
assert info["path"] in wt_note

View File

@@ -152,11 +152,22 @@ def test_gateway_run_agent_codex_path_handles_internal_401_refresh(monkeypatch):
runner._provider_routing = {}
runner._fallback_model = None
runner._running_agents = {}
runner._smart_model_routing = {}
from unittest.mock import MagicMock, AsyncMock
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.hooks.loaded_hooks = []
runner._session_db = None
# Ensure model resolution returns the codex model even if xdist
# leaked env vars cleared HERMES_MODEL.
monkeypatch.setattr(
gateway_run.GatewayRunner,
"_resolve_turn_agent_config",
lambda self, msg, model, runtime: {
"model": model or "gpt-5.3-codex",
"runtime": runtime,
},
)
source = SessionSource(
platform=Platform.LOCAL,

View File

@@ -339,6 +339,36 @@ class TestMarkJobRun:
assert updated["last_status"] == "error"
assert updated["last_error"] == "timeout"
def test_delivery_error_tracked_separately(self, tmp_cron_dir):
"""Agent succeeds but delivery fails — both tracked independently."""
job = create_job(prompt="Report", schedule="every 1h")
mark_job_run(job["id"], success=True, delivery_error="platform 'telegram' not configured")
updated = get_job(job["id"])
assert updated["last_status"] == "ok"
assert updated["last_error"] is None
assert updated["last_delivery_error"] == "platform 'telegram' not configured"
def test_delivery_error_cleared_on_success(self, tmp_cron_dir):
"""Successful delivery clears the previous delivery error."""
job = create_job(prompt="Report", schedule="every 1h")
mark_job_run(job["id"], success=True, delivery_error="network timeout")
updated = get_job(job["id"])
assert updated["last_delivery_error"] == "network timeout"
# Next run delivers successfully
mark_job_run(job["id"], success=True, delivery_error=None)
updated = get_job(job["id"])
assert updated["last_delivery_error"] is None
def test_both_agent_and_delivery_error(self, tmp_cron_dir):
"""Agent fails AND delivery fails — both errors recorded."""
job = create_job(prompt="Report", schedule="every 1h")
mark_job_run(job["id"], success=False, error="model timeout",
delivery_error="platform 'discord' not enabled")
updated = get_job(job["id"])
assert updated["last_status"] == "error"
assert updated["last_error"] == "model timeout"
assert updated["last_delivery_error"] == "platform 'discord' not enabled"
class TestAdvanceNextRun:
"""Tests for advance_next_run() — crash-safety for recurring jobs."""

View File

@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, patch, MagicMock
import pytest
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, run_job, SILENT_MARKER, _build_job_prompt
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt
class TestResolveOrigin:
@@ -277,6 +277,188 @@ class TestDeliverResultWrapping:
# Media files should be forwarded separately
assert kwargs["media_files"] == [("/tmp/test-voice.ogg", False)]
def test_live_adapter_sends_media_as_attachments(self):
"""When a live adapter is available, MEDIA files should be sent as native
platform attachments (e.g., Discord voice, Telegram audio) rather than
as literal 'MEDIA:/path' text."""
from gateway.config import Platform
from concurrent.futures import Future
adapter = AsyncMock()
adapter.send.return_value = MagicMock(success=True)
adapter.send_voice.return_value = MagicMock(success=True)
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.DISCORD: pconfig}
loop = MagicMock()
loop.is_running.return_value = True
# run_coroutine_threadsafe returns concurrent.futures.Future (has timeout kwarg)
def fake_run_coro(coro, _loop):
future = Future()
future.set_result(MagicMock(success=True))
coro.close()
return future
job = {
"id": "tts-job",
"deliver": "origin",
"origin": {"platform": "discord", "chat_id": "9876"},
}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"Here is TTS\nMEDIA:/tmp/cron-voice.mp3",
adapters={Platform.DISCORD: adapter},
loop=loop,
)
# Text should be sent without the MEDIA tag
adapter.send.assert_called_once()
text_sent = adapter.send.call_args[0][1]
assert "MEDIA:" not in text_sent
assert "Here is TTS" in text_sent
# Audio file should be sent as a voice attachment
adapter.send_voice.assert_called_once()
voice_call = adapter.send_voice.call_args
assert voice_call[1]["audio_path"] == "/tmp/cron-voice.mp3"
def test_live_adapter_routes_image_to_send_image_file(self):
"""Image MEDIA files should be routed to send_image_file, not send_voice."""
from gateway.config import Platform
from concurrent.futures import Future
adapter = AsyncMock()
adapter.send.return_value = MagicMock(success=True)
adapter.send_image_file.return_value = MagicMock(success=True)
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.DISCORD: pconfig}
loop = MagicMock()
loop.is_running.return_value = True
def fake_run_coro(coro, _loop):
future = Future()
future.set_result(MagicMock(success=True))
coro.close()
return future
job = {
"id": "img-job",
"deliver": "origin",
"origin": {"platform": "discord", "chat_id": "1234"},
}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"Chart attached\nMEDIA:/tmp/chart.png",
adapters={Platform.DISCORD: adapter},
loop=loop,
)
adapter.send_image_file.assert_called_once()
assert adapter.send_image_file.call_args[1]["image_path"] == "/tmp/chart.png"
adapter.send_voice.assert_not_called()
def test_live_adapter_media_only_no_text(self):
"""When content is ONLY a MEDIA tag with no text, media should still be sent."""
from gateway.config import Platform
from concurrent.futures import Future
adapter = AsyncMock()
adapter.send_voice.return_value = MagicMock(success=True)
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
loop = MagicMock()
loop.is_running.return_value = True
def fake_run_coro(coro, _loop):
future = Future()
future.set_result(MagicMock(success=True))
coro.close()
return future
job = {
"id": "voice-only",
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "999"},
}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"MEDIA:/tmp/voice.ogg",
adapters={Platform.TELEGRAM: adapter},
loop=loop,
)
# Text send should NOT be called (no text after stripping MEDIA tag)
adapter.send.assert_not_called()
# Audio should still be delivered
adapter.send_voice.assert_called_once()
def test_live_adapter_sends_cleaned_text_not_raw(self):
"""The live adapter path must send cleaned text (MEDIA tags stripped),
not the raw delivery_content with embedded MEDIA: tags."""
from gateway.config import Platform
from concurrent.futures import Future
adapter = AsyncMock()
adapter.send.return_value = MagicMock(success=True)
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
loop = MagicMock()
loop.is_running.return_value = True
def fake_run_coro(coro, _loop):
future = Future()
future.set_result(MagicMock(success=True))
coro.close()
return future
job = {
"id": "img-job",
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "555"},
}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
_deliver_result(
job,
"Report\nMEDIA:/tmp/chart.png",
adapters={Platform.TELEGRAM: adapter},
loop=loop,
)
text_sent = adapter.send.call_args[0][1]
assert "MEDIA:" not in text_sent
assert "Report" in text_sent
def test_no_mirror_to_session_call(self):
"""Cron deliveries should NOT mirror into the gateway session."""
from gateway.config import Platform
@@ -326,6 +508,90 @@ class TestDeliverResultWrapping:
assert send_mock.call_args.kwargs["thread_id"] == "17585"
class TestDeliverResultErrorReturns:
"""Verify _deliver_result returns error strings on failure, None on success."""
def test_returns_none_on_successful_delivery(self):
from gateway.config import Platform
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})):
job = {
"id": "ok-job",
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "123"},
}
result = _deliver_result(job, "Output.")
assert result is None
def test_returns_none_for_local_delivery(self):
"""local-only jobs don't deliver — not a failure."""
job = {"id": "local-job", "deliver": "local"}
result = _deliver_result(job, "Output.")
assert result is None
def test_returns_error_for_unknown_platform(self):
job = {
"id": "bad-platform",
"deliver": "origin",
"origin": {"platform": "fax", "chat_id": "123"},
}
with patch("gateway.config.load_gateway_config"):
result = _deliver_result(job, "Output.")
assert result is not None
assert "unknown platform" in result
def test_returns_error_when_platform_disabled(self):
from gateway.config import Platform
pconfig = MagicMock()
pconfig.enabled = False
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg):
job = {
"id": "disabled",
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "123"},
}
result = _deliver_result(job, "Output.")
assert result is not None
assert "not configured" in result
def test_returns_error_on_send_failure(self):
from gateway.config import Platform
pconfig = MagicMock()
pconfig.enabled = True
mock_cfg = MagicMock()
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"error": "rate limited"})):
job = {
"id": "rate-limited",
"deliver": "origin",
"origin": {"platform": "telegram", "chat_id": "123"},
}
result = _deliver_result(job, "Output.")
assert result is not None
assert "rate limited" in result
def test_returns_error_for_unresolved_target(self, monkeypatch):
"""Non-local delivery with no resolvable target should return an error."""
monkeypatch.delenv("TELEGRAM_HOME_CHANNEL", raising=False)
job = {"id": "no-target", "deliver": "telegram"}
result = _deliver_result(job, "Output.")
assert result is not None
assert "no delivery target" in result
class TestRunJobSessionPersistence:
def test_run_job_passes_session_db_and_cron_platform(self, tmp_path):
job = {
@@ -709,6 +975,18 @@ class TestSilentDelivery:
tick(verbose=False)
deliver_mock.assert_not_called()
def test_silent_trailing_suppresses_delivery(self):
"""Agent appended [SILENT] after explanation text — must still suppress."""
response = "2 deals filtered out (like<10, reply<15).\n\n[SILENT]"
with patch("cron.scheduler.get_due_jobs", return_value=[self._make_job()]), \
patch("cron.scheduler.run_job", return_value=(True, "# output", response, None)), \
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
patch("cron.scheduler._deliver_result") as deliver_mock, \
patch("cron.scheduler.mark_job_run"):
from cron.scheduler import tick
tick(verbose=False)
deliver_mock.assert_not_called()
def test_silent_is_case_insensitive(self):
with patch("cron.scheduler.get_due_jobs", return_value=[self._make_job()]), \
patch("cron.scheduler.run_job", return_value=(True, "# output", "[silent] nothing new", None)), \
@@ -850,3 +1128,57 @@ class TestTickAdvanceBeforeRun:
adv_mock.assert_called_once_with("test-advance")
# advance must happen before run
assert call_order == [("advance", "test-advance"), ("run", "test-advance")]
class TestSendMediaViaAdapter:
"""Unit tests for _send_media_via_adapter — routes files to typed adapter methods."""
@staticmethod
def _run_with_loop(adapter, chat_id, media_files, metadata, job):
"""Helper: run _send_media_via_adapter with a real running event loop."""
import asyncio
import threading
loop = asyncio.new_event_loop()
t = threading.Thread(target=loop.run_forever, daemon=True)
t.start()
try:
_send_media_via_adapter(adapter, chat_id, media_files, metadata, loop, job)
finally:
loop.call_soon_threadsafe(loop.stop)
t.join(timeout=5)
loop.close()
def test_video_dispatched_to_send_video(self):
adapter = MagicMock()
adapter.send_video = AsyncMock()
media_files = [("/tmp/clip.mp4", False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j1"})
adapter.send_video.assert_called_once()
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/clip.mp4"
def test_unknown_ext_dispatched_to_send_document(self):
adapter = MagicMock()
adapter.send_document = AsyncMock()
media_files = [("/tmp/report.pdf", False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j2"})
adapter.send_document.assert_called_once()
assert adapter.send_document.call_args[1]["file_path"] == "/tmp/report.pdf"
def test_multiple_media_files_all_delivered(self):
adapter = MagicMock()
adapter.send_voice = AsyncMock()
adapter.send_image_file = AsyncMock()
media_files = [("/tmp/voice.mp3", False), ("/tmp/photo.jpg", False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"})
adapter.send_voice.assert_called_once()
adapter.send_image_file.assert_called_once()
def test_single_failure_does_not_block_others(self):
adapter = MagicMock()
adapter.send_voice = AsyncMock(side_effect=RuntimeError("network error"))
adapter.send_image_file = AsyncMock()
media_files = [("/tmp/voice.ogg", False), ("/tmp/photo.png", False)]
self._run_with_loop(adapter, "123", media_files, None, {"id": "j4"})
adapter.send_voice.assert_called_once()
adapter.send_image_file.assert_called_once()

View File

@@ -0,0 +1,164 @@
"""Security tests for Terminal-Bench 2 archive extraction."""
import base64
import importlib
import io
import sys
import tarfile
import types
import pytest
def _stub_module(name: str, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _load_terminalbench_module(monkeypatch):
class _EvalHandlingEnum:
STOP_TRAIN = "stop_train"
class _APIServerConfig:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class _AgentResult:
pass
class _HermesAgentLoop:
pass
class _HermesAgentBaseEnv:
pass
class _HermesAgentEnvConfig:
pass
class _ToolContext:
pass
stub_modules = {
"atroposlib": _stub_module("atroposlib"),
"atroposlib.envs": _stub_module("atroposlib.envs"),
"atroposlib.envs.base": _stub_module(
"atroposlib.envs.base",
EvalHandlingEnum=_EvalHandlingEnum,
),
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
"atroposlib.envs.server_handling.server_manager": _stub_module(
"atroposlib.envs.server_handling.server_manager",
APIServerConfig=_APIServerConfig,
),
"environments.agent_loop": _stub_module(
"environments.agent_loop",
AgentResult=_AgentResult,
HermesAgentLoop=_HermesAgentLoop,
),
"environments.hermes_base_env": _stub_module(
"environments.hermes_base_env",
HermesAgentBaseEnv=_HermesAgentBaseEnv,
HermesAgentEnvConfig=_HermesAgentEnvConfig,
),
"environments.tool_context": _stub_module(
"environments.tool_context",
ToolContext=_ToolContext,
),
"tools.terminal_tool": _stub_module(
"tools.terminal_tool",
register_task_env_overrides=lambda *args, **kwargs: None,
clear_task_env_overrides=lambda *args, **kwargs: None,
cleanup_vm=lambda *args, **kwargs: None,
),
}
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
"atroposlib.envs.server_handling.server_manager"
]
for name, module in stub_modules.items():
monkeypatch.setitem(sys.modules, name, module)
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
sys.modules.pop(module_name, None)
return importlib.import_module(module_name)
def _build_tar_b64(entries):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
for entry in entries:
kind = entry["kind"]
info = tarfile.TarInfo(entry["name"])
if kind == "dir":
info.type = tarfile.DIRTYPE
tar.addfile(info)
continue
if kind == "file":
data = entry["data"].encode("utf-8")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
continue
if kind == "symlink":
info.type = tarfile.SYMTYPE
info.linkname = entry["target"]
tar.addfile(info)
continue
raise ValueError(f"Unknown tar entry kind: {kind}")
return base64.b64encode(buf.getvalue()).decode("ascii")
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "dir", "name": "nested"},
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
]
)
target = tmp_path / "extract"
module._extract_base64_tar(archive, target)
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "file", "name": "../escape.txt", "data": "owned"},
]
)
target = tmp_path / "extract"
with pytest.raises(ValueError, match="Unsafe archive member path"):
module._extract_base64_tar(archive, target)
assert not (tmp_path / "escape.txt").exists()
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
]
)
target = tmp_path / "extract"
with pytest.raises(ValueError, match="Unsupported archive member type"):
module._extract_base64_tar(archive, target)
assert not (target / "link").exists()

View File

@@ -439,7 +439,7 @@ class TestChatCompletionsEndpoint:
tp_cb = kwargs.get("tool_progress_callback")
# Simulate tool progress before streaming content
if tp_cb:
tp_cb("terminal", "ls -la", {"command": "ls -la"})
tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"})
if cb:
await asyncio.sleep(0.05)
cb("Here are the files.")
@@ -476,8 +476,8 @@ class TestChatCompletionsEndpoint:
cb = kwargs.get("stream_delta_callback")
tp_cb = kwargs.get("tool_progress_callback")
if tp_cb:
tp_cb("_thinking", "some internal state", {})
tp_cb("web_search", "Python docs", {"query": "Python docs"})
tp_cb("tool.started", "_thinking", "some internal state", {})
tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"})
if cb:
await asyncio.sleep(0.05)
cb("Found it.")

View File

@@ -39,7 +39,7 @@ class TestHermesApiServerToolset:
tools = resolve_toolset("hermes-api-server")
for tool in ["browser_navigate", "browser_snapshot", "browser_click",
"browser_type", "browser_scroll", "browser_back",
"browser_press", "browser_close"]:
"browser_press"]:
assert tool in tools, f"Missing browser tool: {tool}"
def test_toolset_includes_homeassistant_tools(self):

View File

@@ -0,0 +1,313 @@
"""Regression tests: slash commands must bypass the base adapter's active-session guard.
When an agent is running, the base adapter's Level 1 guard in
handle_message() intercepts all incoming messages and queues them as
pending. Certain commands (/stop, /new, /reset, /approve, /deny,
/status) must bypass this guard and be dispatched directly to the gateway
runner — otherwise they are queued as user text and either:
- leak into the conversation as agent input (/stop, /new), or
- deadlock (/approve, /deny — agent blocks on Event.wait)
These tests verify that the bypass works at the adapter level and that
the safety net in _run_agent discards leaked command text.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
from gateway.session import SessionSource, build_session_key
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _StubAdapter(BasePlatformAdapter):
"""Concrete adapter with abstract methods stubbed out."""
async def connect(self):
pass
async def disconnect(self):
pass
async def send(self, chat_id, text, **kwargs):
pass
async def get_chat_info(self, chat_id):
return {}
def _make_adapter():
"""Create a minimal adapter for testing the active-session guard."""
config = PlatformConfig(enabled=True, token="test-token")
adapter = _StubAdapter(config, Platform.TELEGRAM)
adapter.sent_responses = []
async def _mock_handler(event):
cmd = event.get_command()
return f"handled:{cmd}" if cmd else f"handled:text:{event.text}"
adapter._message_handler = _mock_handler
async def _mock_send_retry(chat_id, content, **kwargs):
adapter.sent_responses.append(content)
adapter._send_with_retry = _mock_send_retry
return adapter
def _make_event(text="/stop", chat_id="12345"):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
)
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
def _session_key(chat_id="12345"):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
)
return build_session_key(source)
# ---------------------------------------------------------------------------
# Tests: commands bypass Level 1 when session is active
# ---------------------------------------------------------------------------
class TestCommandBypassActiveSession:
"""Commands that must bypass the active-session guard."""
@pytest.mark.asyncio
async def test_stop_bypasses_guard(self):
"""/stop must be dispatched directly, not queued."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/stop"))
assert sk not in adapter._pending_messages, (
"/stop was queued as a pending message instead of being dispatched"
)
assert any("handled:stop" in r for r in adapter.sent_responses), (
"/stop response was not sent back to the user"
)
@pytest.mark.asyncio
async def test_new_bypasses_guard(self):
"""/new must be dispatched directly, not queued."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/new"))
assert sk not in adapter._pending_messages
assert any("handled:new" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_reset_bypasses_guard(self):
"""/reset (alias for /new) must be dispatched directly."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/reset"))
assert sk not in adapter._pending_messages
assert any("handled:reset" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_approve_bypasses_guard(self):
"""/approve must bypass (deadlock prevention)."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/approve"))
assert sk not in adapter._pending_messages
assert any("handled:approve" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_deny_bypasses_guard(self):
"""/deny must bypass (deadlock prevention)."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/deny"))
assert sk not in adapter._pending_messages
assert any("handled:deny" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_status_bypasses_guard(self):
"""/status must bypass so it returns a system response."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/status"))
assert sk not in adapter._pending_messages
assert any("handled:status" in r for r in adapter.sent_responses)
# ---------------------------------------------------------------------------
# Tests: non-bypass messages still get queued
# ---------------------------------------------------------------------------
class TestNonBypassStillQueued:
"""Regular messages and unknown commands must be queued, not dispatched."""
@pytest.mark.asyncio
async def test_regular_text_queued(self):
"""Plain text while agent is running must be queued as pending."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("hello world"))
assert sk in adapter._pending_messages, (
"Regular text was not queued — it should be pending"
)
assert len(adapter.sent_responses) == 0, (
"Regular text should not produce a direct response"
)
@pytest.mark.asyncio
async def test_unknown_command_queued(self):
"""Unknown /commands must be queued, not dispatched."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/foobar"))
assert sk in adapter._pending_messages
assert len(adapter.sent_responses) == 0
@pytest.mark.asyncio
async def test_file_path_not_treated_as_command(self):
"""A message like '/path/to/file' must not bypass the guard."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/path/to/file.py"))
assert sk in adapter._pending_messages
assert len(adapter.sent_responses) == 0
# ---------------------------------------------------------------------------
# Tests: no active session — commands go through normally
# ---------------------------------------------------------------------------
class TestNoActiveSessionNormalDispatch:
"""When no agent is running, messages spawn a background task normally."""
@pytest.mark.asyncio
async def test_stop_when_no_session_active(self):
"""/stop without an active session spawns a background task
(the Level 2 handler will return 'No active task')."""
adapter = _make_adapter()
sk = _session_key()
# No active session — _active_sessions is empty
assert sk not in adapter._active_sessions
await adapter.handle_message(_make_event("/stop"))
# Should have gone through the normal path (background task spawned)
# and NOT be in _pending_messages (that's the queued-during-active path)
assert sk not in adapter._pending_messages
# ---------------------------------------------------------------------------
# Tests: safety net in _run_agent discards command text from pending queue
# ---------------------------------------------------------------------------
class TestPendingCommandSafetyNet:
"""The safety net in gateway/run.py _run_agent must discard command text
that leaks into the pending queue via interrupt_message fallback."""
def test_stop_command_detected(self):
"""resolve_command must recognize /stop so the safety net can
discard it."""
from hermes_cli.commands import resolve_command
assert resolve_command("stop") is not None
assert resolve_command("stop").name == "stop"
def test_new_command_detected(self):
from hermes_cli.commands import resolve_command
assert resolve_command("new") is not None
assert resolve_command("new").name == "new"
def test_reset_alias_detected(self):
from hermes_cli.commands import resolve_command
assert resolve_command("reset") is not None
assert resolve_command("reset").name == "new" # alias
def test_unknown_command_not_detected(self):
from hermes_cli.commands import resolve_command
assert resolve_command("foobar") is None
def test_file_path_not_detected_as_command(self):
"""'/path/to/file' should not resolve as a command."""
from hermes_cli.commands import resolve_command
# The safety net splits on whitespace and takes the first word
# after stripping '/'. For '/path/to/file', that's 'path/to/file'.
assert resolve_command("path/to/file") is None
# ---------------------------------------------------------------------------
# Tests: bypass with @botname suffix (Telegram-style)
# ---------------------------------------------------------------------------
class TestBypassWithBotnameSuffix:
"""Telegram appends @botname to commands. The bypass must still work."""
@pytest.mark.asyncio
async def test_stop_with_botname(self):
"""/stop@MyHermesBot must bypass the guard."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/stop@MyHermesBot"))
assert sk not in adapter._pending_messages, (
"/stop@MyHermesBot was queued instead of bypassing"
)
assert any("handled:stop" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_new_with_botname(self):
"""/new@MyHermesBot must bypass the guard."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/new@MyHermesBot"))
assert sk not in adapter._pending_messages
assert any("handled:new" in r for r in adapter.sent_responses)

View File

@@ -0,0 +1,343 @@
"""Tests for Discord ignored_channels and no_thread_channels config."""
from types import SimpleNamespace
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
import sys
import pytest
from gateway.config import PlatformConfig
def _ensure_discord_mock():
"""Install a mock discord module when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
class FakeDMChannel:
def __init__(self, channel_id: int = 1, name: str = "dm"):
self.id = channel_id
self.name = name
class FakeTextChannel:
def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
self.id = channel_id
self.name = name
self.guild = SimpleNamespace(name=guild_name)
self.topic = None
class FakeThread:
def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
self.id = channel_id
self.name = name
self.parent = parent
self.parent_id = getattr(parent, "id", None)
self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
self.topic = None
@pytest.fixture
def adapter(monkeypatch):
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
config = PlatformConfig(enabled=True, token="fake-token")
adapter = DiscordAdapter(config)
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
adapter.handle_message = AsyncMock()
return adapter
def make_message(*, channel, content: str, mentions=None):
author = SimpleNamespace(id=42, display_name="TestUser", name="TestUser")
return SimpleNamespace(
id=123,
content=content,
mentions=list(mentions or []),
attachments=[],
reference=None,
created_at=datetime.now(timezone.utc),
channel=channel,
author=author,
)
# ── ignored_channels ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_ignored_channel_blocks_message(adapter, monkeypatch):
"""Messages in ignored channels are silently dropped."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_ignored_channel_blocks_even_with_mention(adapter, monkeypatch):
"""Ignored channels take priority — even @mentions are dropped."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
bot_user = adapter._client.user
message = make_message(
channel=FakeTextChannel(channel_id=500),
content=f"<@{bot_user.id}> hello",
mentions=[bot_user],
)
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_non_ignored_channel_processes_normally(adapter, monkeypatch):
"""Channels not in the ignored list process normally."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500,600")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeTextChannel(channel_id=700), content="hello")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
@pytest.mark.asyncio
async def test_ignored_channels_csv_parsing(adapter, monkeypatch):
"""Multiple channel IDs are parsed correctly from CSV."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500, 600 , 700")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
for ch_id in (500, 600, 700):
adapter.handle_message.reset_mock()
message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_ignored_channels_empty_string_ignores_nothing(adapter, monkeypatch):
"""Empty DISCORD_IGNORED_CHANNELS means nothing is ignored."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
@pytest.mark.asyncio
async def test_ignored_channel_thread_parent_match(adapter, monkeypatch):
"""Thread whose parent channel is ignored should also be ignored."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
parent = FakeTextChannel(channel_id=500, name="ignored-channel")
thread = FakeThread(channel_id=501, name="thread-in-ignored", parent=parent)
message = make_message(channel=thread, content="hello from thread")
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_dms_unaffected_by_ignored_channels(adapter, monkeypatch):
"""DMs should never be affected by ignored_channels."""
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeDMChannel(channel_id=500), content="dm hello")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
# ── no_thread_channels ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_no_thread_channel_skips_auto_thread(adapter, monkeypatch):
"""Channels in no_thread_channels should not auto-create threads."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_not_awaited()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.source.chat_type == "group"
@pytest.mark.asyncio
async def test_normal_channel_still_auto_threads(adapter, monkeypatch):
"""Channels NOT in no_thread_channels still get auto-threading."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
fake_thread = FakeThread(channel_id=999, name="auto-thread")
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
message = make_message(channel=FakeTextChannel(channel_id=900), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_awaited_once()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.source.chat_type == "thread"
@pytest.mark.asyncio
async def test_no_thread_channels_csv_parsing(adapter, monkeypatch):
"""Multiple no_thread channel IDs parsed from CSV."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800, 900")
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
for ch_id in (800, 900):
adapter._auto_create_thread.reset_mock()
adapter.handle_message.reset_mock()
message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_not_awaited()
@pytest.mark.asyncio
async def test_no_thread_with_auto_thread_disabled_is_noop(adapter, monkeypatch):
"""no_thread_channels is a no-op when auto_thread is globally disabled."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
adapter._auto_create_thread = AsyncMock()
message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_not_awaited()
adapter.handle_message.assert_awaited_once()
# ── config.py bridging ───────────────────────────────────────────────
def test_config_bridges_ignored_channels(monkeypatch, tmp_path):
"""gateway/config.py bridges discord.ignored_channels to env var."""
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump({
"discord": {
"ignored_channels": ["111", "222"],
},
}))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Use setenv (not delenv) so monkeypatch registers cleanup even when
# the var doesn't exist yet — load_gateway_config will overwrite it.
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
from gateway.config import load_gateway_config
load_gateway_config()
import os
assert os.getenv("DISCORD_IGNORED_CHANNELS") == "111,222"
def test_config_bridges_no_thread_channels(monkeypatch, tmp_path):
"""gateway/config.py bridges discord.no_thread_channels to env var."""
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump({
"discord": {
"no_thread_channels": ["333"],
},
}))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "")
from gateway.config import load_gateway_config
load_gateway_config()
import os
assert os.getenv("DISCORD_NO_THREAD_CHANNELS") == "333"
def test_config_env_var_takes_precedence(monkeypatch, tmp_path):
"""Env vars should take precedence over config.yaml values."""
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump({
"discord": {
"ignored_channels": ["111"],
},
}))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "999")
from gateway.config import load_gateway_config
load_gateway_config()
import os
# Env var should NOT be overwritten
assert os.getenv("DISCORD_IGNORED_CHANNELS") == "999"

View File

@@ -0,0 +1,277 @@
"""Tests for Discord reply_to_mode functionality.
Covers the threading behavior control for multi-chunk replies:
- "off": Never reply-reference to original message
- "first": Only first chunk uses reply reference (default)
- "all": All chunks reply-reference the original message
"""
import os
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
from gateway.config import PlatformConfig, GatewayConfig, Platform, _apply_env_overrides
def _ensure_discord_mock():
"""Install a mock discord module when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
@pytest.fixture()
def adapter_factory():
"""Factory to create DiscordAdapter with custom reply_to_mode."""
def create(reply_to_mode: str = "first"):
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
return DiscordAdapter(config)
return create
class TestReplyToModeConfig:
"""Tests for reply_to_mode configuration loading."""
def test_default_mode_is_first(self, adapter_factory):
adapter = adapter_factory()
assert adapter._reply_to_mode == "first"
def test_off_mode(self, adapter_factory):
adapter = adapter_factory(reply_to_mode="off")
assert adapter._reply_to_mode == "off"
def test_first_mode(self, adapter_factory):
adapter = adapter_factory(reply_to_mode="first")
assert adapter._reply_to_mode == "first"
def test_all_mode(self, adapter_factory):
adapter = adapter_factory(reply_to_mode="all")
assert adapter._reply_to_mode == "all"
def test_invalid_mode_stored_as_is(self, adapter_factory):
"""Invalid modes are stored but send() handles them gracefully."""
adapter = adapter_factory(reply_to_mode="invalid")
assert adapter._reply_to_mode == "invalid"
def test_none_mode_defaults_to_first(self):
config = PlatformConfig(enabled=True, token="test-token")
adapter = DiscordAdapter(config)
assert adapter._reply_to_mode == "first"
def test_empty_string_mode_defaults_to_first(self):
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode="")
adapter = DiscordAdapter(config)
assert adapter._reply_to_mode == "first"
def _make_discord_adapter(reply_to_mode: str = "first"):
"""Create a DiscordAdapter with mocked client and channel for send() tests."""
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
adapter = DiscordAdapter(config)
# Mock the Discord client and channel
mock_channel = AsyncMock()
ref_message = MagicMock()
mock_channel.fetch_message = AsyncMock(return_value=ref_message)
sent_msg = MagicMock()
sent_msg.id = 42
mock_channel.send = AsyncMock(return_value=sent_msg)
mock_client = MagicMock()
mock_client.get_channel = MagicMock(return_value=mock_channel)
adapter._client = mock_client
return adapter, mock_channel, ref_message
class TestSendWithReplyToMode:
"""Tests for send() method respecting reply_to_mode."""
@pytest.mark.asyncio
async def test_off_mode_no_reply_reference(self):
adapter, channel, ref_msg = _make_discord_adapter("off")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
# Should never try to fetch the reference message
channel.fetch_message.assert_not_called()
# All chunks sent without reference
for call in channel.send.call_args_list:
assert call.kwargs.get("reference") is None
@pytest.mark.asyncio
async def test_first_mode_only_first_chunk_references(self):
adapter, channel, ref_msg = _make_discord_adapter("first")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
# Should fetch the reference message
channel.fetch_message.assert_called_once_with(999)
calls = channel.send.call_args_list
assert len(calls) == 3
assert calls[0].kwargs.get("reference") is ref_msg
assert calls[1].kwargs.get("reference") is None
assert calls[2].kwargs.get("reference") is None
@pytest.mark.asyncio
async def test_all_mode_all_chunks_reference(self):
adapter, channel, ref_msg = _make_discord_adapter("all")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
await adapter.send("12345", "test content", reply_to="999")
channel.fetch_message.assert_called_once_with(999)
calls = channel.send.call_args_list
assert len(calls) == 3
for call in calls:
assert call.kwargs.get("reference") is ref_msg
@pytest.mark.asyncio
async def test_no_reply_to_param_no_reference(self):
adapter, channel, ref_msg = _make_discord_adapter("all")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
await adapter.send("12345", "test content", reply_to=None)
channel.fetch_message.assert_not_called()
for call in channel.send.call_args_list:
assert call.kwargs.get("reference") is None
@pytest.mark.asyncio
async def test_single_chunk_respects_first_mode(self):
adapter, channel, ref_msg = _make_discord_adapter("first")
adapter.truncate_message = lambda content, max_len: ["single chunk"]
await adapter.send("12345", "test", reply_to="999")
calls = channel.send.call_args_list
assert len(calls) == 1
assert calls[0].kwargs.get("reference") is ref_msg
@pytest.mark.asyncio
async def test_single_chunk_off_mode(self):
adapter, channel, ref_msg = _make_discord_adapter("off")
adapter.truncate_message = lambda content, max_len: ["single chunk"]
await adapter.send("12345", "test", reply_to="999")
channel.fetch_message.assert_not_called()
calls = channel.send.call_args_list
assert len(calls) == 1
assert calls[0].kwargs.get("reference") is None
@pytest.mark.asyncio
async def test_invalid_mode_falls_back_to_first_behavior(self):
"""Invalid mode behaves like 'first' — only first chunk gets reference."""
adapter, channel, ref_msg = _make_discord_adapter("banana")
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
await adapter.send("12345", "test", reply_to="999")
calls = channel.send.call_args_list
assert len(calls) == 2
assert calls[0].kwargs.get("reference") is ref_msg
assert calls[1].kwargs.get("reference") is None
class TestConfigSerialization:
"""Tests for reply_to_mode serialization (shared with Telegram)."""
def test_to_dict_includes_reply_to_mode(self):
config = PlatformConfig(enabled=True, token="test", reply_to_mode="all")
result = config.to_dict()
assert result["reply_to_mode"] == "all"
def test_from_dict_loads_reply_to_mode(self):
data = {"enabled": True, "token": "***", "reply_to_mode": "off"}
config = PlatformConfig.from_dict(data)
assert config.reply_to_mode == "off"
def test_from_dict_defaults_to_first(self):
data = {"enabled": True, "token": "***"}
config = PlatformConfig.from_dict(data)
assert config.reply_to_mode == "first"
class TestEnvVarOverride:
"""Tests for DISCORD_REPLY_TO_MODE environment variable override."""
def _make_config(self):
config = GatewayConfig()
config.platforms[Platform.DISCORD] = PlatformConfig(enabled=True, token="test")
return config
def test_env_var_sets_off_mode(self):
config = self._make_config()
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
_apply_env_overrides(config)
assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
def test_env_var_sets_all_mode(self):
config = self._make_config()
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "all"}, clear=False):
_apply_env_overrides(config)
assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
def test_env_var_case_insensitive(self):
config = self._make_config()
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "ALL"}, clear=False):
_apply_env_overrides(config)
assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
def test_env_var_invalid_value_ignored(self):
config = self._make_config()
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "banana"}, clear=False):
_apply_env_overrides(config)
assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
def test_env_var_empty_value_ignored(self):
config = self._make_config()
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": ""}, clear=False):
_apply_env_overrides(config)
assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
def test_env_var_creates_platform_config_if_missing(self):
"""DISCORD_REPLY_TO_MODE creates PlatformConfig even without DISCORD_BOT_TOKEN."""
config = GatewayConfig()
assert Platform.DISCORD not in config.platforms
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
_apply_env_overrides(config)
assert Platform.DISCORD in config.platforms
assert config.platforms[Platform.DISCORD].reply_to_mode == "off"

View File

@@ -8,7 +8,7 @@ import time
import unittest
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
try:
import lark_oapi
@@ -17,6 +17,18 @@ except ImportError:
_HAS_LARK_OAPI = False
def _mock_event_dispatcher_builder(mock_handler_class):
mock_builder = Mock()
mock_builder.register_p2_im_message_message_read_v1 = Mock(return_value=mock_builder)
mock_builder.register_p2_im_message_receive_v1 = Mock(return_value=mock_builder)
mock_builder.register_p2_im_message_reaction_created_v1 = Mock(return_value=mock_builder)
mock_builder.register_p2_im_message_reaction_deleted_v1 = Mock(return_value=mock_builder)
mock_builder.register_p2_card_action_trigger = Mock(return_value=mock_builder)
mock_builder.build = Mock(return_value=object())
mock_handler_class.builder = Mock(return_value=mock_builder)
return mock_builder
class TestPlatformEnum(unittest.TestCase):
def test_feishu_in_platform_enum(self):
from gateway.config import Platform
@@ -262,12 +274,14 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
with (
patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True),
patch("gateway.platforms.feishu.FEISHU_WEBHOOK_AVAILABLE", True),
patch("gateway.platforms.feishu.EventDispatcherHandler") as mock_handler_class,
patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)),
patch("gateway.platforms.feishu.release_scoped_lock"),
patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()),
patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()),
patch("gateway.platforms.feishu.web", web_module),
):
_mock_event_dispatcher_builder(mock_handler_class)
connected = asyncio.run(adapter.connect())
self.assertTrue(connected)
@@ -283,13 +297,13 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
ws_client = object()
ws_client = SimpleNamespace()
with (
patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True),
patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True),
patch("gateway.platforms.feishu.lark", SimpleNamespace(LogLevel=SimpleNamespace(INFO="INFO", WARNING="WARNING"))),
patch("gateway.platforms.feishu.EventDispatcherHandler", object()),
patch("gateway.platforms.feishu.EventDispatcherHandler") as mock_handler_class,
patch("gateway.platforms.feishu.FeishuWSClient", return_value=ws_client),
patch("gateway.platforms.feishu._run_official_feishu_ws_client"),
patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)) as acquire_lock,
@@ -297,6 +311,8 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()),
patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()),
):
_mock_event_dispatcher_builder(mock_handler_class)
loop = asyncio.new_event_loop()
future = loop.create_future()
future.set_result(None)
@@ -305,6 +321,9 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
def run_in_executor(self, *_args, **_kwargs):
return future
def is_closed(self):
return False
try:
with patch("gateway.platforms.feishu.asyncio.get_running_loop", return_value=_Loop()):
connected = asyncio.run(adapter.connect())
@@ -313,6 +332,7 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
loop.close()
self.assertTrue(connected)
self.assertIsNone(adapter._event_handler)
acquire_lock.assert_called_once_with(
"feishu-app-id",
"cli_app",
@@ -354,14 +374,14 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
ws_client = object()
ws_client = SimpleNamespace()
sleeps = []
with (
patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True),
patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True),
patch("gateway.platforms.feishu.lark", SimpleNamespace(LogLevel=SimpleNamespace(INFO="INFO", WARNING="WARNING"))),
patch("gateway.platforms.feishu.EventDispatcherHandler", object()),
patch("gateway.platforms.feishu.EventDispatcherHandler") as mock_handler_class,
patch("gateway.platforms.feishu.FeishuWSClient", return_value=ws_client),
patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)),
patch("gateway.platforms.feishu.release_scoped_lock"),
@@ -369,6 +389,8 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
patch("gateway.platforms.feishu.asyncio.sleep", side_effect=lambda delay: sleeps.append(delay)),
patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()),
):
_mock_event_dispatcher_builder(mock_handler_class)
loop = asyncio.new_event_loop()
future = loop.create_future()
future.set_result(None)
@@ -383,6 +405,9 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
raise OSError("temporary websocket failure")
return future
def is_closed(self):
return False
fake_loop = _Loop()
try:
with patch("gateway.platforms.feishu.asyncio.get_running_loop", return_value=fake_loop):
@@ -536,6 +561,113 @@ class TestAdapterModule(unittest.TestCase):
self.assertIn("register_p2_im_message_reaction_deleted_v1", source)
self.assertIn("register_p2_card_action_trigger", source)
def test_load_settings_uses_sdk_defaults_for_invalid_ws_reconnect_values(self):
from gateway.platforms.feishu import FeishuAdapter
settings = FeishuAdapter._load_settings(
{
"ws_reconnect_nonce": -1,
"ws_reconnect_interval": "bad",
}
)
self.assertEqual(settings.ws_reconnect_nonce, 30)
self.assertEqual(settings.ws_reconnect_interval, 120)
def test_load_settings_accepts_custom_ws_reconnect_values(self):
from gateway.platforms.feishu import FeishuAdapter
settings = FeishuAdapter._load_settings(
{
"ws_reconnect_nonce": 0,
"ws_reconnect_interval": 3,
}
)
self.assertEqual(settings.ws_reconnect_nonce, 0)
self.assertEqual(settings.ws_reconnect_interval, 3)
def test_load_settings_accepts_custom_ws_ping_values(self):
from gateway.platforms.feishu import FeishuAdapter
settings = FeishuAdapter._load_settings(
{
"ws_ping_interval": 10,
"ws_ping_timeout": 8,
}
)
self.assertEqual(settings.ws_ping_interval, 10)
self.assertEqual(settings.ws_ping_timeout, 8)
def test_load_settings_ignores_invalid_ws_ping_values(self):
from gateway.platforms.feishu import FeishuAdapter
settings = FeishuAdapter._load_settings(
{
"ws_ping_interval": 0,
"ws_ping_timeout": -1,
}
)
self.assertIsNone(settings.ws_ping_interval)
self.assertIsNone(settings.ws_ping_timeout)
def test_runtime_ws_overrides_reapply_after_sdk_configure(self):
import sys
from types import ModuleType
class _FakeWSClient:
def __init__(self):
self._reconnect_nonce = 30
self._reconnect_interval = 120
self._ping_interval = 120
self.configure_calls = []
def _configure(self, conf):
self.configure_calls.append(conf)
self._reconnect_nonce = conf.ReconnectNonce
self._reconnect_interval = conf.ReconnectInterval
self._ping_interval = conf.PingInterval
def start(self):
conf = SimpleNamespace(ReconnectNonce=99, ReconnectInterval=88, PingInterval=77)
self._configure(conf)
raise RuntimeError("stop test client")
fake_client = _FakeWSClient()
fake_adapter = SimpleNamespace(
_ws_thread_loop=None,
_ws_reconnect_nonce=2,
_ws_reconnect_interval=3,
_ws_ping_interval=4,
_ws_ping_timeout=5,
)
fake_client_module = ModuleType("lark_oapi.ws.client")
fake_client_module.loop = None
fake_client_module.websockets = SimpleNamespace(connect=AsyncMock())
fake_ws_module = ModuleType("lark_oapi.ws")
fake_ws_module.client = fake_client_module
fake_root_module = ModuleType("lark_oapi")
fake_root_module.ws = fake_ws_module
original_modules = sys.modules.copy()
sys.modules["lark_oapi"] = fake_root_module
sys.modules["lark_oapi.ws"] = fake_ws_module
sys.modules["lark_oapi.ws.client"] = fake_client_module
try:
from gateway.platforms.feishu import _run_official_feishu_ws_client
_run_official_feishu_ws_client(fake_client, fake_adapter)
finally:
sys.modules.clear()
sys.modules.update(original_modules)
self.assertEqual(len(fake_client.configure_calls), 1)
self.assertEqual(fake_client._reconnect_nonce, 2)
self.assertEqual(fake_client._reconnect_interval, 3)
self.assertEqual(fake_client._ping_interval, 4)
class TestAdapterBehavior(unittest.TestCase):
@patch.dict(os.environ, {}, clear=True)
@@ -690,10 +822,10 @@ class TestAdapterBehavior(unittest.TestCase):
adapter = FeishuAdapter(PlatformConfig())
message = SimpleNamespace(mentions=[])
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
self.assertFalse(adapter._should_accept_group_message(message, sender_id))
self.assertFalse(adapter._should_accept_group_message(message, sender_id, ""))
message_with_mention = SimpleNamespace(mentions=[SimpleNamespace(key="@_user_1")])
self.assertFalse(adapter._should_accept_group_message(message_with_mention, sender_id))
self.assertFalse(adapter._should_accept_group_message(message_with_mention, sender_id, ""))
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
def test_group_message_with_other_user_mention_is_rejected_when_bot_identity_unknown(self):
@@ -707,7 +839,7 @@ class TestAdapterBehavior(unittest.TestCase):
id=SimpleNamespace(open_id="ou_other", user_id="u_other"),
)
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id))
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id, ""))
@patch.dict(
os.environ,
@@ -736,28 +868,222 @@ class TestAdapterBehavior(unittest.TestCase):
adapter._should_accept_group_message(
mentioned,
SimpleNamespace(open_id="ou_allowed", user_id=None),
"",
)
)
self.assertFalse(
adapter._should_accept_group_message(
mentioned,
SimpleNamespace(open_id="ou_blocked", user_id=None),
"",
)
)
@patch.dict(
os.environ,
{
"FEISHU_GROUP_POLICY": "open",
"FEISHU_BOT_OPEN_ID": "ou_bot",
},
clear=True,
)
def test_per_group_allowlist_policy_gates_by_sender(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
config = PlatformConfig(
extra={
"group_rules": {
"oc_chat_a": {
"policy": "allowlist",
"allowlist": ["ou_alice", "ou_bob"],
}
}
}
)
adapter = FeishuAdapter(config)
adapter._bot_open_id = "ou_bot"
message = SimpleNamespace(
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
)
self.assertTrue(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_alice", user_id=None),
"oc_chat_a",
)
)
self.assertFalse(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_charlie", user_id=None),
"oc_chat_a",
)
)
def test_per_group_blacklist_policy_blocks_specific_users(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
config = PlatformConfig(
extra={
"group_rules": {
"oc_chat_b": {
"policy": "blacklist",
"blacklist": ["ou_blocked"],
}
}
}
)
adapter = FeishuAdapter(config)
adapter._bot_open_id = "ou_bot"
message = SimpleNamespace(
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
)
self.assertTrue(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_alice", user_id=None),
"oc_chat_b",
)
)
self.assertFalse(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_blocked", user_id=None),
"oc_chat_b",
)
)
def test_per_group_admin_only_policy_requires_admin(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
config = PlatformConfig(
extra={
"admins": ["ou_admin"],
"group_rules": {
"oc_chat_c": {
"policy": "admin_only",
}
},
}
)
adapter = FeishuAdapter(config)
adapter._bot_open_id = "ou_bot"
message = SimpleNamespace(
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
)
self.assertTrue(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_admin", user_id=None),
"oc_chat_c",
)
)
self.assertFalse(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_regular", user_id=None),
"oc_chat_c",
)
)
def test_per_group_disabled_policy_blocks_all(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
config = PlatformConfig(
extra={
"admins": ["ou_admin"],
"group_rules": {
"oc_chat_d": {
"policy": "disabled",
}
},
}
)
adapter = FeishuAdapter(config)
adapter._bot_open_id = "ou_bot"
message = SimpleNamespace(
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
)
self.assertTrue(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_admin", user_id=None),
"oc_chat_d",
)
)
self.assertFalse(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_regular", user_id=None),
"oc_chat_d",
)
)
def test_global_admins_bypass_all_group_rules(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
config = PlatformConfig(
extra={
"admins": ["ou_admin"],
"group_rules": {
"oc_chat_e": {
"policy": "allowlist",
"allowlist": ["ou_alice"],
}
},
}
)
adapter = FeishuAdapter(config)
adapter._bot_open_id = "ou_bot"
message = SimpleNamespace(
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
)
self.assertTrue(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_admin", user_id=None),
"oc_chat_e",
)
)
def test_default_group_policy_fallback_for_chats_without_explicit_rule(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
config = PlatformConfig(
extra={
"default_group_policy": "open",
}
)
adapter = FeishuAdapter(config)
adapter._bot_open_id = "ou_bot"
message = SimpleNamespace(
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
)
self.assertTrue(
adapter._should_accept_group_message(
message,
SimpleNamespace(open_id="ou_anyone", user_id=None),
"oc_chat_unknown",
)
)
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
def test_group_message_matches_bot_open_id_when_configured(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._bot_open_id = "ou_bot"
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
bot_mention = SimpleNamespace(
@@ -769,22 +1095,16 @@ class TestAdapterBehavior(unittest.TestCase):
id=SimpleNamespace(open_id="ou_other", user_id="u_other"),
)
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[bot_mention]), sender_id))
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id))
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[bot_mention]), sender_id, ""))
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id, ""))
@patch.dict(
os.environ,
{
"FEISHU_GROUP_POLICY": "open",
"FEISHU_BOT_NAME": "Hermes Bot",
},
clear=True,
)
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
def test_group_message_matches_bot_name_when_only_name_available(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._bot_name = "Hermes Bot"
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
named_mention = SimpleNamespace(
@@ -796,22 +1116,16 @@ class TestAdapterBehavior(unittest.TestCase):
id=SimpleNamespace(open_id="ou_other", user_id="u_other"),
)
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[named_mention]), sender_id))
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[different_mention]), sender_id))
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[named_mention]), sender_id, ""))
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[different_mention]), sender_id, ""))
@patch.dict(
os.environ,
{
"FEISHU_GROUP_POLICY": "open",
"FEISHU_BOT_OPEN_ID": "ou_bot",
},
clear=True,
)
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
def test_group_post_message_uses_parsed_mentions_when_sdk_mentions_missing(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._bot_open_id = "ou_bot"
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
message = SimpleNamespace(
message_type="post",
@@ -819,7 +1133,7 @@ class TestAdapterBehavior(unittest.TestCase):
content='{"en_us":{"content":[[{"tag":"at","user_name":"Hermes","open_id":"ou_bot"}]]}}',
)
self.assertTrue(adapter._should_accept_group_message(message, sender_id))
self.assertTrue(adapter._should_accept_group_message(message, sender_id, ""))
@patch.dict(os.environ, {}, clear=True)
def test_extract_post_message_as_text(self):
@@ -1196,7 +1510,12 @@ class TestAdapterBehavior(unittest.TestCase):
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._loop = object()
class _Loop:
def is_closed(self):
return False
adapter._loop = _Loop()
message = SimpleNamespace(
message_id="om_text",
@@ -1210,6 +1529,7 @@ class TestAdapterBehavior(unittest.TestCase):
data = SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
future = SimpleNamespace(add_done_callback=lambda *_args, **_kwargs: None)
def _submit(coro, _loop):
coro.close()
return future
@@ -1219,6 +1539,30 @@ class TestAdapterBehavior(unittest.TestCase):
self.assertTrue(submit.called)
@patch.dict(os.environ, {}, clear=True)
def test_webhook_request_uses_same_message_dispatch_path(self):
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
adapter = FeishuAdapter(PlatformConfig())
adapter._on_message_event = Mock()
body = json.dumps({
"header": {"event_type": "im.message.receive_v1"},
"event": {"message": {"message_id": "om_test"}},
}).encode("utf-8")
request = SimpleNamespace(
remote="127.0.0.1",
content_length=None,
headers={},
read=AsyncMock(return_value=body),
)
response = asyncio.run(adapter._handle_webhook_request(request))
self.assertEqual(response.status, 200)
adapter._on_message_event.assert_called_once()
@patch.dict(os.environ, {}, clear=True)
def test_process_inbound_message_uses_event_sender_identity_only(self):
from gateway.config import PlatformConfig
@@ -2456,7 +2800,7 @@ class TestGroupMentionAtAll(unittest.TestCase):
mentions=[],
)
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
self.assertTrue(adapter._should_accept_group_message(message, sender_id))
self.assertTrue(adapter._should_accept_group_message(message, sender_id, ""))
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "allowlist", "FEISHU_ALLOWED_USERS": "ou_allowed"}, clear=True)
def test_at_all_still_requires_policy_gate(self):
@@ -2468,10 +2812,10 @@ class TestGroupMentionAtAll(unittest.TestCase):
message = SimpleNamespace(content='{"text":"@_all attention"}', mentions=[])
# Non-allowlisted user — should be blocked even with @_all.
blocked_sender = SimpleNamespace(open_id="ou_blocked", user_id=None)
self.assertFalse(adapter._should_accept_group_message(message, blocked_sender))
self.assertFalse(adapter._should_accept_group_message(message, blocked_sender, ""))
# Allowlisted user — should pass.
allowed_sender = SimpleNamespace(open_id="ou_allowed", user_id=None)
self.assertTrue(adapter._should_accept_group_message(message, allowed_sender))
self.assertTrue(adapter._should_accept_group_message(message, allowed_sender, ""))
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")

View File

@@ -0,0 +1,432 @@
"""Tests for Feishu interactive card approval buttons."""
import asyncio
import json
import os
import sys
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
# ---------------------------------------------------------------------------
# Ensure the repo root is importable
# ---------------------------------------------------------------------------
_repo = str(Path(__file__).resolve().parents[2])
if _repo not in sys.path:
sys.path.insert(0, _repo)
# ---------------------------------------------------------------------------
# Minimal Feishu mock so FeishuAdapter can be imported without lark-oapi
# ---------------------------------------------------------------------------
def _ensure_feishu_mocks():
"""Provide stubs for lark-oapi / aiohttp.web so the import succeeds."""
if "lark_oapi" not in sys.modules:
mod = MagicMock()
for name in (
"lark_oapi", "lark_oapi.api.im.v1",
"lark_oapi.event", "lark_oapi.event.callback_type",
):
sys.modules.setdefault(name, mod)
if "aiohttp" not in sys.modules:
aio = MagicMock()
sys.modules.setdefault("aiohttp", aio)
sys.modules.setdefault("aiohttp.web", aio.web)
_ensure_feishu_mocks()
from gateway.config import PlatformConfig
from gateway.platforms.feishu import FeishuAdapter
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter() -> FeishuAdapter:
"""Create a FeishuAdapter with mocked internals."""
config = PlatformConfig(enabled=True)
adapter = FeishuAdapter(config)
adapter._client = MagicMock()
return adapter
def _make_card_action_data(
action_value: dict,
chat_id: str = "oc_12345",
open_id: str = "ou_user1",
token: str = "tok_abc",
) -> SimpleNamespace:
"""Create a mock Feishu card action callback data object."""
return SimpleNamespace(
event=SimpleNamespace(
token=token,
context=SimpleNamespace(open_chat_id=chat_id),
operator=SimpleNamespace(open_id=open_id),
action=SimpleNamespace(
tag="button",
value=action_value,
),
),
)
# ===========================================================================
# send_exec_approval — interactive card with buttons
# ===========================================================================
class TestFeishuExecApproval:
"""Test send_exec_approval sends an interactive card."""
@pytest.mark.asyncio
async def test_sends_interactive_card(self):
adapter = _make_adapter()
mock_response = SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(message_id="msg_001"),
)
with patch.object(
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
return_value=mock_response,
) as mock_send:
result = await adapter.send_exec_approval(
chat_id="oc_12345",
command="rm -rf /important",
session_key="agent:main:feishu:group:oc_12345",
description="dangerous deletion",
)
assert result.success is True
assert result.message_id == "msg_001"
mock_send.assert_called_once()
kwargs = mock_send.call_args[1]
assert kwargs["chat_id"] == "oc_12345"
assert kwargs["msg_type"] == "interactive"
# Verify card payload contains the command and buttons
card = json.loads(kwargs["payload"])
assert card["header"]["template"] == "orange"
assert "rm -rf /important" in card["elements"][0]["content"]
assert "dangerous deletion" in card["elements"][0]["content"]
# Check buttons
actions = card["elements"][1]["actions"]
assert len(actions) == 4
action_names = [a["value"]["hermes_action"] for a in actions]
assert action_names == [
"approve_once", "approve_session", "approve_always", "deny"
]
@pytest.mark.asyncio
async def test_stores_approval_state(self):
adapter = _make_adapter()
mock_response = SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(message_id="msg_002"),
)
with patch.object(
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
return_value=mock_response,
):
await adapter.send_exec_approval(
chat_id="oc_12345",
command="echo test",
session_key="my-session-key",
)
assert len(adapter._approval_state) == 1
approval_id = list(adapter._approval_state.keys())[0]
state = adapter._approval_state[approval_id]
assert state["session_key"] == "my-session-key"
assert state["message_id"] == "msg_002"
assert state["chat_id"] == "oc_12345"
@pytest.mark.asyncio
async def test_not_connected(self):
adapter = _make_adapter()
adapter._client = None
result = await adapter.send_exec_approval(
chat_id="oc_12345", command="ls", session_key="s"
)
assert result.success is False
@pytest.mark.asyncio
async def test_truncates_long_command(self):
adapter = _make_adapter()
mock_response = SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(message_id="msg_003"),
)
with patch.object(
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
return_value=mock_response,
) as mock_send:
long_cmd = "x" * 5000
await adapter.send_exec_approval(
chat_id="oc_12345", command=long_cmd, session_key="s"
)
card = json.loads(mock_send.call_args[1]["payload"])
content = card["elements"][0]["content"]
assert "..." in content
assert len(content) < 5000
@pytest.mark.asyncio
async def test_multiple_approvals_get_unique_ids(self):
adapter = _make_adapter()
mock_response = SimpleNamespace(
success=lambda: True,
data=SimpleNamespace(message_id="msg_x"),
)
with patch.object(
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
return_value=mock_response,
):
await adapter.send_exec_approval(
chat_id="oc_1", command="cmd1", session_key="s1"
)
await adapter.send_exec_approval(
chat_id="oc_2", command="cmd2", session_key="s2"
)
assert len(adapter._approval_state) == 2
ids = list(adapter._approval_state.keys())
assert ids[0] != ids[1]
# ===========================================================================
# _handle_card_action_event — approval button clicks
# ===========================================================================
class TestFeishuApprovalCallback:
"""Test the approval intercept in _handle_card_action_event."""
@pytest.mark.asyncio
async def test_resolves_approval_on_click(self):
adapter = _make_adapter()
adapter._approval_state[1] = {
"session_key": "agent:main:feishu:group:oc_12345",
"message_id": "msg_001",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
action_value={"hermes_action": "approve_once", "approval_id": 1},
)
with (
patch.object(
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
return_value={"user_id": "ou_user1", "user_name": "Norbert", "user_id_alt": None},
),
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
):
await adapter._handle_card_action_event(data)
mock_resolve.assert_called_once_with("agent:main:feishu:group:oc_12345", "once")
mock_update.assert_called_once_with("msg_001", "Approved once", "Norbert", "once")
# State should be cleaned up
assert 1 not in adapter._approval_state
@pytest.mark.asyncio
async def test_deny_button(self):
adapter = _make_adapter()
adapter._approval_state[2] = {
"session_key": "some-session",
"message_id": "msg_002",
"chat_id": "oc_12345",
}
data = _make_card_action_data(
action_value={"hermes_action": "deny", "approval_id": 2},
token="tok_deny",
)
with (
patch.object(
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
return_value={"user_id": "ou_alice", "user_name": "Alice", "user_id_alt": None},
),
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
):
await adapter._handle_card_action_event(data)
mock_resolve.assert_called_once_with("some-session", "deny")
mock_update.assert_called_once_with("msg_002", "Denied", "Alice", "deny")
@pytest.mark.asyncio
async def test_session_approval(self):
adapter = _make_adapter()
adapter._approval_state[3] = {
"session_key": "sess-3",
"message_id": "msg_003",
"chat_id": "oc_99",
}
data = _make_card_action_data(
action_value={"hermes_action": "approve_session", "approval_id": 3},
token="tok_ses",
)
with (
patch.object(
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
return_value={"user_id": "ou_u", "user_name": "Bob", "user_id_alt": None},
),
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
):
await adapter._handle_card_action_event(data)
mock_resolve.assert_called_once_with("sess-3", "session")
mock_update.assert_called_once_with("msg_003", "Approved for session", "Bob", "session")
@pytest.mark.asyncio
async def test_always_approval(self):
adapter = _make_adapter()
adapter._approval_state[4] = {
"session_key": "sess-4",
"message_id": "msg_004",
"chat_id": "oc_55",
}
data = _make_card_action_data(
action_value={"hermes_action": "approve_always", "approval_id": 4},
token="tok_alw",
)
with (
patch.object(
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
return_value={"user_id": "ou_u", "user_name": "Carol", "user_id_alt": None},
),
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock),
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
):
await adapter._handle_card_action_event(data)
mock_resolve.assert_called_once_with("sess-4", "always")
@pytest.mark.asyncio
async def test_already_resolved_drops_silently(self):
adapter = _make_adapter()
# No state for approval_id 99 — already resolved
data = _make_card_action_data(
action_value={"hermes_action": "approve_once", "approval_id": 99},
token="tok_gone",
)
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._handle_card_action_event(data)
# Should NOT resolve — already handled
mock_resolve.assert_not_called()
@pytest.mark.asyncio
async def test_non_approval_actions_route_normally(self):
"""Non-approval card actions should still become synthetic commands."""
adapter = _make_adapter()
data = _make_card_action_data(
action_value={"custom_action": "something_else"},
token="tok_normal",
)
with (
patch.object(
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
return_value={"user_id": "ou_u", "user_name": "Dave", "user_id_alt": None},
),
patch.object(adapter, "get_chat_info", new_callable=AsyncMock, return_value={"name": "Test Chat"}),
patch.object(adapter, "_handle_message_with_guards", new_callable=AsyncMock) as mock_handle,
patch("tools.approval.resolve_gateway_approval") as mock_resolve,
):
await adapter._handle_card_action_event(data)
# Should NOT resolve any approval
mock_resolve.assert_not_called()
# Should have routed as synthetic command
mock_handle.assert_called_once()
event = mock_handle.call_args[0][0]
assert "/card button" in event.text
# ===========================================================================
# _update_approval_card — card replacement after resolution
# ===========================================================================
class TestFeishuUpdateApprovalCard:
"""Test the card update after approval resolution."""
@pytest.mark.asyncio
async def test_updates_card_on_approve(self):
adapter = _make_adapter()
mock_update = AsyncMock()
adapter._client.im.v1.message.update = MagicMock()
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
await adapter._update_approval_card(
"msg_001", "Approved once", "Norbert", "once"
)
mock_thread.assert_called_once()
# Verify the update request was built
call_args = mock_thread.call_args
assert call_args[0][0] == adapter._client.im.v1.message.update
@pytest.mark.asyncio
async def test_updates_card_on_deny(self):
adapter = _make_adapter()
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
await adapter._update_approval_card(
"msg_002", "Denied", "Alice", "deny"
)
mock_thread.assert_called_once()
@pytest.mark.asyncio
async def test_skips_update_when_not_connected(self):
adapter = _make_adapter()
adapter._client = None
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
await adapter._update_approval_card(
"msg_001", "Approved", "Bob", "once"
)
mock_thread.assert_not_called()
@pytest.mark.asyncio
async def test_skips_update_when_no_message_id(self):
adapter = _make_adapter()
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
await adapter._update_approval_card(
"", "Approved", "Bob", "once"
)
mock_thread.assert_not_called()
@pytest.mark.asyncio
async def test_swallows_update_errors(self):
adapter = _make_adapter()
with patch("asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("API error")):
# Should not raise
await adapter._update_approval_card(
"msg_001", "Approved", "Bob", "once"
)

View File

@@ -2,12 +2,54 @@
import asyncio
import json
import re
import sys
import types
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
def _make_fake_nio():
"""Create a lightweight fake ``nio`` module with real response classes.
Tests that call production methods doing ``import nio`` / ``isinstance(resp, nio.XxxResponse)``
need real classes (not MagicMock auto-attributes) to satisfy isinstance checks.
Use via ``patch.dict("sys.modules", {"nio": _make_fake_nio()})``.
"""
mod = types.ModuleType("nio")
class RoomSendResponse:
def __init__(self, event_id="$fake"):
self.event_id = event_id
class RoomRedactResponse:
pass
class RoomCreateResponse:
def __init__(self, room_id="!fake:example.org"):
self.room_id = room_id
class RoomInviteResponse:
pass
class UploadResponse:
def __init__(self, content_uri="mxc://example.org/fake"):
self.content_uri = content_uri
# Minimal Api stub for code that checks nio.Api.RoomPreset
class _Api:
pass
mod.Api = _Api
mod.RoomSendResponse = RoomSendResponse
mod.RoomRedactResponse = RoomRedactResponse
mod.RoomCreateResponse = RoomCreateResponse
mod.RoomInviteResponse = RoomInviteResponse
mod.UploadResponse = UploadResponse
return mod
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
@@ -428,6 +470,7 @@ class TestMatrixRequirements:
def test_check_requirements_with_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
try:
import nio # noqa: F401
@@ -448,6 +491,45 @@ class TestMatrixRequirements:
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False
def test_check_requirements_encryption_true_no_e2ee_deps(self, monkeypatch):
"""MATRIX_ENCRYPTION=true should fail if python-olm is not installed."""
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False):
assert matrix_mod.check_matrix_requirements() is False
def test_check_requirements_encryption_false_no_e2ee_deps_ok(self, monkeypatch):
"""Without encryption, missing E2EE deps should not block startup."""
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False):
# Still needs nio itself to be importable
try:
import nio # noqa: F401
assert matrix_mod.check_matrix_requirements() is True
except ImportError:
assert matrix_mod.check_matrix_requirements() is False
def test_check_requirements_encryption_true_with_e2ee_deps(self, monkeypatch):
"""MATRIX_ENCRYPTION=true should pass if E2EE deps are available."""
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
try:
import nio # noqa: F401
assert matrix_mod.check_matrix_requirements() is True
except ImportError:
assert matrix_mod.check_matrix_requirements() is False
# ---------------------------------------------------------------------------
# Access-token auth / E2EE bootstrap
@@ -516,10 +598,12 @@ class TestMatrixAccessTokenAuth:
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await adapter.connect() is True
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await adapter.connect() is True
fake_client.restore_login.assert_called_once_with(
"@bot:example.org", "DEV123", "syt_test_access_token"
@@ -532,6 +616,326 @@ class TestMatrixAccessTokenAuth:
await adapter.disconnect()
class TestMatrixE2EEHardFail:
"""connect() must refuse to start when E2EE is requested but deps are missing."""
@pytest.mark.asyncio
async def test_connect_fails_when_encryption_true_but_no_e2ee_deps(self):
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_access_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
"encryption": True,
},
)
adapter = MatrixAdapter(config)
fake_nio = MagicMock()
fake_nio.AsyncClient = MagicMock()
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False):
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_connect_fails_when_olm_not_loaded_after_login(self):
"""Even if _check_e2ee_deps passes, if olm is None after auth, hard-fail."""
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_access_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
"encryption": True,
},
)
adapter = MatrixAdapter(config)
class FakeWhoamiResponse:
def __init__(self, user_id, device_id):
self.user_id = user_id
self.device_id = device_id
fake_client = MagicMock()
fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
fake_client.close = AsyncMock()
# olm is None — crypto store not loaded
fake_client.olm = None
fake_client.should_upload_keys = False
def _restore_login(user_id, device_id, access_token):
fake_client.user_id = user_id
fake_client.device_id = device_id
fake_client.access_token = access_token
fake_client.restore_login = MagicMock(side_effect=_restore_login)
fake_nio = MagicMock()
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
fake_nio.WhoamiResponse = FakeWhoamiResponse
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await adapter.connect()
assert result is False
fake_client.close.assert_awaited_once()
class TestMatrixDeviceId:
"""MATRIX_DEVICE_ID should be used for stable device identity."""
def test_device_id_from_config_extra(self):
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test",
extra={
"homeserver": "https://matrix.example.org",
"device_id": "HERMES_BOT_STABLE",
},
)
adapter = MatrixAdapter(config)
assert adapter._device_id == "HERMES_BOT_STABLE"
def test_device_id_from_env(self, monkeypatch):
monkeypatch.setenv("MATRIX_DEVICE_ID", "FROM_ENV")
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test",
extra={
"homeserver": "https://matrix.example.org",
},
)
adapter = MatrixAdapter(config)
assert adapter._device_id == "FROM_ENV"
def test_device_id_config_takes_precedence_over_env(self, monkeypatch):
monkeypatch.setenv("MATRIX_DEVICE_ID", "FROM_ENV")
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test",
extra={
"homeserver": "https://matrix.example.org",
"device_id": "FROM_CONFIG",
},
)
adapter = MatrixAdapter(config)
assert adapter._device_id == "FROM_CONFIG"
@pytest.mark.asyncio
async def test_connect_uses_configured_device_id_over_whoami(self):
"""When MATRIX_DEVICE_ID is set, it should be used instead of whoami device_id."""
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_access_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
"encryption": True,
"device_id": "MY_STABLE_DEVICE",
},
)
adapter = MatrixAdapter(config)
class FakeWhoamiResponse:
def __init__(self, user_id, device_id):
self.user_id = user_id
self.device_id = device_id
class FakeSyncResponse:
def __init__(self):
self.rooms = MagicMock(join={})
fake_client = MagicMock()
fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "WHOAMI_DEV"))
fake_client.sync = AsyncMock(return_value=FakeSyncResponse())
fake_client.keys_upload = AsyncMock()
fake_client.keys_query = AsyncMock()
fake_client.keys_claim = AsyncMock()
fake_client.send_to_device_messages = AsyncMock(return_value=[])
fake_client.get_users_for_key_claiming = MagicMock(return_value={})
fake_client.close = AsyncMock()
fake_client.add_event_callback = MagicMock()
fake_client.rooms = {}
fake_client.account_data = {}
fake_client.olm = object()
fake_client.should_upload_keys = False
fake_client.should_query_keys = False
fake_client.should_claim_keys = False
def _restore_login(user_id, device_id, access_token):
fake_client.user_id = user_id
fake_client.device_id = device_id
fake_client.access_token = access_token
fake_client.restore_login = MagicMock(side_effect=_restore_login)
fake_nio = MagicMock()
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
fake_nio.WhoamiResponse = FakeWhoamiResponse
fake_nio.SyncResponse = FakeSyncResponse
fake_nio.LoginResponse = type("LoginResponse", (), {})
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await adapter.connect() is True
# The configured device_id should override the whoami device_id
fake_client.restore_login.assert_called_once_with(
"@bot:example.org", "MY_STABLE_DEVICE", "syt_test_access_token"
)
assert fake_client.device_id == "MY_STABLE_DEVICE"
# Verify device_id was passed to nio.AsyncClient constructor
ctor_call = fake_nio.AsyncClient.call_args
assert ctor_call.kwargs.get("device_id") == "MY_STABLE_DEVICE"
await adapter.disconnect()
class TestMatrixE2EEClientConstructorFailure:
"""connect() should hard-fail if nio.AsyncClient() raises when encryption is on."""
@pytest.mark.asyncio
async def test_connect_fails_when_e2ee_client_constructor_raises(self):
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_access_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
"encryption": True,
},
)
adapter = MatrixAdapter(config)
fake_nio = MagicMock()
fake_nio.AsyncClient = MagicMock(side_effect=Exception("olm init failed"))
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await adapter.connect()
assert result is False
class TestMatrixPasswordLoginDeviceId:
"""MATRIX_DEVICE_ID should be passed to nio.AsyncClient even with password login."""
@pytest.mark.asyncio
async def test_password_login_passes_device_id_to_constructor(self):
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
"password": "secret",
"device_id": "STABLE_PW_DEVICE",
},
)
adapter = MatrixAdapter(config)
class FakeLoginResponse:
pass
class FakeSyncResponse:
def __init__(self):
self.rooms = MagicMock(join={})
fake_client = MagicMock()
fake_client.login = AsyncMock(return_value=FakeLoginResponse())
fake_client.sync = AsyncMock(return_value=FakeSyncResponse())
fake_client.close = AsyncMock()
fake_client.add_event_callback = MagicMock()
fake_client.rooms = {}
fake_client.account_data = {}
fake_nio = MagicMock()
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
fake_nio.LoginResponse = FakeLoginResponse
fake_nio.SyncResponse = FakeSyncResponse
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await adapter.connect() is True
# Verify device_id was passed to the nio.AsyncClient constructor
ctor_call = fake_nio.AsyncClient.call_args
assert ctor_call.kwargs.get("device_id") == "STABLE_PW_DEVICE"
await adapter.disconnect()
class TestMatrixDeviceIdConfig:
"""MATRIX_DEVICE_ID should be plumbed through gateway config."""
def test_device_id_in_config_extra(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_DEVICE_ID", "HERMES_BOT")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("device_id") == "HERMES_BOT"
def test_device_id_not_set_when_env_empty(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_DEVICE_ID", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert "device_id" not in mc.extra
class TestMatrixE2EEMaintenance:
@pytest.mark.asyncio
async def test_sync_loop_runs_e2ee_maintenance_requests(self):
@@ -1071,10 +1475,12 @@ class TestMatrixEncryptedMedia:
fake_nio.InviteMemberEvent = FakeInviteMemberEvent
fake_nio.MegolmEvent = FakeMegolmEvent
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await adapter.connect() is True
from gateway.platforms import matrix as matrix_mod
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await adapter.connect() is True
callback_classes = [call.args[1] for call in fake_client.add_event_callback.call_args_list]
assert FakeRoomEncryptedImage in callback_classes
@@ -1086,7 +1492,10 @@ class TestMatrixEncryptedMedia:
@pytest.mark.asyncio
async def test_on_room_message_media_decrypts_encrypted_image_and_passes_local_path(self):
from nio.crypto.attachments import encrypt_attachment
try:
from nio.crypto.attachments import encrypt_attachment
except (ImportError, ModuleNotFoundError):
pytest.skip("matrix-nio[e2e] required for encryption tests")
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
@@ -1154,7 +1563,10 @@ class TestMatrixEncryptedMedia:
@pytest.mark.asyncio
async def test_on_room_message_media_decrypts_encrypted_voice_and_caches_audio(self):
from nio.crypto.attachments import encrypt_attachment
try:
from nio.crypto.attachments import encrypt_attachment
except (ImportError, ModuleNotFoundError):
pytest.skip("matrix-nio[e2e] required for encryption tests")
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
@@ -1223,7 +1635,10 @@ class TestMatrixEncryptedMedia:
@pytest.mark.asyncio
async def test_on_room_message_media_decrypts_encrypted_file_and_caches_document(self):
from nio.crypto.attachments import encrypt_attachment
try:
from nio.crypto.attachments import encrypt_attachment
except (ImportError, ModuleNotFoundError):
pytest.skip("matrix-nio[e2e] required for encryption tests")
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
@@ -1519,14 +1934,15 @@ class TestMatrixReactions:
@pytest.mark.asyncio
async def test_send_reaction(self):
"""_send_reaction should call room_send with m.reaction."""
nio = pytest.importorskip("nio")
fake_nio = _make_fake_nio()
mock_client = MagicMock()
mock_client.room_send = AsyncMock(
return_value=MagicMock(spec=nio.RoomSendResponse)
return_value=fake_nio.RoomSendResponse("$reaction1")
)
self.adapter._client = mock_client
result = await self.adapter._send_reaction("!room:ex", "$event1", "👍")
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await self.adapter._send_reaction("!room:ex", "$event1", "👍")
assert result is True
mock_client.room_send.assert_called_once()
args = mock_client.room_send.call_args
@@ -1538,7 +1954,8 @@ class TestMatrixReactions:
@pytest.mark.asyncio
async def test_send_reaction_no_client(self):
self.adapter._client = None
result = await self.adapter._send_reaction("!room:ex", "$ev", "👍")
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
result = await self.adapter._send_reaction("!room:ex", "$ev", "👍")
assert result is False
@pytest.mark.asyncio
@@ -1635,21 +2052,23 @@ class TestMatrixRedaction:
@pytest.mark.asyncio
async def test_redact_message(self):
nio = pytest.importorskip("nio")
fake_nio = _make_fake_nio()
mock_client = MagicMock()
mock_client.room_redact = AsyncMock(
return_value=MagicMock(spec=nio.RoomRedactResponse)
return_value=fake_nio.RoomRedactResponse()
)
self.adapter._client = mock_client
result = await self.adapter.redact_message("!room:ex", "$ev1", "oops")
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await self.adapter.redact_message("!room:ex", "$ev1", "oops")
assert result is True
mock_client.room_redact.assert_called_once()
@pytest.mark.asyncio
async def test_redact_no_client(self):
self.adapter._client = None
result = await self.adapter.redact_message("!room:ex", "$ev1")
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
result = await self.adapter.redact_message("!room:ex", "$ev1")
assert result is False
@@ -1663,33 +2082,35 @@ class TestMatrixRoomManagement:
@pytest.mark.asyncio
async def test_create_room(self):
nio = pytest.importorskip("nio")
mock_resp = MagicMock(spec=nio.RoomCreateResponse)
mock_resp.room_id = "!new:example.org"
fake_nio = _make_fake_nio()
mock_resp = fake_nio.RoomCreateResponse(room_id="!new:example.org")
mock_client = MagicMock()
mock_client.room_create = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
room_id = await self.adapter.create_room(name="Test Room", topic="A test")
with patch.dict("sys.modules", {"nio": fake_nio}):
room_id = await self.adapter.create_room(name="Test Room", topic="A test")
assert room_id == "!new:example.org"
assert "!new:example.org" in self.adapter._joined_rooms
@pytest.mark.asyncio
async def test_invite_user(self):
nio = pytest.importorskip("nio")
fake_nio = _make_fake_nio()
mock_client = MagicMock()
mock_client.room_invite = AsyncMock(
return_value=MagicMock(spec=nio.RoomInviteResponse)
return_value=fake_nio.RoomInviteResponse()
)
self.adapter._client = mock_client
result = await self.adapter.invite_user("!room:ex", "@user:ex")
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await self.adapter.invite_user("!room:ex", "@user:ex")
assert result is True
@pytest.mark.asyncio
async def test_create_room_no_client(self):
self.adapter._client = None
result = await self.adapter.create_room()
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
result = await self.adapter.create_room()
assert result is None
@@ -1735,28 +2156,28 @@ class TestMatrixMessageTypes:
@pytest.mark.asyncio
async def test_send_emote(self):
nio = pytest.importorskip("nio")
fake_nio = _make_fake_nio()
mock_client = MagicMock()
mock_resp = MagicMock(spec=nio.RoomSendResponse)
mock_resp.event_id = "$emote1"
mock_resp = fake_nio.RoomSendResponse(event_id="$emote1")
mock_client.room_send = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
result = await self.adapter.send_emote("!room:ex", "waves hello")
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await self.adapter.send_emote("!room:ex", "waves hello")
assert result.success is True
call_args = mock_client.room_send.call_args[0]
assert call_args[2]["msgtype"] == "m.emote"
@pytest.mark.asyncio
async def test_send_notice(self):
nio = pytest.importorskip("nio")
fake_nio = _make_fake_nio()
mock_client = MagicMock()
mock_resp = MagicMock(spec=nio.RoomSendResponse)
mock_resp.event_id = "$notice1"
mock_resp = fake_nio.RoomSendResponse(event_id="$notice1")
mock_client.room_send = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
result = await self.adapter.send_notice("!room:ex", "System message")
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await self.adapter.send_notice("!room:ex", "System message")
assert result.success is True
call_args = mock_client.room_send.call_args[0]
assert call_args[2]["msgtype"] == "m.notice"
@@ -1764,5 +2185,6 @@ class TestMatrixMessageTypes:
@pytest.mark.asyncio
async def test_send_emote_empty_text(self):
self.adapter._client = MagicMock()
result = await self.adapter.send_emote("!room:ex", "")
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
result = await self.adapter.send_emote("!room:ex", "")
assert result.success is False

View File

@@ -1,10 +1,18 @@
"""Tests for Matrix voice message support (MSC3245)."""
import io
import types
import pytest
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
nio = pytest.importorskip("nio", reason="matrix-nio not installed")
# Try importing real nio; skip entire file if not available.
# A MagicMock in sys.modules (from another test) is not the real package.
try:
import nio as _nio_probe
if not isinstance(_nio_probe, types.ModuleType) or not hasattr(_nio_probe, "__file__"):
pytest.skip("nio in sys.modules is a mock, not the real package", allow_module_level=True)
except ImportError:
pytest.skip("matrix-nio not installed", allow_module_level=True)
from gateway.platforms.base import MessageType

View File

@@ -504,7 +504,8 @@ class TestMattermostFileUpload:
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self):
@patch("tools.url_safety.is_safe_url", return_value=True)
async def test_send_image_downloads_and_uploads(self, _mock_safe):
"""send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET)
mock_dl_resp = AsyncMock()

View File

@@ -596,10 +596,11 @@ def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
return resp
@patch("tools.url_safety.is_safe_url", return_value=True)
class TestMattermostSendUrlAsFile:
"""Tests for MattermostAdapter._send_url_as_file"""
def test_success_on_first_attempt(self):
def test_success_on_first_attempt(self, _mock_safe):
"""200 on first attempt → file uploaded and post created."""
adapter = _make_mm_adapter()
resp = _make_aiohttp_resp(200)
@@ -616,7 +617,7 @@ class TestMattermostSendUrlAsFile:
adapter._upload_file.assert_called_once()
adapter._api_post.assert_called_once()
def test_retries_on_429_then_succeeds(self):
def test_retries_on_429_then_succeeds(self, _mock_safe):
"""429 on first attempt is retried; 200 on second attempt succeeds."""
adapter = _make_mm_adapter()
@@ -637,7 +638,7 @@ class TestMattermostSendUrlAsFile:
assert adapter._session.get.call_count == 2
mock_sleep.assert_called_once()
def test_retries_on_500_then_succeeds(self):
def test_retries_on_500_then_succeeds(self, _mock_safe):
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
adapter = _make_mm_adapter()
@@ -655,7 +656,7 @@ class TestMattermostSendUrlAsFile:
assert result.success
assert adapter._session.get.call_count == 2
def test_falls_back_to_text_after_max_retries_on_5xx(self):
def test_falls_back_to_text_after_max_retries_on_5xx(self, _mock_safe):
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
adapter = _make_mm_adapter()
@@ -674,7 +675,7 @@ class TestMattermostSendUrlAsFile:
text_arg = adapter.send.call_args[0][1]
assert "http://cdn.example.com/img.png" in text_arg
def test_falls_back_on_client_error(self):
def test_falls_back_on_client_error(self, _mock_safe):
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
import aiohttp
@@ -699,7 +700,7 @@ class TestMattermostSendUrlAsFile:
text_arg = adapter.send.call_args[0][1]
assert "http://cdn.example.com/img.png" in text_arg
def test_non_retryable_404_falls_back_immediately(self):
def test_non_retryable_404_falls_back_immediately(self, _mock_safe):
"""404 is non-retryable (< 500, != 429); send() is called right away."""
adapter = _make_mm_adapter()

View File

@@ -8,6 +8,7 @@ from gateway.platforms.base import (
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
MessageEvent,
MessageType,
_safe_url_for_log,
)
@@ -18,6 +19,31 @@ class TestSecretCaptureGuidance:
assert "~/.hermes/.env" in message
class TestSafeUrlForLog:
def test_strips_query_fragment_and_userinfo(self):
url = (
"https://user:pass@example.com/private/path/image.png"
"?X-Amz-Signature=supersecret&token=abc#frag"
)
result = _safe_url_for_log(url)
assert result == "https://example.com/.../image.png"
assert "supersecret" not in result
assert "token=abc" not in result
assert "user:pass@" not in result
def test_truncates_long_values(self):
long_url = "https://example.com/" + ("a" * 300)
result = _safe_url_for_log(long_url, max_len=40)
assert len(result) == 40
assert result.endswith("...")
def test_handles_small_and_non_positive_max_len(self):
url = "https://example.com/very/long/path/file.png?token=secret"
assert _safe_url_for_log(url, max_len=3) == "..."
assert _safe_url_for_log(url, max_len=2) == ".."
assert _safe_url_for_log(url, max_len=0) == ""
# ---------------------------------------------------------------------------
# MessageEvent — command parsing
# ---------------------------------------------------------------------------

View File

@@ -59,6 +59,7 @@ def _make_runner():
runner._honcho_managers = {}
runner._honcho_configs = {}
runner._shutdown_all_gateway_honcho = lambda: None
runner.session_store = MagicMock()
return runner

View File

@@ -87,7 +87,6 @@ class TestReasoningCommand:
)
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
runner = _make_runner()
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
@@ -108,7 +107,6 @@ class TestReasoningCommand:
config_path.write_text("agent:\n reasoning_effort: medium\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
runner = _make_runner()
runner._reasoning_config = {"enabled": True, "effort": "medium"}
@@ -138,7 +136,6 @@ class TestReasoningCommand:
"api_key": "test-key",
},
)
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
@@ -170,55 +167,6 @@ class TestReasoningCommand:
assert _CapturingAgent.last_init is not None
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": True, "effort": "low"}
def test_run_agent_prefers_config_over_stale_reasoning_env(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: none\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(
gateway_run,
"_resolve_runtime_agent_kwargs",
lambda: {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": "https://openrouter.ai/api/v1",
"api_key": "test-key",
},
)
monkeypatch.setenv("HERMES_REASONING_EFFORT", "low")
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
_CapturingAgent.last_init = None
runner = _make_runner()
source = SessionSource(
platform=Platform.LOCAL,
chat_id="cli",
chat_name="CLI",
chat_type="dm",
user_id="user-1",
)
result = asyncio.run(
runner._run_agent(
message="ping",
context_prompt="",
history=[],
source=source,
session_id="session-1",
session_key="agent:main:local:dm",
)
)
assert result["final_response"] == "ok"
assert _CapturingAgent.last_init is not None
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": False}
def test_run_agent_includes_enabled_mcp_servers_in_gateway_toolsets(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()

View File

@@ -201,8 +201,8 @@ class TestHandleResumeCommand:
db.close()
@pytest.mark.asyncio
async def test_resume_flushes_memories_with_gateway_session_key(self, tmp_path):
"""Resume should preserve the gateway session key for Honcho flushes."""
async def test_resume_flushes_memories(self, tmp_path):
"""Resume should flush memories from the current session before switching."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
@@ -221,6 +221,5 @@ class TestHandleResumeCommand:
runner._async_flush_memories.assert_called_once_with(
"current_session_001",
_session_key_for_event(event),
)
db.close()

View File

@@ -71,6 +71,24 @@ class FakeAgent:
}
class LongPreviewAgent:
"""Agent that emits a tool call with a very long preview string."""
LONG_CMD = "cd /home/teknium/.hermes/hermes-agent/.worktrees/hermes-d8860339 && source .venv/bin/activate && python -m pytest tests/gateway/test_run_progress_topics.py -n0 -q"
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("tool.started", "terminal", self.LONG_CMD, {})
time.sleep(0.35)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
def _make_runner(adapter):
gateway_run = importlib.import_module("gateway.run")
GatewayRunner = gateway_run.GatewayRunner
@@ -217,3 +235,102 @@ async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch
assert adapter.sent
assert adapter.sent[0]["metadata"] == {"thread_id": "1234567890.000001"}
assert all(call["metadata"] == {"thread_id": "1234567890.000001"} for call in adapter.typing)
# ---------------------------------------------------------------------------
# Preview truncation tests (all/new mode respects tool_preview_length)
# ---------------------------------------------------------------------------
def _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0):
"""Shared setup for long-preview truncation tests.
Returns (adapter, result) after running the agent with LongPreviewAgent.
``preview_length`` controls display.tool_preview_length in the config file
that _run_agent reads — so the gateway picks it up the same way production does.
"""
import asyncio
import yaml
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = LongPreviewAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
# Write config.yaml so _run_agent picks up tool_preview_length
config = {"display": {"tool_preview_length": preview_length}}
(tmp_path / "config.yaml").write_text(yaml.dump(config), encoding="utf-8")
adapter = ProgressCaptureAdapter()
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
thread_id=None,
)
result = asyncio.get_event_loop().run_until_complete(
runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-trunc",
session_key="agent:main:telegram:dm:12345",
)
)
return adapter, result
def test_all_mode_default_truncation_40_chars(monkeypatch, tmp_path):
"""When tool_preview_length is 0 (default), all/new mode truncates to 40 chars."""
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0)
assert result["final_response"] == "done"
assert adapter.sent
content = adapter.sent[0]["content"]
# The long command should be truncated — total preview <= 40 chars
assert "..." in content
# Extract the preview part between quotes
import re
match = re.search(r'"(.+)"', content)
assert match, f"No quoted preview found in: {content}"
preview_text = match.group(1)
assert len(preview_text) <= 40, f"Preview too long ({len(preview_text)}): {preview_text}"
def test_all_mode_respects_custom_preview_length(monkeypatch, tmp_path):
"""When tool_preview_length is explicitly set (e.g. 120), all/new mode uses that."""
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=120)
assert result["final_response"] == "done"
assert adapter.sent
content = adapter.sent[0]["content"]
# With 120-char cap, the command (165 chars) should still be truncated but longer
import re
match = re.search(r'"(.+)"', content)
assert match, f"No quoted preview found in: {content}"
preview_text = match.group(1)
# Should be longer than the 40-char default
assert len(preview_text) > 40, f"Preview suspiciously short ({len(preview_text)}): {preview_text}"
# But still capped at 120
assert len(preview_text) <= 120, f"Preview too long ({len(preview_text)}): {preview_text}"
def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
"""Short previews (under the cap) are not truncated."""
# Set a generous cap — the LongPreviewAgent's command is ~165 chars
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=200)
assert result["final_response"] == "done"
assert adapter.sent
content = adapter.sent[0]["content"]
# With a 200-char cap, the 165-char command should NOT be truncated
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"

View File

@@ -0,0 +1,158 @@
"""Tests that on_session_finalize and on_session_reset plugin hooks fire in the gateway."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(text=text, source=_make_source(), message_id="m1")
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner._session_model_overrides = {}
runner._pending_model_notes = {}
runner._background_tasks = set()
session_key = build_session_key(_make_source())
session_entry = SessionEntry(
session_key=session_key,
session_id="sess-old",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
new_session_entry = SessionEntry(
session_key=session_key,
session_id="sess-new",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = new_session_entry
runner.session_store.reset_session.return_value = new_session_entry
runner.session_store._entries = {session_key: session_entry}
runner.session_store._generate_session_key.return_value = session_key
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._agent_cache_lock = None
runner._is_user_authorized = lambda _source: True
runner._format_session_info = lambda: ""
return runner
@pytest.mark.asyncio
@patch("hermes_cli.plugins.invoke_hook")
async def test_reset_fires_finalize_hook(mock_invoke_hook):
"""/new must fire on_session_finalize with the OLD session id."""
runner = _make_runner()
await runner._handle_reset_command(_make_event("/new"))
mock_invoke_hook.assert_any_call(
"on_session_finalize", session_id="sess-old", platform="telegram"
)
@pytest.mark.asyncio
@patch("hermes_cli.plugins.invoke_hook")
async def test_reset_fires_reset_hook(mock_invoke_hook):
"""/new must fire on_session_reset with the NEW session id."""
runner = _make_runner()
await runner._handle_reset_command(_make_event("/new"))
mock_invoke_hook.assert_any_call(
"on_session_reset", session_id="sess-new", platform="telegram"
)
@pytest.mark.asyncio
@patch("hermes_cli.plugins.invoke_hook")
async def test_finalize_before_reset(mock_invoke_hook):
"""on_session_finalize must fire before on_session_reset."""
runner = _make_runner()
await runner._handle_reset_command(_make_event("/new"))
calls = [c for c in mock_invoke_hook.call_args_list
if c[0][0] in ("on_session_finalize", "on_session_reset")]
hook_names = [c[0][0] for c in calls]
assert hook_names == ["on_session_finalize", "on_session_reset"]
@pytest.mark.asyncio
@patch("hermes_cli.plugins.invoke_hook")
async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
"""Gateway stop() must fire on_session_finalize for each active agent."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._running = True
runner._background_tasks = set()
runner._pending_messages = {}
runner._pending_approvals = {}
runner._shutdown_event = MagicMock()
runner.adapters = {}
runner._exit_reason = "test"
agent1 = MagicMock()
agent1.session_id = "sess-a"
agent2 = MagicMock()
agent2.session_id = "sess-b"
runner._running_agents = {"key-a": agent1, "key-b": agent2}
with patch("gateway.status.remove_pid_file"), \
patch("gateway.status.write_runtime_status"):
await runner.stop()
finalize_calls = [
c for c in mock_invoke_hook.call_args_list
if c[0][0] == "on_session_finalize"
]
session_ids = {c[1]["session_id"] for c in finalize_calls}
assert session_ids == {"sess-a", "sess-b"}
@pytest.mark.asyncio
@patch("hermes_cli.plugins.invoke_hook", side_effect=Exception("boom"))
async def test_hook_error_does_not_break_reset(mock_invoke_hook):
"""Plugin hook errors must not prevent /new from completing."""
runner = _make_runner()
result = await runner._handle_reset_command(_make_event("/new"))
# Should still return a success message despite hook errors
assert "Session reset" in result or "New session" in result

View File

@@ -36,11 +36,16 @@ def _make_runner():
)
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._voice_mode = {}
runner._background_tasks = set()
runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.session_store = MagicMock()
runner.delivery_router = MagicMock()
return runner

View File

@@ -699,6 +699,147 @@ class TestReactions:
assert remove_calls[0].kwargs["name"] == "eyes"
# ---------------------------------------------------------------------------
# TestThreadReplyHandling
# ---------------------------------------------------------------------------
class TestThreadReplyHandling:
"""Test thread reply processing without explicit bot mentions."""
@pytest.fixture()
def mock_session_store(self):
"""Create a mock session store with entries dict."""
store = MagicMock()
store._entries = {}
store._ensure_loaded = MagicMock()
store.config = MagicMock()
store.config.group_sessions_per_user = True
return store
@pytest.fixture()
def adapter_with_session_store(self, mock_session_store):
"""Create an adapter with a mock session store attached."""
config = PlatformConfig(enabled=True, token="***")
a = SlackAdapter(config)
a._app = MagicMock()
a._app.client = AsyncMock()
a._bot_user_id = "U_BOT"
a._team_bot_user_ids = {"T_TEAM": "U_BOT"}
a._running = True
a.handle_message = AsyncMock()
a.set_session_store(mock_session_store)
return a
@pytest.mark.asyncio
async def test_thread_reply_without_mention_no_session_ignored(
self, adapter_with_session_store, mock_session_store
):
"""Thread replies without mention should be ignored if no active session."""
mock_session_store._entries = {} # No active sessions
event = {
"text": "Just replying in the thread",
"user": "U_USER",
"channel": "C123",
"ts": "123.456",
"thread_ts": "123.000", # Different from ts - this is a reply
"channel_type": "channel",
"team": "T_TEAM",
}
await adapter_with_session_store._handle_slack_message(event)
adapter_with_session_store.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_thread_reply_without_mention_with_session_processed(
self, adapter_with_session_store, mock_session_store
):
"""Thread replies without mention should be processed if there's an active session."""
# Simulate an active session for this thread
session_key = "agent:main:slack:group:C123:123.000:U_USER"
mock_session_store._entries = {session_key: MagicMock()}
event = {
"text": "Follow-up question",
"user": "U_USER",
"channel": "C123",
"ts": "123.456",
"thread_ts": "123.000", # Reply in thread 123.000
"channel_type": "channel",
"team": "T_TEAM",
}
await adapter_with_session_store._handle_slack_message(event)
adapter_with_session_store.handle_message.assert_called_once()
# Verify the text is passed through unchanged (no mention stripping needed)
msg_event = adapter_with_session_store.handle_message.call_args[0][0]
assert msg_event.text == "Follow-up question"
@pytest.mark.asyncio
async def test_thread_reply_with_mention_strips_bot_id(
self, adapter_with_session_store, mock_session_store
):
"""Thread replies with @mention should still strip the bot ID."""
# Even with a session, mentions should be stripped
session_key = "agent:main:slack:group:C123:123.000:U_USER"
mock_session_store._entries = {session_key: MagicMock()}
event = {
"text": "<@U_BOT> thanks for the help",
"user": "U_USER",
"channel": "C123",
"ts": "123.456",
"thread_ts": "123.000",
"channel_type": "channel",
"team": "T_TEAM",
}
await adapter_with_session_store._handle_slack_message(event)
adapter_with_session_store.handle_message.assert_called_once()
msg_event = adapter_with_session_store.handle_message.call_args[0][0]
assert "<@U_BOT>" not in msg_event.text
assert msg_event.text == "thanks for the help"
@pytest.mark.asyncio
async def test_top_level_message_requires_mention_even_with_session(
self, adapter_with_session_store, mock_session_store
):
"""Top-level channel messages should require mention even if session exists."""
# Session exists but this is a top-level message (no thread_ts)
session_key = "agent:main:slack:group:C123:123.000:U_USER"
mock_session_store._entries = {session_key: MagicMock()}
event = {
"text": "New question without mention",
"user": "U_USER",
"channel": "C123",
"ts": "456.789",
# No thread_ts - this is a top-level message
"channel_type": "channel",
"team": "T_TEAM",
}
await adapter_with_session_store._handle_slack_message(event)
adapter_with_session_store.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_no_session_store_ignores_thread_replies(
self, adapter
):
"""If no session store is attached, thread replies without mention should be ignored."""
# adapter fixture has no session store attached
event = {
"text": "Thread reply without mention",
"user": "U_USER",
"channel": "C123",
"ts": "123.456",
"thread_ts": "123.000",
"channel_type": "channel",
"team": "T_TEAM",
}
await adapter._handle_slack_message(event)
adapter.handle_message.assert_not_called()
# ---------------------------------------------------------------------------
# TestUserNameResolution
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,426 @@
"""Tests for Slack Block Kit approval buttons and thread context fetching."""
import asyncio
import os
import sys
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Ensure the repo root is importable
# ---------------------------------------------------------------------------
_repo = str(Path(__file__).resolve().parents[2])
if _repo not in sys.path:
sys.path.insert(0, _repo)
# ---------------------------------------------------------------------------
# Minimal Slack SDK mock so SlackAdapter can be imported
# ---------------------------------------------------------------------------
def _ensure_slack_mock():
"""Wire up the minimal mocks required to import SlackAdapter."""
if "slack_bolt" in sys.modules:
return
slack_bolt = MagicMock()
slack_bolt.async_app.AsyncApp = MagicMock
sys.modules["slack_bolt"] = slack_bolt
sys.modules["slack_bolt.async_app"] = slack_bolt.async_app
handler_mod = MagicMock()
handler_mod.AsyncSocketModeHandler = MagicMock
sys.modules["slack_bolt.adapter"] = MagicMock()
sys.modules["slack_bolt.adapter.socket_mode"] = MagicMock()
sys.modules["slack_bolt.adapter.socket_mode.async_handler"] = handler_mod
sdk_mod = MagicMock()
sdk_mod.web = MagicMock()
sdk_mod.web.async_client = MagicMock()
sdk_mod.web.async_client.AsyncWebClient = MagicMock
sys.modules["slack_sdk"] = sdk_mod
sys.modules["slack_sdk.web"] = sdk_mod.web
sys.modules["slack_sdk.web.async_client"] = sdk_mod.web.async_client
_ensure_slack_mock()
from gateway.platforms.slack import SlackAdapter
from gateway.config import Platform, PlatformConfig
def _make_adapter():
"""Create a SlackAdapter instance with mocked internals."""
config = PlatformConfig(enabled=True, token="xoxb-test-token")
adapter = SlackAdapter(config)
adapter._app = MagicMock()
adapter._bot_user_id = "U_BOT"
adapter._team_clients = {"T1": AsyncMock()}
adapter._team_bot_user_ids = {"T1": "U_BOT"}
adapter._channel_team = {"C1": "T1"}
return adapter
# ===========================================================================
# send_exec_approval — Block Kit buttons
# ===========================================================================
class TestSlackExecApproval:
"""Test the send_exec_approval method sends Block Kit buttons."""
@pytest.mark.asyncio
async def test_sends_blocks_with_buttons(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"})
result = await adapter.send_exec_approval(
chat_id="C1",
command="rm -rf /important",
session_key="agent:main:slack:group:C1:1111",
description="dangerous deletion",
)
assert result.success is True
assert result.message_id == "1234.5678"
# Verify chat_postMessage was called with blocks
mock_client.chat_postMessage.assert_called_once()
kwargs = mock_client.chat_postMessage.call_args[1]
assert "blocks" in kwargs
blocks = kwargs["blocks"]
assert len(blocks) == 2
assert blocks[0]["type"] == "section"
assert "rm -rf /important" in blocks[0]["text"]["text"]
assert "dangerous deletion" in blocks[0]["text"]["text"]
assert blocks[1]["type"] == "actions"
elements = blocks[1]["elements"]
assert len(elements) == 4
action_ids = [e["action_id"] for e in elements]
assert "hermes_approve_once" in action_ids
assert "hermes_approve_session" in action_ids
assert "hermes_approve_always" in action_ids
assert "hermes_deny" in action_ids
# Each button carries the session key as value
for e in elements:
assert e["value"] == "agent:main:slack:group:C1:1111"
@pytest.mark.asyncio
async def test_sends_in_thread(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"})
await adapter.send_exec_approval(
chat_id="C1",
command="echo test",
session_key="test-session",
metadata={"thread_id": "9999.0000"},
)
kwargs = mock_client.chat_postMessage.call_args[1]
assert kwargs.get("thread_ts") == "9999.0000"
@pytest.mark.asyncio
async def test_not_connected(self):
adapter = _make_adapter()
adapter._app = None
result = await adapter.send_exec_approval(
chat_id="C1", command="ls", session_key="s"
)
assert result.success is False
@pytest.mark.asyncio
async def test_truncates_long_command(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1.2"})
long_cmd = "x" * 5000
await adapter.send_exec_approval(
chat_id="C1", command=long_cmd, session_key="s"
)
kwargs = mock_client.chat_postMessage.call_args[1]
section_text = kwargs["blocks"][0]["text"]["text"]
assert "..." in section_text
assert len(section_text) < 5000
# ===========================================================================
# _handle_approval_action — button click handler
# ===========================================================================
class TestSlackApprovalAction:
"""Test the approval button click handler."""
@pytest.mark.asyncio
async def test_resolves_approval(self):
adapter = _make_adapter()
adapter._approval_resolved["1234.5678"] = False
ack = AsyncMock()
body = {
"message": {
"ts": "1234.5678",
"blocks": [
{"type": "section", "text": {"type": "mrkdwn", "text": "original text"}},
{"type": "actions", "elements": []},
],
},
"channel": {"id": "C1"},
"user": {"name": "norbert"},
}
action = {
"action_id": "hermes_approve_once",
"value": "agent:main:slack:group:C1:1111",
}
mock_client = adapter._team_clients["T1"]
mock_client.chat_update = AsyncMock()
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_approval_action(ack, body, action)
ack.assert_called_once()
mock_resolve.assert_called_once_with("agent:main:slack:group:C1:1111", "once")
# Message should be updated with decision
mock_client.chat_update.assert_called_once()
update_kwargs = mock_client.chat_update.call_args[1]
assert "Approved once by norbert" in update_kwargs["text"]
@pytest.mark.asyncio
async def test_prevents_double_click(self):
adapter = _make_adapter()
adapter._approval_resolved["1234.5678"] = True # Already resolved
ack = AsyncMock()
body = {
"message": {"ts": "1234.5678", "blocks": []},
"channel": {"id": "C1"},
"user": {"name": "norbert"},
}
action = {
"action_id": "hermes_approve_once",
"value": "some-session",
}
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._handle_approval_action(ack, body, action)
# Should have acked but NOT resolved
ack.assert_called_once()
mock_resolve.assert_not_called()
@pytest.mark.asyncio
async def test_deny_action(self):
adapter = _make_adapter()
adapter._approval_resolved["1.2"] = False
ack = AsyncMock()
body = {
"message": {"ts": "1.2", "blocks": [
{"type": "section", "text": {"type": "mrkdwn", "text": "cmd"}},
]},
"channel": {"id": "C1"},
"user": {"name": "alice"},
}
action = {"action_id": "hermes_deny", "value": "session-key"}
mock_client = adapter._team_clients["T1"]
mock_client.chat_update = AsyncMock()
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_approval_action(ack, body, action)
mock_resolve.assert_called_once_with("session-key", "deny")
update_kwargs = mock_client.chat_update.call_args[1]
assert "Denied by alice" in update_kwargs["text"]
# ===========================================================================
# _fetch_thread_context
# ===========================================================================
class TestSlackThreadContext:
"""Test thread context fetching."""
@pytest.mark.asyncio
async def test_fetches_and_formats_context(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.conversations_replies = AsyncMock(return_value={
"messages": [
{"ts": "1000.0", "user": "U1", "text": "This is the parent message"},
{"ts": "1000.1", "user": "U2", "text": "I think we should refactor"},
{"ts": "1000.2", "user": "U1", "text": "Good idea, <@U_BOT> what do you think?"},
]
})
# Mock user name resolution
adapter._user_name_cache = {"U1": "Alice", "U2": "Bob"}
context = await adapter._fetch_thread_context(
channel_id="C1",
thread_ts="1000.0",
current_ts="1000.2", # The message that triggered the fetch
team_id="T1",
)
assert "[Thread context" in context
assert "[thread parent] Alice: This is the parent message" in context
assert "Bob: I think we should refactor" in context
# Current message should be excluded
assert "what do you think" not in context
# Bot mention should be stripped from context
assert "<@U_BOT>" not in context
@pytest.mark.asyncio
async def test_skips_bot_messages(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.conversations_replies = AsyncMock(return_value={
"messages": [
{"ts": "1000.0", "user": "U1", "text": "Parent"},
{"ts": "1000.1", "bot_id": "B1", "text": "Bot reply (should be skipped)"},
{"ts": "1000.2", "user": "U1", "text": "Current"},
]
})
adapter._user_name_cache = {"U1": "Alice"}
context = await adapter._fetch_thread_context(
channel_id="C1", thread_ts="1000.0", current_ts="1000.2", team_id="T1"
)
assert "Bot reply" not in context
assert "Alice: Parent" in context
@pytest.mark.asyncio
async def test_empty_thread(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.conversations_replies = AsyncMock(return_value={"messages": []})
context = await adapter._fetch_thread_context(
channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
)
assert context == ""
@pytest.mark.asyncio
async def test_api_failure_returns_empty(self):
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.conversations_replies = AsyncMock(side_effect=Exception("API error"))
context = await adapter._fetch_thread_context(
channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
)
assert context == ""
# ===========================================================================
# _has_active_session_for_thread — session key fix (#5833)
# ===========================================================================
class TestSessionKeyFix:
"""Test that _has_active_session_for_thread uses build_session_key."""
def test_uses_build_session_key(self):
"""Verify the fix uses build_session_key instead of manual key construction."""
adapter = _make_adapter()
# Mock session store with a known entry
mock_store = MagicMock()
mock_store._entries = {
"agent:main:slack:group:C1:1000.0": MagicMock()
}
mock_store._ensure_loaded = MagicMock()
mock_store.config = MagicMock()
mock_store.config.group_sessions_per_user = False # threads don't include user_id
mock_store.config.thread_sessions_per_user = False
adapter._session_store = mock_store
# With the fix, build_session_key should be called which respects
# group_sessions_per_user=False (no user_id appended)
result = adapter._has_active_session_for_thread(
channel_id="C1", thread_ts="1000.0", user_id="U123"
)
# Should find the session because build_session_key with
# group_sessions_per_user=False doesn't append user_id
assert result is True
def test_no_session_returns_false(self):
adapter = _make_adapter()
mock_store = MagicMock()
mock_store._entries = {}
mock_store._ensure_loaded = MagicMock()
mock_store.config = MagicMock()
mock_store.config.group_sessions_per_user = True
mock_store.config.thread_sessions_per_user = False
adapter._session_store = mock_store
result = adapter._has_active_session_for_thread(
channel_id="C1", thread_ts="1000.0", user_id="U123"
)
assert result is False
def test_no_session_store(self):
adapter = _make_adapter()
# No _session_store attribute
result = adapter._has_active_session_for_thread(
channel_id="C1", thread_ts="1000.0", user_id="U123"
)
assert result is False
# ===========================================================================
# Thread engagement — bot-started threads & mentioned threads
# ===========================================================================
class TestThreadEngagement:
"""Test _bot_message_ts and _mentioned_threads tracking."""
@pytest.mark.asyncio
async def test_send_tracks_bot_message_ts(self):
"""Bot's sent messages are tracked so thread replies work without @mention."""
adapter = _make_adapter()
mock_client = adapter._team_clients["T1"]
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "9000.1"})
await adapter.send(chat_id="C1", content="Hello!", metadata={"thread_id": "8000.0"})
assert "9000.1" in adapter._bot_message_ts
# Thread root should also be tracked
assert "8000.0" in adapter._bot_message_ts
@pytest.mark.asyncio
async def test_bot_message_ts_cap(self):
"""Verify memory is bounded when many messages are sent."""
adapter = _make_adapter()
adapter._BOT_TS_MAX = 10 # low cap for testing
mock_client = adapter._team_clients["T1"]
for i in range(20):
mock_client.chat_postMessage = AsyncMock(return_value={"ts": f"{i}.0"})
await adapter.send(chat_id="C1", content=f"msg {i}")
assert len(adapter._bot_message_ts) <= 10
def test_mentioned_threads_populated_on_mention(self):
"""When bot is @mentioned in a thread, that thread is tracked."""
adapter = _make_adapter()
# Simulate what _handle_slack_message does on mention
adapter._mentioned_threads.add("1000.0")
assert "1000.0" in adapter._mentioned_threads
def test_mentioned_threads_cap(self):
"""Verify _mentioned_threads is bounded."""
adapter = _make_adapter()
adapter._MENTIONED_THREADS_MAX = 10
for i in range(15):
adapter._mentioned_threads.add(f"{i}.0")
if len(adapter._mentioned_threads) > adapter._MENTIONED_THREADS_MAX:
to_remove = list(adapter._mentioned_threads)[:adapter._MENTIONED_THREADS_MAX // 2]
for t in to_remove:
adapter._mentioned_threads.discard(t)
assert len(adapter._mentioned_threads) <= 10

View File

@@ -51,7 +51,8 @@ def _make_runner(session_entry: SessionEntry):
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._session_db = MagicMock()
runner._session_db.get_session_title.return_value = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
@@ -82,12 +83,34 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
result = await runner._handle_message(_make_event("/status"))
assert "**Session ID:** `sess-1`" in result
assert "**Tokens:** 321" in result
assert "**Agent Running:** Yes ⚡" in result
assert "**Title:**" not in result
running_agent.interrupt.assert_not_called()
assert runner._pending_messages == {}
@pytest.mark.asyncio
async def test_status_command_includes_session_title_when_present():
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=321,
)
runner = _make_runner(session_entry)
runner._session_db.get_session_title.return_value = "My titled session"
result = await runner._handle_message(_make_event("/status"))
assert "**Session ID:** `sess-1`" in result
assert "**Title:** My titled session" in result
@pytest.mark.asyncio
async def test_handle_message_persists_agent_token_counts(monkeypatch):
import gateway.run as gateway_run

View File

@@ -177,3 +177,238 @@ class TestStreamRunMediaStripping:
assert "MEDIA:" not in sent_text, f"MEDIA: leaked into display: {sent_text!r}"
assert consumer.already_sent
# ── Segment break (tool boundary) tests ──────────────────────────────────
class TestSegmentBreakOnToolBoundary:
"""Verify that on_delta(None) finalizes the current message and starts a
new one so the final response appears below tool-progress messages."""
@pytest.mark.asyncio
async def test_segment_break_creates_new_message(self):
"""After a None boundary, next text creates a fresh message."""
adapter = MagicMock()
send_result_1 = SimpleNamespace(success=True, message_id="msg_1")
send_result_2 = SimpleNamespace(success=True, message_id="msg_2")
edit_result = SimpleNamespace(success=True)
adapter.send = AsyncMock(side_effect=[send_result_1, send_result_2])
adapter.edit_message = AsyncMock(return_value=edit_result)
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
# Phase 1: intermediate text before tool calls
consumer.on_delta("Let me search for that...")
# Tool boundary — model is about to call tools
consumer.on_delta(None)
# Phase 2: final response text after tools finished
consumer.on_delta("Here are the results.")
consumer.finish()
await consumer.run()
# Should have sent TWO separate messages (two adapter.send calls),
# not just edited the first one.
assert adapter.send.call_count == 2
first_text = adapter.send.call_args_list[0][1]["content"]
second_text = adapter.send.call_args_list[1][1]["content"]
assert "search" in first_text
assert "results" in second_text
@pytest.mark.asyncio
async def test_segment_break_no_text_before(self):
"""A None boundary with no preceding text is a no-op."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
# No text before the boundary — model went straight to tool calls
consumer.on_delta(None)
consumer.on_delta("Final answer.")
consumer.finish()
await consumer.run()
# Only one send call (the final answer)
assert adapter.send.call_count == 1
assert "Final answer" in adapter.send.call_args_list[0][1]["content"]
@pytest.mark.asyncio
async def test_segment_break_removes_cursor(self):
"""The finalized segment message should not have a cursor."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
edit_result = SimpleNamespace(success=True)
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=edit_result)
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="")
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Thinking...")
consumer.on_delta(None)
consumer.on_delta("Done.")
consumer.finish()
await consumer.run()
# The first segment should have been finalized without cursor.
# Check all edit_message calls + the initial send for the first segment.
# The last state of msg_1 should NOT have the cursor.
all_texts = []
for call in adapter.send.call_args_list:
all_texts.append(call[1].get("content", ""))
for call in adapter.edit_message.call_args_list:
all_texts.append(call[1].get("content", ""))
# Find the text(s) that contain "Thinking" — the finalized version
# should not have the cursor.
thinking_texts = [t for t in all_texts if "Thinking" in t]
assert thinking_texts, "Expected at least one message with 'Thinking'"
# The LAST occurrence is the finalized version
assert "" not in thinking_texts[-1], (
f"Cursor found in finalized segment: {thinking_texts[-1]!r}"
)
@pytest.mark.asyncio
async def test_multiple_segment_breaks(self):
"""Multiple tool boundaries create multiple message segments."""
adapter = MagicMock()
msg_counter = iter(["msg_1", "msg_2", "msg_3"])
adapter.send = AsyncMock(
side_effect=lambda **kw: SimpleNamespace(success=True, message_id=next(msg_counter))
)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Phase 1")
consumer.on_delta(None) # tool boundary
consumer.on_delta("Phase 2")
consumer.on_delta(None) # another tool boundary
consumer.on_delta("Phase 3")
consumer.finish()
await consumer.run()
# Three separate messages
assert adapter.send.call_count == 3
@pytest.mark.asyncio
async def test_already_sent_stays_true_after_segment(self):
"""already_sent remains True after a segment break."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Text")
consumer.on_delta(None)
consumer.finish()
await consumer.run()
assert consumer.already_sent
@pytest.mark.asyncio
async def test_edit_failure_sends_only_unsent_tail_at_finish(self):
"""If an edit fails mid-stream, send only the missing tail once at finish."""
adapter = MagicMock()
send_results = [
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
]
adapter.send = AsyncMock(side_effect=send_results)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="")
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Hello")
task = asyncio.create_task(consumer.run())
await asyncio.sleep(0.08)
consumer.on_delta(" world")
await asyncio.sleep(0.08)
consumer.finish()
await task
assert adapter.send.call_count == 2
first_text = adapter.send.call_args_list[0][1]["content"]
second_text = adapter.send.call_args_list[1][1]["content"]
assert "Hello" in first_text
assert second_text.strip() == "world"
assert consumer.already_sent
@pytest.mark.asyncio
async def test_segment_break_clears_failed_edit_fallback_state(self):
"""A tool boundary after edit failure must not duplicate the next segment."""
adapter = MagicMock()
send_results = [
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
]
adapter.send = AsyncMock(side_effect=send_results)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="")
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Hello")
task = asyncio.create_task(consumer.run())
await asyncio.sleep(0.08)
consumer.on_delta(" world")
await asyncio.sleep(0.08)
consumer.on_delta(None)
consumer.on_delta("Next segment")
consumer.finish()
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
assert sent_texts == ["Hello ▉", "Next segment"]
@pytest.mark.asyncio
async def test_fallback_final_splits_long_continuation_without_dropping_text(self):
"""Long continuation tails should be chunked when fallback final-send runs."""
adapter = MagicMock()
adapter.send = AsyncMock(side_effect=[
SimpleNamespace(success=True, message_id="msg_1"),
SimpleNamespace(success=True, message_id="msg_2"),
SimpleNamespace(success=True, message_id="msg_3"),
])
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
adapter.MAX_MESSAGE_LENGTH = 610
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="")
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
prefix = "abc"
tail = "x" * 620
consumer.on_delta(prefix)
task = asyncio.create_task(consumer.run())
await asyncio.sleep(0.08)
consumer.on_delta(tail)
await asyncio.sleep(0.08)
consumer.finish()
await task
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
assert len(sent_texts) == 3
assert sent_texts[0].startswith(prefix)
assert sum(len(t) for t in sent_texts[1:]) == len(tail)

View File

@@ -0,0 +1,291 @@
"""Tests for Telegram inline keyboard approval buttons."""
import asyncio
import os
import sys
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Ensure the repo root is importable
# ---------------------------------------------------------------------------
_repo = str(Path(__file__).resolve().parents[2])
if _repo not in sys.path:
sys.path.insert(0, _repo)
# ---------------------------------------------------------------------------
# Minimal Telegram mock so TelegramAdapter can be imported
# ---------------------------------------------------------------------------
def _ensure_telegram_mock():
"""Wire up the minimal mocks required to import TelegramAdapter."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
mod = MagicMock()
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
mod.constants.ParseMode.MARKDOWN = "Markdown"
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
mod.constants.ParseMode.HTML = "HTML"
mod.constants.ChatType.PRIVATE = "private"
mod.constants.ChatType.GROUP = "group"
mod.constants.ChatType.SUPERGROUP = "supergroup"
mod.constants.ChatType.CHANNEL = "channel"
# Provide real exception classes so ``except (NetworkError, ...)`` in
# connect() doesn't blow up under xdist when this mock leaks.
mod.error.NetworkError = type("NetworkError", (OSError,), {})
mod.error.TimedOut = type("TimedOut", (OSError,), {})
mod.error.BadRequest = type("BadRequest", (Exception,), {})
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
sys.modules.setdefault(name, mod)
sys.modules.setdefault("telegram.error", mod.error)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter
from gateway.config import Platform, PlatformConfig
def _make_adapter():
"""Create a TelegramAdapter with mocked internals."""
config = PlatformConfig(enabled=True, token="test-token")
adapter = TelegramAdapter(config)
adapter._bot = AsyncMock()
adapter._app = MagicMock()
return adapter
# ===========================================================================
# send_exec_approval — inline keyboard buttons
# ===========================================================================
class TestTelegramExecApproval:
"""Test the send_exec_approval method sends InlineKeyboard buttons."""
@pytest.mark.asyncio
async def test_sends_inline_keyboard(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
result = await adapter.send_exec_approval(
chat_id="12345",
command="rm -rf /important",
session_key="agent:main:telegram:group:12345:99",
description="dangerous deletion",
)
assert result.success is True
assert result.message_id == "42"
adapter._bot.send_message.assert_called_once()
kwargs = adapter._bot.send_message.call_args[1]
assert kwargs["chat_id"] == 12345
assert "rm -rf /important" in kwargs["text"]
assert "dangerous deletion" in kwargs["text"]
assert kwargs["reply_markup"] is not None # InlineKeyboardMarkup
@pytest.mark.asyncio
async def test_stores_approval_state(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
await adapter.send_exec_approval(
chat_id="12345",
command="echo test",
session_key="my-session-key",
)
# The approval_id should map to the session_key
assert len(adapter._approval_state) == 1
approval_id = list(adapter._approval_state.keys())[0]
assert adapter._approval_state[approval_id] == "my-session-key"
@pytest.mark.asyncio
async def test_sends_in_thread(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
await adapter.send_exec_approval(
chat_id="12345",
command="ls",
session_key="s",
metadata={"thread_id": "999"},
)
kwargs = adapter._bot.send_message.call_args[1]
assert kwargs.get("message_thread_id") == 999
@pytest.mark.asyncio
async def test_not_connected(self):
adapter = _make_adapter()
adapter._bot = None
result = await adapter.send_exec_approval(
chat_id="12345", command="ls", session_key="s"
)
assert result.success is False
@pytest.mark.asyncio
async def test_truncates_long_command(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 1
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
long_cmd = "x" * 5000
await adapter.send_exec_approval(
chat_id="12345", command=long_cmd, session_key="s"
)
kwargs = adapter._bot.send_message.call_args[1]
assert "..." in kwargs["text"]
assert len(kwargs["text"]) < 5000
# ===========================================================================
# _handle_callback_query — approval button clicks
# ===========================================================================
class TestTelegramApprovalCallback:
"""Test the approval callback handling in _handle_callback_query."""
@pytest.mark.asyncio
async def test_resolves_approval_on_click(self):
adapter = _make_adapter()
# Set up approval state
adapter._approval_state[1] = "agent:main:telegram:group:12345:99"
# Mock callback query
query = AsyncMock()
query.data = "ea:once:1"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.first_name = "Norbert"
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_called_once_with("agent:main:telegram:group:12345:99", "once")
query.answer.assert_called_once()
query.edit_message_text.assert_called_once()
# State should be cleaned up
assert 1 not in adapter._approval_state
@pytest.mark.asyncio
async def test_deny_button(self):
adapter = _make_adapter()
adapter._approval_state[2] = "some-session"
query = AsyncMock()
query.data = "ea:deny:2"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.first_name = "Alice"
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_called_once_with("some-session", "deny")
edit_kwargs = query.edit_message_text.call_args[1]
assert "Denied" in edit_kwargs["text"]
@pytest.mark.asyncio
async def test_already_resolved(self):
adapter = _make_adapter()
# No state for approval_id 99 — already resolved
query = AsyncMock()
query.data = "ea:once:99"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.first_name = "Bob"
query.answer = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._handle_callback_query(update, context)
# Should NOT resolve — already handled
mock_resolve.assert_not_called()
# Should still ack with "already resolved" message
query.answer.assert_called_once()
assert "already been resolved" in query.answer.call_args[1]["text"]
@pytest.mark.asyncio
async def test_model_picker_callback_not_affected(self):
"""Ensure model picker callbacks still route correctly."""
adapter = _make_adapter()
query = AsyncMock()
query.data = "mp:some_provider"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
# Model picker callback should be handled (not crash)
# We just verify it doesn't try to resolve an approval
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
with patch.object(adapter, "_handle_model_picker_callback", new_callable=AsyncMock):
await adapter._handle_callback_query(update, context)
mock_resolve.assert_not_called()
@pytest.mark.asyncio
async def test_update_prompt_callback_not_affected(self):
"""Ensure update prompt callbacks still work."""
adapter = _make_adapter()
query = AsyncMock()
query.data = "update_prompt:y"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.id = 123
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
with patch("hermes_constants.get_hermes_home", return_value=Path("/tmp/test")):
try:
await adapter._handle_callback_query(update, context)
except Exception:
pass # May fail on file write, that's fine
# Should NOT have triggered approval resolution
mock_resolve.assert_not_called()

View File

@@ -0,0 +1,77 @@
"""Tests for TelegramPlatform._merge_caption caption deduplication logic."""
import pytest
from gateway.platforms.telegram import TelegramAdapter
merge = TelegramAdapter._merge_caption
class TestMergeCaptionBasic:
def test_no_existing_text(self):
assert merge(None, "Hello") == "Hello"
def test_empty_existing_text(self):
assert merge("", "Hello") == "Hello"
def test_exact_duplicate_dropped(self):
assert merge("Revenue", "Revenue") == "Revenue"
def test_different_captions_merged(self):
result = merge("Q3 Results", "Q4 Projections")
assert result == "Q3 Results\n\nQ4 Projections"
class TestMergeCaptionSubstringBug:
"""These are the exact scenarios that the old substring check got wrong."""
def test_shorter_caption_not_dropped_when_substring(self):
# Bug: "Meeting" in "Meeting agenda" → True → caption was silently lost
result = merge("Meeting agenda", "Meeting")
assert result == "Meeting agenda\n\nMeeting"
def test_longer_caption_not_dropped_when_contains_existing(self):
# "Revenue and Profit" contains "Revenue", but they are different captions
result = merge("Revenue", "Revenue and Profit")
assert result == "Revenue\n\nRevenue and Profit"
def test_prefix_caption_not_dropped(self):
result = merge("Q3 Results - Revenue", "Q3 Results")
assert result == "Q3 Results - Revenue\n\nQ3 Results"
class TestMergeCaptionWhitespace:
def test_trailing_space_treated_as_duplicate(self):
assert merge("Revenue", "Revenue ") == "Revenue"
def test_leading_space_treated_as_duplicate(self):
assert merge("Revenue", " Revenue") == "Revenue"
def test_whitespace_only_new_text_not_added(self):
# strip() makes it empty string → falsy check in callers guards this,
# but _merge_caption itself: strip matches "" which is not in list → would merge.
# Callers already guard with `if event.text:` so this is an edge case.
result = merge("Revenue", " ")
# " ".strip() == "" → not in ["Revenue"] → gets merged (caller guards prevent this)
assert "\n\n" in result or result == "Revenue"
class TestMergeCaptionMultipleItems:
def test_three_unique_captions_all_present(self):
text = merge(None, "A")
text = merge(text, "B")
text = merge(text, "C")
assert text == "A\n\nB\n\nC"
def test_duplicate_in_middle_dropped(self):
text = merge(None, "A")
text = merge(text, "B")
text = merge(text, "A") # duplicate
assert text == "A\n\nB"
def test_album_scenario_revenue_profit(self):
# Album Item 1: "Revenue and Profit", Item 2: "Revenue"
# Old bug: "Revenue" in ["Revenue and Profit"] → True → lost
text = merge(None, "Revenue and Profit")
text = merge(text, "Revenue")
assert text == "Revenue and Profit\n\nRevenue"

View File

@@ -20,8 +20,16 @@ def _ensure_telegram_mock():
telegram_mod.constants.ChatType.CHANNEL = "channel"
telegram_mod.constants.ChatType.PRIVATE = "private"
# Provide real exception classes so ``except (NetworkError, ...)`` in
# connect() doesn't blow up with "catching classes that do not inherit
# from BaseException" when another xdist worker pollutes sys.modules.
telegram_mod.error.NetworkError = type("NetworkError", (OSError,), {})
telegram_mod.error.TimedOut = type("TimedOut", (OSError,), {})
telegram_mod.error.BadRequest = type("BadRequest", (Exception,), {})
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
sys.modules.setdefault(name, telegram_mod)
sys.modules.setdefault("telegram.error", telegram_mod.error)
_ensure_telegram_mock()

View File

@@ -0,0 +1,260 @@
"""Tests for Telegram message reactions tied to processing lifecycle hooks."""
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource
def _make_adapter(**extra_env):
from gateway.platforms.telegram import TelegramAdapter
adapter = object.__new__(TelegramAdapter)
adapter.platform = Platform.TELEGRAM
adapter.config = PlatformConfig(enabled=True, token="fake-token")
adapter._bot = AsyncMock()
adapter._bot.set_message_reaction = AsyncMock()
return adapter
def _make_event(chat_id: str = "123", message_id: str = "456") -> MessageEvent:
return MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type="private",
user_id="42",
user_name="TestUser",
),
message_id=message_id,
)
# ── _reactions_enabled ───────────────────────────────────────────────
def test_reactions_disabled_by_default(monkeypatch):
"""Telegram reactions should be disabled by default."""
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
adapter = _make_adapter()
assert adapter._reactions_enabled() is False
def test_reactions_enabled_when_set_true(monkeypatch):
"""Setting TELEGRAM_REACTIONS=true enables reactions."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
assert adapter._reactions_enabled() is True
def test_reactions_enabled_with_1(monkeypatch):
"""TELEGRAM_REACTIONS=1 enables reactions."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "1")
adapter = _make_adapter()
assert adapter._reactions_enabled() is True
def test_reactions_disabled_with_false(monkeypatch):
"""TELEGRAM_REACTIONS=false disables reactions."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "false")
adapter = _make_adapter()
assert adapter._reactions_enabled() is False
def test_reactions_disabled_with_0(monkeypatch):
"""TELEGRAM_REACTIONS=0 disables reactions."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "0")
adapter = _make_adapter()
assert adapter._reactions_enabled() is False
def test_reactions_disabled_with_no(monkeypatch):
"""TELEGRAM_REACTIONS=no disables reactions."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "no")
adapter = _make_adapter()
assert adapter._reactions_enabled() is False
# ── _set_reaction ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_set_reaction_calls_bot_api(monkeypatch):
"""_set_reaction should call bot.set_message_reaction with correct args."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
result = await adapter._set_reaction("123", "456", "\U0001f440")
assert result is True
adapter._bot.set_message_reaction.assert_awaited_once_with(
chat_id=123,
message_id=456,
reaction="\U0001f440",
)
@pytest.mark.asyncio
async def test_set_reaction_returns_false_without_bot(monkeypatch):
"""_set_reaction should return False when bot is not available."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
adapter._bot = None
result = await adapter._set_reaction("123", "456", "\U0001f440")
assert result is False
@pytest.mark.asyncio
async def test_set_reaction_handles_api_error_gracefully(monkeypatch):
"""API errors during reaction should not propagate."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
adapter._bot.set_message_reaction = AsyncMock(side_effect=RuntimeError("no perms"))
result = await adapter._set_reaction("123", "456", "\U0001f440")
assert result is False
# ── on_processing_start ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_on_processing_start_adds_eyes_reaction(monkeypatch):
"""Processing start should add eyes reaction when enabled."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
event = _make_event()
await adapter.on_processing_start(event)
adapter._bot.set_message_reaction.assert_awaited_once_with(
chat_id=123,
message_id=456,
reaction="\U0001f440",
)
@pytest.mark.asyncio
async def test_on_processing_start_skipped_when_disabled(monkeypatch):
"""Processing start should not react when reactions are disabled."""
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
adapter = _make_adapter()
event = _make_event()
await adapter.on_processing_start(event)
adapter._bot.set_message_reaction.assert_not_awaited()
@pytest.mark.asyncio
async def test_on_processing_start_handles_missing_ids(monkeypatch):
"""Should handle events without chat_id or message_id gracefully."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=SimpleNamespace(chat_id=None),
message_id=None,
)
await adapter.on_processing_start(event)
adapter._bot.set_message_reaction.assert_not_awaited()
# ── on_processing_complete ───────────────────────────────────────────
@pytest.mark.asyncio
async def test_on_processing_complete_success(monkeypatch):
"""Successful processing should set check mark reaction."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
event = _make_event()
await adapter.on_processing_complete(event, success=True)
adapter._bot.set_message_reaction.assert_awaited_once_with(
chat_id=123,
message_id=456,
reaction="\u2705",
)
@pytest.mark.asyncio
async def test_on_processing_complete_failure(monkeypatch):
"""Failed processing should set cross mark reaction."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
event = _make_event()
await adapter.on_processing_complete(event, success=False)
adapter._bot.set_message_reaction.assert_awaited_once_with(
chat_id=123,
message_id=456,
reaction="\u274c",
)
@pytest.mark.asyncio
async def test_on_processing_complete_skipped_when_disabled(monkeypatch):
"""Processing complete should not react when reactions are disabled."""
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
adapter = _make_adapter()
event = _make_event()
await adapter.on_processing_complete(event, success=True)
adapter._bot.set_message_reaction.assert_not_awaited()
# ── config.py bridging ───────────────────────────────────────────────
def test_config_bridges_telegram_reactions(monkeypatch, tmp_path):
"""gateway/config.py bridges telegram.reactions to TELEGRAM_REACTIONS env var."""
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump({
"telegram": {
"reactions": True,
},
}))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Use setenv (not delenv) so monkeypatch registers cleanup even when
# the var doesn't exist yet — load_gateway_config will overwrite it.
monkeypatch.setenv("TELEGRAM_REACTIONS", "")
from gateway.config import load_gateway_config
load_gateway_config()
import os
assert os.getenv("TELEGRAM_REACTIONS") == "true"
def test_config_reactions_env_takes_precedence(monkeypatch, tmp_path):
"""Env var should take precedence over config.yaml for reactions."""
import yaml
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump({
"telegram": {
"reactions": True,
},
}))
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TELEGRAM_REACTIONS", "false")
from gateway.config import load_gateway_config
load_gateway_config()
import os
assert os.getenv("TELEGRAM_REACTIONS") == "false"

View File

@@ -590,8 +590,15 @@ class TestSessionIsolation:
class TestDeliveryCleanup:
@pytest.mark.asyncio
async def test_delivery_info_cleaned_after_send(self):
"""send() pops delivery_info so the entry doesn't leak memory."""
async def test_delivery_info_survives_multiple_sends(self):
"""send() must NOT pop delivery_info.
Interim status messages (fallback notifications, context-pressure
warnings, etc.) flow through the same send() path as the final
response. If the entry were popped on the first send, the final
response would silently downgrade to the ``log`` deliver type.
Regression test for that bug.
"""
adapter = _make_adapter()
chat_id = "webhook:test:d-xyz"
adapter._delivery_info[chat_id] = {
@@ -599,10 +606,40 @@ class TestDeliveryCleanup:
"deliver_extra": {},
"payload": {"x": 1},
}
adapter._delivery_info_created[chat_id] = time.time()
result = await adapter.send(chat_id, "Agent response here")
assert result.success is True
assert chat_id not in adapter._delivery_info
# First send (e.g. an interim status message)
result1 = await adapter.send(chat_id, "Status: switching to fallback")
assert result1.success is True
# Entry must still be present so the final send can read it
assert chat_id in adapter._delivery_info
# Second send (the final agent response)
result2 = await adapter.send(chat_id, "Final agent response")
assert result2.success is True
assert chat_id in adapter._delivery_info
@pytest.mark.asyncio
async def test_delivery_info_pruned_via_ttl(self):
"""Stale delivery_info entries are dropped on the next POST."""
adapter = _make_adapter()
adapter._idempotency_ttl = 60 # short TTL for the test
now = time.time()
# Stale entry — older than TTL
adapter._delivery_info["webhook:test:old"] = {"deliver": "log"}
adapter._delivery_info_created["webhook:test:old"] = now - 120
# Fresh entry — should survive
adapter._delivery_info["webhook:test:new"] = {"deliver": "log"}
adapter._delivery_info_created["webhook:test:new"] = now - 5
adapter._prune_delivery_info(now)
assert "webhook:test:old" not in adapter._delivery_info
assert "webhook:test:old" not in adapter._delivery_info_created
assert "webhook:test:new" in adapter._delivery_info
assert "webhook:test:new" in adapter._delivery_info_created
# ===================================================================
@@ -617,3 +654,107 @@ class TestCheckRequirements:
@patch("gateway.platforms.webhook.AIOHTTP_AVAILABLE", False)
def test_returns_false_without_aiohttp(self):
assert check_webhook_requirements() is False
# ===================================================================
# __raw__ template token
# ===================================================================
class TestRawTemplateToken:
"""Tests for the {__raw__} special token in _render_prompt."""
def test_raw_resolves_to_full_json_payload(self):
"""{__raw__} in a template dumps the entire payload as JSON."""
adapter = _make_adapter()
payload = {"action": "opened", "number": 42}
result = adapter._render_prompt(
"Payload: {__raw__}", payload, "push", "test"
)
expected_json = json.dumps(payload, indent=2)
assert result == f"Payload: {expected_json}"
def test_raw_truncated_at_4000_chars(self):
"""{__raw__} output is truncated at 4000 characters for large payloads."""
adapter = _make_adapter()
# Build a payload whose JSON repr exceeds 4000 chars
payload = {"data": "x" * 5000}
result = adapter._render_prompt("{__raw__}", payload, "push", "test")
assert len(result) <= 4000
def test_raw_mixed_with_other_variables(self):
"""{__raw__} can be mixed with regular template variables."""
adapter = _make_adapter()
payload = {"action": "closed", "number": 7}
result = adapter._render_prompt(
"Action={action} Raw={__raw__}", payload, "push", "test"
)
assert result.startswith("Action=closed Raw=")
assert '"action": "closed"' in result
assert '"number": 7' in result
# ===================================================================
# Cross-platform delivery thread_id passthrough
# ===================================================================
class TestDeliverCrossPlatformThreadId:
"""Tests for thread_id passthrough in _deliver_cross_platform."""
def _setup_adapter_with_mock_target(self):
"""Set up a webhook adapter with a mocked gateway_runner and target adapter."""
adapter = _make_adapter()
mock_target = AsyncMock()
mock_target.send = AsyncMock(return_value=SendResult(success=True))
mock_runner = MagicMock()
mock_runner.adapters = {Platform("telegram"): mock_target}
mock_runner.config.get_home_channel.return_value = None
adapter.gateway_runner = mock_runner
return adapter, mock_target
@pytest.mark.asyncio
async def test_thread_id_passed_as_metadata(self):
"""thread_id from deliver_extra is passed as metadata to adapter.send()."""
adapter, mock_target = self._setup_adapter_with_mock_target()
delivery = {
"deliver_extra": {
"chat_id": "12345",
"thread_id": "999",
}
}
await adapter._deliver_cross_platform("telegram", "hello", delivery)
mock_target.send.assert_awaited_once_with(
"12345", "hello", metadata={"thread_id": "999"}
)
@pytest.mark.asyncio
async def test_message_thread_id_passed_as_thread_id(self):
"""message_thread_id from deliver_extra is mapped to thread_id in metadata."""
adapter, mock_target = self._setup_adapter_with_mock_target()
delivery = {
"deliver_extra": {
"chat_id": "12345",
"message_thread_id": "888",
}
}
await adapter._deliver_cross_platform("telegram", "hello", delivery)
mock_target.send.assert_awaited_once_with(
"12345", "hello", metadata={"thread_id": "888"}
)
@pytest.mark.asyncio
async def test_no_thread_id_sends_no_metadata(self):
"""When no thread_id is present, metadata is None."""
adapter, mock_target = self._setup_adapter_with_mock_target()
delivery = {
"deliver_extra": {
"chat_id": "12345",
}
}
await adapter._deliver_cross_platform("telegram", "hello", delivery)
mock_target.send.assert_awaited_once_with(
"12345", "hello", metadata=None
)

View File

@@ -257,10 +257,11 @@ class TestCrossPlatformDelivery:
assert result.success is True
mock_tg_adapter.send.assert_awaited_once_with(
"12345", "I've acknowledged the alert."
"12345", "I've acknowledged the alert.", metadata=None
)
# Delivery info should be cleaned up
assert chat_id not in adapter._delivery_info
# Delivery info is retained after send() so interim status messages
# don't strand the final response (TTL-based cleanup happens on POST).
assert chat_id in adapter._delivery_info
# ===================================================================
@@ -333,5 +334,6 @@ class TestGitHubCommentDelivery:
text=True,
timeout=30,
)
# Delivery info cleaned up
assert chat_id not in adapter._delivery_info
# Delivery info is retained after send() so interim status messages
# don't strand the final response (TTL-based cleanup happens on POST).
assert chat_id in adapter._delivery_info

View File

@@ -40,6 +40,7 @@ def test_run_anthropic_oauth_flow_manual_token_still_persists(tmp_path, monkeypa
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
monkeypatch.setattr("agent.anthropic_adapter.is_claude_code_token_valid", lambda creds: False)
monkeypatch.setattr("builtins.input", lambda _prompt="": "sk-ant-oat01-manual-token")
monkeypatch.setattr("getpass.getpass", lambda _prompt="": "sk-ant-oat01-manual-token")
from hermes_cli.main import _run_anthropic_oauth_flow

View File

@@ -350,6 +350,7 @@ class TestResolveApiKeyProviderCredentials:
def test_resolve_zai_with_key(self, monkeypatch):
monkeypatch.setenv("GLM_API_KEY", "glm-secret-key")
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
creds = resolve_api_key_provider_credentials("zai")
assert creds["provider"] == "zai"
assert creds["api_key"] == "glm-secret-key"
@@ -471,6 +472,7 @@ class TestResolveApiKeyProviderCredentials:
"""GLM_API_KEY takes priority over ZAI_API_KEY."""
monkeypatch.setenv("GLM_API_KEY", "primary")
monkeypatch.setenv("ZAI_API_KEY", "secondary")
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
creds = resolve_api_key_provider_credentials("zai")
assert creds["api_key"] == "primary"
assert creds["source"] == "GLM_API_KEY"
@@ -478,6 +480,7 @@ class TestResolveApiKeyProviderCredentials:
def test_zai_key_fallback(self, monkeypatch):
"""ZAI_API_KEY used when GLM_API_KEY not set."""
monkeypatch.setenv("ZAI_API_KEY", "secondary")
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
creds = resolve_api_key_provider_credentials("zai")
assert creds["api_key"] == "secondary"
assert creds["source"] == "ZAI_API_KEY"
@@ -830,11 +833,58 @@ class TestKimiCodeCredentialAutoDetect:
def test_non_kimi_providers_unaffected(self, monkeypatch):
"""Ensure the auto-detect logic doesn't leak to other providers."""
monkeypatch.setenv("GLM_API_KEY", "sk-kimi-looks-like-kimi-but-isnt")
monkeypatch.setenv("GLM_API_KEY", "sk-kim...isnt")
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
creds = resolve_api_key_provider_credentials("zai")
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
class TestZaiEndpointAutoDetect:
"""Test that resolve_api_key_provider_credentials auto-detects Z.AI endpoints."""
def test_probe_success_returns_detected_url(self, monkeypatch):
monkeypatch.setenv("GLM_API_KEY", "glm-coding-key")
monkeypatch.setattr(
"hermes_cli.auth.detect_zai_endpoint",
lambda *a, **kw: {
"id": "coding-global",
"base_url": "https://api.z.ai/api/coding/paas/v4",
"model": "glm-4.7",
"label": "Global (Coding Plan)",
},
)
creds = resolve_api_key_provider_credentials("zai")
assert creds["base_url"] == "https://api.z.ai/api/coding/paas/v4"
def test_probe_failure_falls_back_to_default(self, monkeypatch):
monkeypatch.setenv("GLM_API_KEY", "glm-key")
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
creds = resolve_api_key_provider_credentials("zai")
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
def test_env_override_skips_probe(self, monkeypatch):
"""GLM_BASE_URL should always win without probing."""
monkeypatch.setenv("GLM_API_KEY", "glm-key")
monkeypatch.setenv("GLM_BASE_URL", "https://custom.example/v4")
probe_called = False
def _never_called(*a, **kw):
nonlocal probe_called
probe_called = True
return None
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", _never_called)
creds = resolve_api_key_provider_credentials("zai")
assert creds["base_url"] == "https://custom.example/v4"
assert not probe_called
def test_no_key_skips_probe(self, monkeypatch):
"""Without an API key, no probe should occur."""
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
creds = resolve_api_key_provider_credentials("zai")
assert creds["api_key"] == ""
# =============================================================================
# Kimi / Moonshot model list isolation tests
# =============================================================================

View File

@@ -0,0 +1,399 @@
"""Tests for Qwen OAuth provider authentication (hermes_cli/auth.py).
Covers: _qwen_cli_auth_path, _read_qwen_cli_tokens, _save_qwen_cli_tokens,
_qwen_access_token_is_expiring, _refresh_qwen_cli_tokens,
resolve_qwen_runtime_credentials, get_qwen_auth_status.
"""
import json
import os
import stat
import time
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from hermes_cli.auth import (
AuthError,
DEFAULT_QWEN_BASE_URL,
QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
_qwen_cli_auth_path,
_read_qwen_cli_tokens,
_save_qwen_cli_tokens,
_qwen_access_token_is_expiring,
_refresh_qwen_cli_tokens,
resolve_qwen_runtime_credentials,
get_qwen_auth_status,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_qwen_tokens(
access_token="test-access-token",
refresh_token="test-refresh-token",
expiry_date=None,
**extra,
):
"""Create a minimal Qwen CLI OAuth credential dict."""
if expiry_date is None:
# 1 hour from now in milliseconds
expiry_date = int((time.time() + 3600) * 1000)
data = {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
"expiry_date": expiry_date,
"resource_url": "portal.qwen.ai",
}
data.update(extra)
return data
def _write_qwen_creds(tmp_path, tokens=None):
"""Write tokens to the Qwen CLI credentials file and return the path."""
qwen_dir = tmp_path / ".qwen"
qwen_dir.mkdir(parents=True, exist_ok=True)
creds_path = qwen_dir / "oauth_creds.json"
if tokens is None:
tokens = _make_qwen_tokens()
creds_path.write_text(json.dumps(tokens), encoding="utf-8")
return creds_path
@pytest.fixture()
def qwen_env(tmp_path, monkeypatch):
"""Redirect _qwen_cli_auth_path to tmp_path/.qwen/oauth_creds.json."""
creds_path = tmp_path / ".qwen" / "oauth_creds.json"
monkeypatch.setattr(
"hermes_cli.auth._qwen_cli_auth_path", lambda: creds_path
)
return tmp_path
# ---------------------------------------------------------------------------
# _qwen_cli_auth_path
# ---------------------------------------------------------------------------
def test_qwen_cli_auth_path_returns_expected_location():
path = _qwen_cli_auth_path()
assert path == Path.home() / ".qwen" / "oauth_creds.json"
# ---------------------------------------------------------------------------
# _read_qwen_cli_tokens
# ---------------------------------------------------------------------------
def test_read_qwen_cli_tokens_success(qwen_env):
tokens = _make_qwen_tokens(access_token="my-access")
_write_qwen_creds(qwen_env, tokens)
result = _read_qwen_cli_tokens()
assert result["access_token"] == "my-access"
assert result["refresh_token"] == "test-refresh-token"
def test_read_qwen_cli_tokens_missing_file(qwen_env):
with pytest.raises(AuthError) as exc:
_read_qwen_cli_tokens()
assert exc.value.code == "qwen_auth_missing"
def test_read_qwen_cli_tokens_invalid_json(qwen_env):
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
creds_path.parent.mkdir(parents=True, exist_ok=True)
creds_path.write_text("not json{{{", encoding="utf-8")
with pytest.raises(AuthError) as exc:
_read_qwen_cli_tokens()
assert exc.value.code == "qwen_auth_read_failed"
def test_read_qwen_cli_tokens_non_dict(qwen_env):
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
creds_path.parent.mkdir(parents=True, exist_ok=True)
creds_path.write_text(json.dumps(["a", "b"]), encoding="utf-8")
with pytest.raises(AuthError) as exc:
_read_qwen_cli_tokens()
assert exc.value.code == "qwen_auth_invalid"
# ---------------------------------------------------------------------------
# _save_qwen_cli_tokens
# ---------------------------------------------------------------------------
def test_save_qwen_cli_tokens_roundtrip(qwen_env):
tokens = _make_qwen_tokens(access_token="saved-token")
saved_path = _save_qwen_cli_tokens(tokens)
assert saved_path.exists()
loaded = json.loads(saved_path.read_text(encoding="utf-8"))
assert loaded["access_token"] == "saved-token"
def test_save_qwen_cli_tokens_creates_parent(qwen_env):
tokens = _make_qwen_tokens()
saved_path = _save_qwen_cli_tokens(tokens)
assert saved_path.parent.exists()
def test_save_qwen_cli_tokens_permissions(qwen_env):
tokens = _make_qwen_tokens()
saved_path = _save_qwen_cli_tokens(tokens)
mode = saved_path.stat().st_mode
assert mode & stat.S_IRUSR # owner read
assert mode & stat.S_IWUSR # owner write
assert not (mode & stat.S_IRGRP) # no group read
assert not (mode & stat.S_IROTH) # no other read
# ---------------------------------------------------------------------------
# _qwen_access_token_is_expiring
# ---------------------------------------------------------------------------
def test_expiring_token_not_expired():
# 1 hour from now in milliseconds
future_ms = int((time.time() + 3600) * 1000)
assert not _qwen_access_token_is_expiring(future_ms)
def test_expiring_token_already_expired():
# 1 hour ago in milliseconds
past_ms = int((time.time() - 3600) * 1000)
assert _qwen_access_token_is_expiring(past_ms)
def test_expiring_token_within_skew():
# Just inside the default skew window
near_ms = int((time.time() + QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS - 5) * 1000)
assert _qwen_access_token_is_expiring(near_ms)
def test_expiring_token_none_returns_true():
assert _qwen_access_token_is_expiring(None)
def test_expiring_token_non_numeric_returns_true():
assert _qwen_access_token_is_expiring("not-a-number")
# ---------------------------------------------------------------------------
# _refresh_qwen_cli_tokens
# ---------------------------------------------------------------------------
def test_refresh_qwen_cli_tokens_success(qwen_env):
tokens = _make_qwen_tokens(refresh_token="old-refresh")
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {
"access_token": "new-access",
"refresh_token": "new-refresh",
"expires_in": 7200,
}
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
result = _refresh_qwen_cli_tokens(tokens)
assert result["access_token"] == "new-access"
assert result["refresh_token"] == "new-refresh"
assert "expiry_date" in result
def test_refresh_qwen_cli_tokens_preserves_old_refresh_if_not_in_response(qwen_env):
tokens = _make_qwen_tokens(refresh_token="keep-me")
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {
"access_token": "new-access",
# No refresh_token in response — should keep old one
"expires_in": 3600,
}
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
result = _refresh_qwen_cli_tokens(tokens)
assert result["refresh_token"] == "keep-me"
def test_refresh_qwen_cli_tokens_missing_refresh_token():
tokens = {"access_token": "at", "refresh_token": ""}
with pytest.raises(AuthError) as exc:
_refresh_qwen_cli_tokens(tokens)
assert exc.value.code == "qwen_refresh_token_missing"
def test_refresh_qwen_cli_tokens_http_error(qwen_env):
tokens = _make_qwen_tokens()
resp = MagicMock()
resp.status_code = 401
resp.text = "unauthorized"
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
with pytest.raises(AuthError) as exc:
_refresh_qwen_cli_tokens(tokens)
assert exc.value.code == "qwen_refresh_failed"
def test_refresh_qwen_cli_tokens_network_error(qwen_env):
tokens = _make_qwen_tokens()
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.side_effect = ConnectionError("timeout")
with pytest.raises(AuthError) as exc:
_refresh_qwen_cli_tokens(tokens)
assert exc.value.code == "qwen_refresh_failed"
def test_refresh_qwen_cli_tokens_invalid_json_response(qwen_env):
tokens = _make_qwen_tokens()
resp = MagicMock()
resp.status_code = 200
resp.json.side_effect = ValueError("bad json")
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
with pytest.raises(AuthError) as exc:
_refresh_qwen_cli_tokens(tokens)
assert exc.value.code == "qwen_refresh_invalid_json"
def test_refresh_qwen_cli_tokens_missing_access_token_in_response(qwen_env):
tokens = _make_qwen_tokens()
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"something": "but no access_token"}
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
with pytest.raises(AuthError) as exc:
_refresh_qwen_cli_tokens(tokens)
assert exc.value.code == "qwen_refresh_invalid_response"
def test_refresh_qwen_cli_tokens_default_expires_in(qwen_env):
"""When expires_in is missing, default to 6 hours."""
tokens = _make_qwen_tokens()
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"access_token": "new"}
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
result = _refresh_qwen_cli_tokens(tokens)
# Verify expiry_date is roughly now + 6h (within 60s tolerance)
expected_ms = int(time.time() * 1000) + 6 * 60 * 60 * 1000
assert abs(result["expiry_date"] - expected_ms) < 60_000
def test_refresh_qwen_cli_tokens_saves_to_disk(qwen_env):
tokens = _make_qwen_tokens()
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {
"access_token": "disk-check",
"expires_in": 3600,
}
with patch("hermes_cli.auth.httpx") as mock_httpx:
mock_httpx.post.return_value = resp
_refresh_qwen_cli_tokens(tokens)
# Verify it was persisted
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
assert creds_path.exists()
saved = json.loads(creds_path.read_text(encoding="utf-8"))
assert saved["access_token"] == "disk-check"
# ---------------------------------------------------------------------------
# resolve_qwen_runtime_credentials
# ---------------------------------------------------------------------------
def test_resolve_qwen_runtime_credentials_fresh_token(qwen_env):
tokens = _make_qwen_tokens(access_token="fresh-at")
_write_qwen_creds(qwen_env, tokens)
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
assert creds["provider"] == "qwen-oauth"
assert creds["api_key"] == "fresh-at"
assert creds["base_url"] == DEFAULT_QWEN_BASE_URL
assert creds["source"] == "qwen-cli"
def test_resolve_qwen_runtime_credentials_triggers_refresh(qwen_env):
# Write an expired token
expired_ms = int((time.time() - 3600) * 1000)
tokens = _make_qwen_tokens(access_token="old", expiry_date=expired_ms)
_write_qwen_creds(qwen_env, tokens)
refreshed = _make_qwen_tokens(access_token="refreshed-at")
with patch(
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
) as mock_refresh:
creds = resolve_qwen_runtime_credentials()
mock_refresh.assert_called_once()
assert creds["api_key"] == "refreshed-at"
def test_resolve_qwen_runtime_credentials_force_refresh(qwen_env):
tokens = _make_qwen_tokens(access_token="old-at")
_write_qwen_creds(qwen_env, tokens)
refreshed = _make_qwen_tokens(access_token="force-refreshed")
with patch(
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
) as mock_refresh:
creds = resolve_qwen_runtime_credentials(force_refresh=True)
mock_refresh.assert_called_once()
assert creds["api_key"] == "force-refreshed"
def test_resolve_qwen_runtime_credentials_missing_access_token(qwen_env):
tokens = _make_qwen_tokens(access_token="")
_write_qwen_creds(qwen_env, tokens)
with pytest.raises(AuthError) as exc:
resolve_qwen_runtime_credentials(refresh_if_expiring=False)
assert exc.value.code == "qwen_access_token_missing"
def test_resolve_qwen_runtime_credentials_base_url_env_override(qwen_env, monkeypatch):
tokens = _make_qwen_tokens(access_token="at")
_write_qwen_creds(qwen_env, tokens)
monkeypatch.setenv("HERMES_QWEN_BASE_URL", "https://custom.qwen.ai/v1")
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
assert creds["base_url"] == "https://custom.qwen.ai/v1"
# ---------------------------------------------------------------------------
# get_qwen_auth_status
# ---------------------------------------------------------------------------
def test_get_qwen_auth_status_logged_in(qwen_env):
tokens = _make_qwen_tokens(access_token="status-at")
_write_qwen_creds(qwen_env, tokens)
status = get_qwen_auth_status()
assert status["logged_in"] is True
assert status["api_key"] == "status-at"
def test_get_qwen_auth_status_not_logged_in(qwen_env):
# No credentials file
status = get_qwen_auth_status()
assert status["logged_in"] is False
assert "error" in status

View File

@@ -0,0 +1,63 @@
from unittest.mock import MagicMock, patch
def test_format_banner_version_label_without_git_state():
from hermes_cli import banner
with patch.object(banner, "get_git_banner_state", return_value=None):
value = banner.format_banner_version_label()
assert value == f"Hermes Agent v{banner.VERSION} ({banner.RELEASE_DATE})"
def test_format_banner_version_label_on_upstream_main():
from hermes_cli import banner
with patch.object(
banner,
"get_git_banner_state",
return_value={"upstream": "b2f477a3", "local": "b2f477a3", "ahead": 0},
):
value = banner.format_banner_version_label()
assert value.endswith("· upstream b2f477a3")
assert "local" not in value
def test_format_banner_version_label_with_carried_commits():
from hermes_cli import banner
with patch.object(
banner,
"get_git_banner_state",
return_value={"upstream": "b2f477a3", "local": "af8aad31", "ahead": 3},
):
value = banner.format_banner_version_label()
assert "upstream b2f477a3" in value
assert "local af8aad31" in value
assert "+3 carried commits" in value
def test_get_git_banner_state_reads_origin_and_head(tmp_path):
from hermes_cli import banner
repo_dir = tmp_path / "repo"
(repo_dir / ".git").mkdir(parents=True)
results = {
("git", "rev-parse", "--short=8", "origin/main"): MagicMock(returncode=0, stdout="b2f477a3\n"),
("git", "rev-parse", "--short=8", "HEAD"): MagicMock(returncode=0, stdout="af8aad31\n"),
("git", "rev-list", "--count", "origin/main..HEAD"): MagicMock(returncode=0, stdout="3\n"),
}
def fake_run(cmd, **kwargs):
key = tuple(cmd)
if key not in results:
raise AssertionError(f"unexpected command: {cmd}")
return results[key]
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
state = banner.get_git_banner_state(repo_dir)
assert state == {"upstream": "b2f477a3", "local": "af8aad31", "ahead": 3}

View File

@@ -136,3 +136,73 @@ def test_check_gateway_service_linger_skips_when_service_not_installed(monkeypat
out = capsys.readouterr().out
assert out == ""
assert issues == []
# ── Memory provider section (doctor should only check the *active* provider) ──
class TestDoctorMemoryProviderSection:
"""The ◆ Memory Provider section should respect memory.provider config."""
def _make_hermes_home(self, tmp_path, provider=""):
"""Create a minimal HERMES_HOME with config.yaml."""
home = tmp_path / ".hermes"
home.mkdir(parents=True, exist_ok=True)
import yaml
config = {"memory": {"provider": provider}} if provider else {"memory": {}}
(home / "config.yaml").write_text(yaml.dump(config))
return home
def _run_doctor_and_capture(self, monkeypatch, tmp_path, provider=""):
"""Run doctor and capture stdout."""
home = self._make_hermes_home(tmp_path, provider)
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", tmp_path / "project")
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
(tmp_path / "project").mkdir(exist_ok=True)
# Stub tool availability (returns empty) so doctor runs past it
fake_model_tools = types.SimpleNamespace(
check_tool_availability=lambda *a, **kw: ([], []),
TOOLSET_REQUIREMENTS={},
)
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
# Stub auth checks to avoid real API calls
try:
from hermes_cli import auth as _auth_mod
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {})
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
except Exception:
pass
import io, contextlib
buf = io.StringIO()
with contextlib.redirect_stdout(buf):
doctor_mod.run_doctor(Namespace(fix=False))
return buf.getvalue()
def test_no_provider_shows_builtin_ok(self, monkeypatch, tmp_path):
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="")
assert "Memory Provider" in out
assert "Built-in memory active" in out
# Should NOT mention Honcho or Mem0 errors
assert "Honcho API key" not in out
assert "Mem0" not in out
def test_honcho_provider_not_installed_shows_fail(self, monkeypatch, tmp_path):
# Make honcho import fail
monkeypatch.setitem(
sys.modules, "plugins.memory.honcho.client", None
)
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="honcho")
assert "Memory Provider" in out
# Should show failure since honcho is set but not importable
assert "Built-in memory active" not in out
def test_mem0_provider_not_installed_shows_fail(self, monkeypatch, tmp_path):
# Make mem0 import fail
monkeypatch.setitem(sys.modules, "plugins.memory.mem0", None)
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="mem0")
assert "Memory Provider" in out
assert "Built-in memory active" not in out

View File

@@ -641,3 +641,69 @@ class TestEnsureUserSystemdEnv:
result = gateway_cli._systemctl_cmd(system=True)
assert result == ["systemctl"]
assert calls == []
class TestProfileArg:
"""Tests for _profile_arg — returns '--profile <name>' for named profiles."""
def test_default_hermes_home_returns_empty(self, tmp_path, monkeypatch):
"""Default ~/.hermes should not produce a --profile flag."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setattr(Path, "home", lambda: tmp_path)
result = gateway_cli._profile_arg(str(hermes_home))
assert result == ""
def test_named_profile_returns_flag(self, tmp_path, monkeypatch):
"""~/.hermes/profiles/mybot should return '--profile mybot'."""
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
profile_dir.mkdir(parents=True)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
result = gateway_cli._profile_arg(str(profile_dir))
assert result == "--profile mybot"
def test_hash_path_returns_empty(self, tmp_path, monkeypatch):
"""Arbitrary non-profile HERMES_HOME should return empty string."""
custom_home = tmp_path / "custom" / "hermes"
custom_home.mkdir(parents=True)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
result = gateway_cli._profile_arg(str(custom_home))
assert result == ""
def test_nested_profile_path_returns_empty(self, tmp_path, monkeypatch):
"""~/.hermes/profiles/mybot/subdir should NOT match — too deep."""
nested = tmp_path / ".hermes" / "profiles" / "mybot" / "subdir"
nested.mkdir(parents=True)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
result = gateway_cli._profile_arg(str(nested))
assert result == ""
def test_invalid_profile_name_returns_empty(self, tmp_path, monkeypatch):
"""Profile names with invalid chars should not match the regex."""
bad_profile = tmp_path / ".hermes" / "profiles" / "My Bot!"
bad_profile.mkdir(parents=True)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
result = gateway_cli._profile_arg(str(bad_profile))
assert result == ""
def test_systemd_unit_includes_profile(self, tmp_path, monkeypatch):
"""generate_systemd_unit should include --profile in ExecStart for named profiles."""
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
profile_dir.mkdir(parents=True)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
unit = gateway_cli.generate_systemd_unit(system=False)
assert "--profile mybot" in unit
assert "gateway run --replace" in unit
def test_launchd_plist_includes_profile(self, tmp_path, monkeypatch):
"""generate_launchd_plist should include --profile in ProgramArguments for named profiles."""
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
profile_dir.mkdir(parents=True)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
plist = gateway_cli.generate_launchd_plist()
assert "<string>--profile</string>" in plist
assert "<string>mybot</string>" in plist

View File

@@ -171,7 +171,11 @@ class TestGeminiModelNormalization:
class TestGeminiContextLength:
def test_gemma_4_31b_context(self):
ctx = get_model_context_length("gemma-4-31b-it", provider="gemini")
# Mock external API lookups to test against hardcoded defaults
# (models.dev and OpenRouter may return different values like 262144).
with patch("agent.models_dev.lookup_models_dev_context", return_value=None), \
patch("agent.model_metadata.fetch_model_metadata", return_value={}):
ctx = get_model_context_length("gemma-4-31b-it", provider="gemini")
assert ctx == 256000
def test_gemma_4_26b_context(self):

View File

@@ -1,6 +1,15 @@
"""Tests for the hermes_cli models module."""
from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model
from unittest.mock import patch, MagicMock
from hermes_cli.models import (
OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model,
filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS,
is_nous_free_tier, partition_nous_models_by_tier,
check_nous_free_tier, clear_nous_free_tier_cache,
_FREE_TIER_CACHE_TTL,
)
import hermes_cli.models as _models_mod
class TestModelIds:
@@ -124,3 +133,226 @@ class TestDetectProviderForModel:
result = detect_provider_for_model("claude-opus-4-6", "openai-codex")
assert result is not None
assert result[0] not in ("nous",) # nous has claude models but shouldn't be suggested
class TestFilterNousFreeModels:
"""Tests for filter_nous_free_models — Nous Portal free-model policy."""
_PAID = {"prompt": "0.000003", "completion": "0.000015"}
_FREE = {"prompt": "0", "completion": "0"}
def test_paid_models_kept(self):
"""Regular paid models pass through unchanged."""
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
pricing = {m: self._PAID for m in models}
assert filter_nous_free_models(models, pricing) == models
def test_free_non_allowlist_models_removed(self):
"""Free models NOT in the allowlist are filtered out."""
models = ["anthropic/claude-opus-4.6", "arcee-ai/trinity-large-preview:free"]
pricing = {
"anthropic/claude-opus-4.6": self._PAID,
"arcee-ai/trinity-large-preview:free": self._FREE,
}
result = filter_nous_free_models(models, pricing)
assert result == ["anthropic/claude-opus-4.6"]
def test_allowlist_model_kept_when_free(self):
"""Allowlist models are kept when they report as free."""
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
pricing = {
"anthropic/claude-opus-4.6": self._PAID,
"xiaomi/mimo-v2-pro": self._FREE,
}
result = filter_nous_free_models(models, pricing)
assert result == ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
def test_allowlist_model_removed_when_paid(self):
"""Allowlist models are removed when they are NOT free."""
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
pricing = {
"anthropic/claude-opus-4.6": self._PAID,
"xiaomi/mimo-v2-pro": self._PAID,
}
result = filter_nous_free_models(models, pricing)
assert result == ["anthropic/claude-opus-4.6"]
def test_no_pricing_returns_all(self):
"""When pricing data is unavailable, all models pass through."""
models = ["anthropic/claude-opus-4.6", "nvidia/nemotron-3-super-120b-a12b:free"]
assert filter_nous_free_models(models, {}) == models
def test_model_with_no_pricing_entry_treated_as_paid(self):
"""A model missing from the pricing dict is kept (assumed paid)."""
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
pricing = {"anthropic/claude-opus-4.6": self._PAID} # gpt-5.4 not in pricing
result = filter_nous_free_models(models, pricing)
assert result == models
def test_mixed_scenario(self):
"""End-to-end: mix of paid, free-allowed, free-disallowed, allowlist-not-free."""
models = [
"anthropic/claude-opus-4.6", # paid, not allowlist → keep
"nvidia/nemotron-3-super-120b-a12b:free", # free, not allowlist → drop
"xiaomi/mimo-v2-pro", # free, allowlist → keep
"xiaomi/mimo-v2-omni", # paid, allowlist → drop
"openai/gpt-5.4", # paid, not allowlist → keep
]
pricing = {
"anthropic/claude-opus-4.6": self._PAID,
"nvidia/nemotron-3-super-120b-a12b:free": self._FREE,
"xiaomi/mimo-v2-pro": self._FREE,
"xiaomi/mimo-v2-omni": self._PAID,
"openai/gpt-5.4": self._PAID,
}
result = filter_nous_free_models(models, pricing)
assert result == [
"anthropic/claude-opus-4.6",
"xiaomi/mimo-v2-pro",
"openai/gpt-5.4",
]
def test_allowlist_contains_expected_models(self):
"""Sanity: the allowlist has the models we expect."""
assert "xiaomi/mimo-v2-pro" in _NOUS_ALLOWED_FREE_MODELS
assert "xiaomi/mimo-v2-omni" in _NOUS_ALLOWED_FREE_MODELS
class TestIsNousFreeTier:
"""Tests for is_nous_free_tier — account tier detection."""
def test_paid_plus_tier(self):
assert is_nous_free_tier({"subscription": {"plan": "Plus", "tier": 2, "monthly_charge": 20}}) is False
def test_free_tier_by_charge(self):
assert is_nous_free_tier({"subscription": {"plan": "Free", "tier": 0, "monthly_charge": 0}}) is True
def test_no_charge_field_not_free(self):
"""Missing monthly_charge defaults to not-free (don't block users)."""
assert is_nous_free_tier({"subscription": {"plan": "Free", "tier": 0}}) is False
def test_plan_name_alone_not_free(self):
"""Plan name alone is not enough — monthly_charge is required."""
assert is_nous_free_tier({"subscription": {"plan": "free"}}) is False
def test_empty_subscription_not_free(self):
"""Empty subscription dict defaults to not-free (don't block users)."""
assert is_nous_free_tier({"subscription": {}}) is False
def test_no_subscription_not_free(self):
"""Missing subscription key returns False."""
assert is_nous_free_tier({}) is False
def test_empty_response_not_free(self):
"""Completely empty response defaults to not-free."""
assert is_nous_free_tier({}) is False
class TestPartitionNousModelsByTier:
"""Tests for partition_nous_models_by_tier — free vs paid tier model split."""
_PAID = {"prompt": "0.000003", "completion": "0.000015"}
_FREE = {"prompt": "0", "completion": "0"}
def test_paid_tier_all_selectable(self):
"""Paid users get all models as selectable, none unavailable."""
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
pricing = {"anthropic/claude-opus-4.6": self._PAID, "xiaomi/mimo-v2-pro": self._FREE}
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=False)
assert sel == models
assert unav == []
def test_free_tier_splits_correctly(self):
"""Free users see only free models; paid ones are unavailable."""
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro", "openai/gpt-5.4"]
pricing = {
"anthropic/claude-opus-4.6": self._PAID,
"xiaomi/mimo-v2-pro": self._FREE,
"openai/gpt-5.4": self._PAID,
}
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
assert sel == ["xiaomi/mimo-v2-pro"]
assert unav == ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
def test_no_pricing_returns_all(self):
"""Without pricing data, all models are selectable."""
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
sel, unav = partition_nous_models_by_tier(models, {}, free_tier=True)
assert sel == models
assert unav == []
def test_all_free_models(self):
"""When all models are free, free-tier users can select all."""
models = ["xiaomi/mimo-v2-pro", "xiaomi/mimo-v2-omni"]
pricing = {m: self._FREE for m in models}
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
assert sel == models
assert unav == []
def test_all_paid_models(self):
"""When all models are paid, free-tier users have none selectable."""
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
pricing = {m: self._PAID for m in models}
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
assert sel == []
assert unav == models
class TestCheckNousFreeTierCache:
"""Tests for the TTL cache on check_nous_free_tier()."""
def setup_method(self):
"""Reset cache before each test."""
clear_nous_free_tier_cache()
def teardown_method(self):
"""Reset cache after each test."""
clear_nous_free_tier_cache()
@patch("hermes_cli.models.fetch_nous_account_tier")
@patch("hermes_cli.models.is_nous_free_tier", return_value=True)
def test_result_is_cached(self, mock_is_free, mock_fetch):
"""Second call within TTL returns cached result without API call."""
mock_fetch.return_value = {"subscription": {"monthly_charge": 0}}
with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \
patch("hermes_cli.auth.resolve_nous_runtime_credentials"):
result1 = check_nous_free_tier()
result2 = check_nous_free_tier()
assert result1 is True
assert result2 is True
# fetch_nous_account_tier should only be called once (cached on second call)
assert mock_fetch.call_count == 1
@patch("hermes_cli.models.fetch_nous_account_tier")
@patch("hermes_cli.models.is_nous_free_tier", return_value=False)
def test_cache_expires_after_ttl(self, mock_is_free, mock_fetch):
"""After TTL expires, the API is called again."""
mock_fetch.return_value = {"subscription": {"monthly_charge": 20}}
with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \
patch("hermes_cli.auth.resolve_nous_runtime_credentials"):
result1 = check_nous_free_tier()
assert mock_fetch.call_count == 1
# Simulate TTL expiry by backdating the cache timestamp
cached_result, cached_at = _models_mod._free_tier_cache
_models_mod._free_tier_cache = (cached_result, cached_at - _FREE_TIER_CACHE_TTL - 1)
result2 = check_nous_free_tier()
assert mock_fetch.call_count == 2
assert result1 is False
assert result2 is False
def test_clear_cache_forces_refresh(self):
"""clear_nous_free_tier_cache() invalidates the cached result."""
# Manually seed the cache
import time
_models_mod._free_tier_cache = (True, time.monotonic())
clear_nous_free_tier_cache()
assert _models_mod._free_tier_cache is None
def test_cache_ttl_is_short(self):
"""TTL should be short enough to catch upgrades quickly (<=5 min)."""
assert _FREE_TIER_CACHE_TTL <= 300

View File

@@ -44,7 +44,62 @@ def test_get_nous_subscription_features_prefers_managed_modal_in_auto_mode(monke
assert features.modal.direct_override is False
def test_get_nous_subscription_features_prefers_camofox_over_managed_browserbase(monkeypatch):
def test_get_nous_subscription_features_marks_browser_use_as_managed_when_gateway_ready(monkeypatch):
monkeypatch.setattr(ns, "get_env_value", lambda name: "")
monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {"logged_in": True})
monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True)
monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "browser")
monkeypatch.setattr(ns, "_has_agent_browser", lambda: True)
monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "")
monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: False)
monkeypatch.setattr(
ns,
"is_managed_tool_gateway_ready",
lambda vendor: vendor == "browser-use",
)
features = ns.get_nous_subscription_features(
{"browser": {"cloud_provider": "browser-use"}}
)
assert features.browser.available is True
assert features.browser.active is True
assert features.browser.managed_by_nous is True
assert features.browser.direct_override is False
assert features.browser.current_provider == "Browser Use"
def test_get_nous_subscription_features_uses_direct_browserbase_when_no_managed_gateway(monkeypatch):
"""When direct Browserbase keys are set and no managed gateway is available,
the unconfigured fallback should pick Browserbase as a direct provider."""
env = {
"BROWSERBASE_API_KEY": "bb-key",
"BROWSERBASE_PROJECT_ID": "bb-project",
}
monkeypatch.setattr(ns, "get_env_value", lambda name: env.get(name, ""))
monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {"logged_in": True})
monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True)
monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "browser")
monkeypatch.setattr(ns, "_has_agent_browser", lambda: True)
monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "")
monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: False)
monkeypatch.setattr(
ns,
"is_managed_tool_gateway_ready",
lambda vendor: False, # No managed gateway available
)
features = ns.get_nous_subscription_features({})
assert features.browser.available is True
assert features.browser.active is True
assert features.browser.managed_by_nous is False
assert features.browser.direct_override is True
assert features.browser.current_provider == "Browserbase"
def test_get_nous_subscription_features_prefers_camofox_over_managed_browser_use(monkeypatch):
env = {"CAMOFOX_URL": "http://localhost:9377"}
monkeypatch.setattr(ns, "get_env_value", lambda name: env.get(name, ""))
@@ -57,11 +112,11 @@ def test_get_nous_subscription_features_prefers_camofox_over_managed_browserbase
monkeypatch.setattr(
ns,
"is_managed_tool_gateway_ready",
lambda vendor: vendor == "browserbase",
lambda vendor: vendor == "browser-use",
)
features = ns.get_nous_subscription_features(
{"browser": {"cloud_provider": "browserbase"}}
{"browser": {"cloud_provider": "browser-use"}}
)
assert features.browser.available is True

Some files were not shown because too many files have changed in this diff Show More