Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
@@ -1276,6 +1276,258 @@ class TestRoleAlternation:
|
||||
assert [m["role"] for m in result] == ["user", "assistant", "user"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking block signature management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThinkingBlockSignatureManagement:
|
||||
"""Tests for the thinking block handling strategy:
|
||||
strip from old turns, preserve latest signed, downgrade unsigned."""
|
||||
|
||||
def test_thinking_stripped_from_non_last_assistant(self):
|
||||
"""Thinking blocks are removed from all assistant messages except the last."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "tool1", "arguments": "{}"}},
|
||||
],
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Old reasoning.", "signature": "sig_old"},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_2", "function": {"name": "tool2", "arguments": "{}"}},
|
||||
],
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Latest reasoning.", "signature": "sig_new"},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
# Find both assistant messages
|
||||
assistants = [m for m in result if m["role"] == "assistant"]
|
||||
assert len(assistants) == 2
|
||||
|
||||
# First (non-last) assistant: no thinking blocks
|
||||
first_types = [b.get("type") for b in assistants[0]["content"]]
|
||||
assert "thinking" not in first_types
|
||||
assert "redacted_thinking" not in first_types
|
||||
assert "tool_use" in first_types # tool_use should survive
|
||||
|
||||
# Last assistant: thinking block preserved with signature
|
||||
last_blocks = assistants[1]["content"]
|
||||
thinking_blocks = [b for b in last_blocks if b.get("type") == "thinking"]
|
||||
assert len(thinking_blocks) == 1
|
||||
assert thinking_blocks[0]["thinking"] == "Latest reasoning."
|
||||
assert thinking_blocks[0]["signature"] == "sig_new"
|
||||
|
||||
def test_signed_thinking_preserved_on_last_turn(self):
|
||||
"""A signed thinking block on the last assistant message is kept."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 42.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Deep thought.", "signature": "sig_valid"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
thinking = [b for b in blocks if b.get("type") == "thinking"]
|
||||
assert len(thinking) == 1
|
||||
assert thinking[0]["signature"] == "sig_valid"
|
||||
|
||||
def test_unsigned_thinking_downgraded_to_text_on_last_turn(self):
|
||||
"""Unsigned thinking blocks on the last turn become text blocks."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response text.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Unsigned reasoning."},
|
||||
# No 'signature' field
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
|
||||
# No thinking blocks should remain
|
||||
assert not any(b.get("type") == "thinking" for b in blocks)
|
||||
# The reasoning text should be preserved as a text block
|
||||
text_contents = [b.get("text", "") for b in blocks if b.get("type") == "text"]
|
||||
assert "Unsigned reasoning." in text_contents
|
||||
|
||||
def test_redacted_thinking_with_data_preserved(self):
|
||||
"""Redacted thinking with 'data' field is kept on last turn."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response.",
|
||||
"reasoning_details": [
|
||||
{"type": "redacted_thinking", "data": "opaque_signature_data"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
redacted = [b for b in blocks if b.get("type") == "redacted_thinking"]
|
||||
assert len(redacted) == 1
|
||||
assert redacted[0]["data"] == "opaque_signature_data"
|
||||
|
||||
def test_redacted_thinking_without_data_dropped(self):
|
||||
"""Redacted thinking without 'data' is dropped — can't be validated."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response.",
|
||||
"reasoning_details": [
|
||||
{"type": "redacted_thinking"},
|
||||
# No 'data' field
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
assert not any(b.get("type") == "redacted_thinking" for b in blocks)
|
||||
|
||||
def test_cache_control_stripped_from_thinking_blocks(self):
|
||||
"""cache_control markers are removed from thinking/redacted_thinking blocks."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "t", "arguments": "{}"}},
|
||||
],
|
||||
"reasoning_details": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Reasoning.",
|
||||
"signature": "sig_1",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assistant = next(m for m in result if m["role"] == "assistant")
|
||||
for block in assistant["content"]:
|
||||
if block.get("type") in ("thinking", "redacted_thinking"):
|
||||
assert "cache_control" not in block
|
||||
|
||||
def test_thinking_stripped_from_merged_consecutive_assistants(self):
|
||||
"""When consecutive assistants are merged, second one's thinking is dropped."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "First response.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "First thought.", "signature": "sig_1"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Second response.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Second thought.", "signature": "sig_2"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
# Should be merged into one assistant message
|
||||
assistants = [m for m in result if m["role"] == "assistant"]
|
||||
assert len(assistants) == 1
|
||||
|
||||
# Only the first thinking block should remain (signed, on the last/only assistant)
|
||||
blocks = assistants[0]["content"]
|
||||
thinking = [b for b in blocks if b.get("type") == "thinking"]
|
||||
assert len(thinking) == 1
|
||||
assert thinking[0]["thinking"] == "First thought."
|
||||
|
||||
def test_empty_content_after_strip_gets_placeholder(self):
|
||||
"""If stripping thinking leaves an empty message, a placeholder is added."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Only thinking, no text."},
|
||||
# Unsigned — will be downgraded, but content was empty string
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Next message."},
|
||||
{"role": "assistant", "content": "Final."},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
# First assistant is non-last, so thinking is stripped completely.
|
||||
# The original content was empty and thinking was unsigned → placeholder
|
||||
first_assistant = result[0]
|
||||
assert first_assistant["role"] == "assistant"
|
||||
assert len(first_assistant["content"]) >= 1
|
||||
|
||||
def test_multi_turn_conversation_preserves_only_last(self):
|
||||
"""Full multi-turn conversation: only last assistant keeps thinking."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Question 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Answer 1",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Thought 1", "signature": "sig_1"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Question 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Answer 2",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Thought 2", "signature": "sig_2"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Question 3"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Answer 3",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Thought 3", "signature": "sig_3"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
assistants = [m for m in result if m["role"] == "assistant"]
|
||||
assert len(assistants) == 3
|
||||
|
||||
# First two: no thinking blocks
|
||||
for a in assistants[:2]:
|
||||
assert not any(
|
||||
b.get("type") in ("thinking", "redacted_thinking")
|
||||
for b in a["content"]
|
||||
if isinstance(b, dict)
|
||||
)
|
||||
|
||||
# Last one: thinking preserved
|
||||
last_thinking = [
|
||||
b for b in assistants[2]["content"]
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert len(last_thinking) == 1
|
||||
assert last_thinking[0]["signature"] == "sig_3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool choice
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -471,6 +471,23 @@ class TestExplicitProviderRouting:
|
||||
client, model = resolve_provider_client("zai")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_google_alias_uses_gemini_credentials(self):
|
||||
"""provider='google' should route through the gemini API-key provider."""
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("google", model="gemini-3.1-pro-preview")
|
||||
|
||||
assert client is not None
|
||||
assert model == "gemini-3.1-pro-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_explicit_unknown_returns_none(self, monkeypatch):
|
||||
"""Unknown provider should return None."""
|
||||
client, model = resolve_provider_client("nonexistent-provider")
|
||||
@@ -624,12 +641,15 @@ class TestVisionClientFallback:
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_auto_includes_anthropic_when_configured(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
|
||||
"""Active provider appears in available backends when credentials exist."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
backends = get_available_vision_backends()
|
||||
|
||||
@@ -702,88 +722,51 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
|
||||
assert call_kwargs["default_headers"]["Editor-Version"]
|
||||
|
||||
def test_vision_auto_uses_anthropic_when_no_higher_priority_backend(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
|
||||
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_anthropic_provider_is_preferred_for_vision_auto(self, monkeypatch):
|
||||
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
|
||||
"""Active provider is tried before OpenRouter in vision auto."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "anthropic", "default": "claude-sonnet-4-6"}}
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_codex_provider_short_circuits_vision_auto(self, monkeypatch):
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "openai-codex", "default": "gpt-5.2-codex"}}
|
||||
|
||||
codex_client = MagicMock()
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(codex_client, "gpt-5.2-codex")) as mock_codex,
|
||||
patch("agent.auxiliary_client._try_openrouter") as mock_openrouter,
|
||||
patch("agent.auxiliary_client._try_nous") as mock_nous,
|
||||
patch("agent.auxiliary_client._try_anthropic") as mock_anthropic,
|
||||
patch("agent.auxiliary_client._try_custom_endpoint") as mock_custom,
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openai-codex"
|
||||
assert client is codex_client
|
||||
assert model == "gpt-5.2-codex"
|
||||
mock_codex.assert_called_once()
|
||||
mock_openrouter.assert_not_called()
|
||||
mock_nous.assert_not_called()
|
||||
mock_anthropic.assert_not_called()
|
||||
mock_custom.assert_not_called()
|
||||
# Active provider should win over OpenRouter
|
||||
assert provider == "anthropic"
|
||||
|
||||
def test_vision_auto_includes_codex(self, codex_auth_dir):
|
||||
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is used as fallback in vision auto mode.
|
||||
|
||||
Many local models (Qwen-VL, LLaVA, etc.) support vision.
|
||||
When no OpenRouter/Nous/Codex is available, try the custom endpoint.
|
||||
"""
|
||||
def test_vision_auto_uses_named_custom_as_active_provider(self, monkeypatch):
|
||||
"""Named custom provider works as active provider fallback in vision auto."""
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("http://localhost:1234/v1", "local-key")), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None # Custom endpoint picked up as fallback
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="custom:local"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="my-local-model"), \
|
||||
patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(MagicMock(), "my-local-model")) as mock_resolve:
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
assert client is not None
|
||||
assert provider == "custom:local"
|
||||
|
||||
def test_vision_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
@@ -822,6 +805,31 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {
|
||||
"provider": "google",
|
||||
"model": "gemini-3.1-pro-preview",
|
||||
}
|
||||
}
|
||||
}
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
resolved_provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert resolved_provider == "gemini"
|
||||
assert client is not None
|
||||
assert model == "gemini-3.1-pro-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
config = {
|
||||
@@ -846,7 +854,14 @@ class TestAuxiliaryPoolAwareness:
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
# Clear client cache to avoid stale entries from previous tests
|
||||
from agent.auxiliary_client import _client_cache
|
||||
_client_cache.clear()
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value=""), \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
@@ -13,7 +13,7 @@ from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
|
||||
def _run_auxiliary_bridge(config_dict, monkeypatch):
|
||||
@@ -199,7 +199,7 @@ class TestGatewayBridgeCodeParity:
|
||||
|
||||
def test_gateway_has_auxiliary_bridge(self):
|
||||
"""The gateway config bridge must include auxiliary.* bridging."""
|
||||
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
|
||||
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
|
||||
content = gateway_path.read_text()
|
||||
# Check for key patterns that indicate the bridge is present
|
||||
assert "AUXILIARY_VISION_PROVIDER" in content
|
||||
@@ -213,7 +213,7 @@ class TestGatewayBridgeCodeParity:
|
||||
|
||||
def test_gateway_no_compression_env_bridge(self):
|
||||
"""Gateway should NOT bridge compression config to env vars (config-only)."""
|
||||
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
|
||||
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
|
||||
content = gateway_path.read_text()
|
||||
assert "CONTEXT_COMPRESSION_PROVIDER" not in content
|
||||
assert "CONTEXT_COMPRESSION_MODEL" not in content
|
||||
151
tests/agent/test_auxiliary_named_custom_providers.py
Normal file
151
tests/agent/test_auxiliary_named_custom_providers.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for named custom provider and 'main' alias resolution in auxiliary_client."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate(tmp_path, monkeypatch):
|
||||
"""Redirect HERMES_HOME and clear module caches."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# Write a minimal config so load_config doesn't fail
|
||||
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
|
||||
|
||||
|
||||
def _write_config(tmp_path, config_dict):
|
||||
"""Write a config.yaml to the test HERMES_HOME."""
|
||||
import yaml
|
||||
config_path = tmp_path / ".hermes" / "config.yaml"
|
||||
config_path.write_text(yaml.dump(config_dict))
|
||||
|
||||
|
||||
class TestNormalizeVisionProvider:
|
||||
"""_normalize_vision_provider should resolve 'main' to actual main provider."""
|
||||
|
||||
def test_main_resolves_to_named_custom(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "my-model", "provider": "custom:beans"},
|
||||
"custom_providers": [{"name": "beans", "base_url": "http://localhost/v1"}],
|
||||
})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "custom:beans"
|
||||
|
||||
def test_main_resolves_to_openrouter(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "anthropic/claude-sonnet-4", "provider": "openrouter"},
|
||||
})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "openrouter"
|
||||
|
||||
def test_main_resolves_to_deepseek(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "deepseek-chat", "provider": "deepseek"},
|
||||
})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "deepseek"
|
||||
|
||||
def test_main_falls_back_to_custom_when_no_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {"model": {"default": "gpt-4o"}})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "custom"
|
||||
|
||||
def test_bare_provider_name_unchanged(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("beans") == "beans"
|
||||
assert _normalize_vision_provider("deepseek") == "deepseek"
|
||||
|
||||
def test_codex_alias_still_works(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("codex") == "openai-codex"
|
||||
|
||||
def test_auto_unchanged(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("auto") == "auto"
|
||||
assert _normalize_vision_provider(None) == "auto"
|
||||
|
||||
|
||||
class TestResolveProviderClientMainAlias:
|
||||
"""resolve_provider_client('main', ...) should resolve to actual main provider."""
|
||||
|
||||
def test_main_resolves_to_named_custom_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "my-model", "provider": "beans"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("main", "override-model")
|
||||
assert client is not None
|
||||
assert model == "override-model"
|
||||
assert "beans.local" in str(client.base_url)
|
||||
|
||||
def test_main_with_custom_colon_prefix(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "my-model", "provider": "custom:beans"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("main", "test")
|
||||
assert client is not None
|
||||
assert "beans.local" in str(client.base_url)
|
||||
|
||||
|
||||
class TestResolveProviderClientNamedCustom:
|
||||
"""resolve_provider_client should resolve named custom providers directly."""
|
||||
|
||||
def test_named_custom_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test-model"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("beans", "my-model")
|
||||
assert client is not None
|
||||
assert model == "my-model"
|
||||
assert "beans.local" in str(client.base_url)
|
||||
|
||||
def test_named_custom_provider_default_model(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "main-model"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("beans")
|
||||
assert client is not None
|
||||
# Should use _read_main_model() fallback
|
||||
assert model == "main-model"
|
||||
|
||||
def test_named_custom_no_api_key_uses_fallback(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test"},
|
||||
"custom_providers": [
|
||||
{"name": "local", "base_url": "http://localhost:8080/v1"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("local", "test")
|
||||
assert client is not None
|
||||
# no-key-required should be used
|
||||
|
||||
def test_nonexistent_named_custom_falls_through(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
# "coffee" doesn't exist in custom_providers
|
||||
client, model = resolve_provider_client("coffee", "test")
|
||||
assert client is None
|
||||
@@ -197,6 +197,44 @@ class TestNonStringContent:
|
||||
assert summary is not None
|
||||
assert summary == SUMMARY_PREFIX
|
||||
|
||||
def test_summary_call_does_not_force_temperature(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "ok"
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_call:
|
||||
c._generate_summary(messages)
|
||||
|
||||
kwargs = mock_call.call_args.kwargs
|
||||
assert "temperature" not in kwargs
|
||||
|
||||
|
||||
class TestSummaryFailureCooldown:
|
||||
def test_summary_failure_enters_cooldown_and_skips_retry(self):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
|
||||
with patch("agent.context_compressor.call_llm", side_effect=Exception("boom")) as mock_call:
|
||||
first = c._generate_summary(messages)
|
||||
second = c._generate_summary(messages)
|
||||
|
||||
assert first is None
|
||||
assert second is None
|
||||
assert mock_call.call_count == 1
|
||||
|
||||
|
||||
class TestSummaryPrefixNormalization:
|
||||
def test_legacy_prefix_is_replaced(self):
|
||||
|
||||
@@ -947,7 +947,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-xxx",
|
||||
"access_token": "***",
|
||||
}
|
||||
],
|
||||
"custom:together.ai": [
|
||||
@@ -957,7 +957,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-tog-xxx",
|
||||
"access_token": "***",
|
||||
}
|
||||
],
|
||||
"custom:fireworks": [
|
||||
@@ -967,7 +967,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-fw-xxx",
|
||||
"access_token": "***",
|
||||
}
|
||||
],
|
||||
"custom:empty": [],
|
||||
@@ -980,3 +980,78 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
result = list_custom_pool_providers()
|
||||
assert result == ["custom:fireworks", "custom:together.ai"]
|
||||
# "custom:empty" not included because it's empty
|
||||
|
||||
|
||||
|
||||
def test_acquire_lease_prefers_unleased_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
first = pool.acquire_lease()
|
||||
second = pool.acquire_lease()
|
||||
|
||||
assert first == "cred-1"
|
||||
assert second == "cred-2"
|
||||
assert pool.active_lease_count("cred-1") == 1
|
||||
assert pool.active_lease_count("cred-2") == 1
|
||||
|
||||
|
||||
|
||||
def test_release_lease_decrements_counter(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
leased = pool.acquire_lease()
|
||||
assert leased == "cred-1"
|
||||
assert pool.active_lease_count("cred-1") == 1
|
||||
|
||||
pool.release_lease("cred-1")
|
||||
assert pool.active_lease_count("cred-1") == 0
|
||||
289
tests/agent/test_memory_user_id.py
Normal file
289
tests/agent/test_memory_user_id.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Tests for per-user memory scoping via user_id threading.
|
||||
|
||||
Verifies that gateway user_id flows from AIAgent -> MemoryManager -> plugins,
|
||||
so each gateway user gets their own memory bucket instead of sharing a static one.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete test provider that records init kwargs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RecordingProvider(MemoryProvider):
|
||||
"""Minimal provider that records what initialize() receives."""
|
||||
|
||||
def __init__(self, name="recording"):
|
||||
self._name = name
|
||||
self._init_kwargs = {}
|
||||
self._init_session_id = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._init_session_id = session_id
|
||||
self._init_kwargs = dict(kwargs)
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return ""
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content, assistant_content, *, session_id=""):
|
||||
pass
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return []
|
||||
|
||||
def handle_tool_call(self, tool_name, args, **kwargs):
|
||||
return json.dumps({})
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryManager user_id threading tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryManagerUserIdThreading:
|
||||
"""Verify user_id reaches providers via initialize_all."""
|
||||
|
||||
def test_user_id_forwarded_to_provider(self):
|
||||
mgr = MemoryManager()
|
||||
p = RecordingProvider()
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.initialize_all(
|
||||
session_id="sess-123",
|
||||
platform="telegram",
|
||||
user_id="tg_user_42",
|
||||
)
|
||||
|
||||
assert p._init_kwargs.get("user_id") == "tg_user_42"
|
||||
assert p._init_kwargs.get("platform") == "telegram"
|
||||
assert p._init_session_id == "sess-123"
|
||||
|
||||
def test_no_user_id_when_cli(self):
|
||||
"""CLI sessions should not have user_id in kwargs."""
|
||||
mgr = MemoryManager()
|
||||
p = RecordingProvider()
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.initialize_all(
|
||||
session_id="sess-456",
|
||||
platform="cli",
|
||||
)
|
||||
|
||||
assert "user_id" not in p._init_kwargs
|
||||
assert p._init_kwargs.get("platform") == "cli"
|
||||
|
||||
def test_user_id_none_not_forwarded(self):
|
||||
"""Explicit None user_id should not appear in kwargs."""
|
||||
mgr = MemoryManager()
|
||||
p = RecordingProvider()
|
||||
mgr.add_provider(p)
|
||||
|
||||
# Simulates what happens when AIAgent passes user_id=None
|
||||
# (the agent code only adds user_id to kwargs when it's truthy)
|
||||
mgr.initialize_all(
|
||||
session_id="sess-789",
|
||||
platform="discord",
|
||||
)
|
||||
|
||||
assert "user_id" not in p._init_kwargs
|
||||
|
||||
def test_multiple_providers_all_receive_user_id(self):
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
mgr = MemoryManager()
|
||||
# Use builtin + one external (MemoryManager only allows one external)
|
||||
builtin = BuiltinMemoryProvider()
|
||||
ext = RecordingProvider("external")
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(ext)
|
||||
|
||||
mgr.initialize_all(
|
||||
session_id="sess-multi",
|
||||
platform="slack",
|
||||
user_id="slack_U12345",
|
||||
)
|
||||
|
||||
assert ext._init_kwargs.get("user_id") == "slack_U12345"
|
||||
assert ext._init_kwargs.get("platform") == "slack"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mem0 provider user_id tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMem0UserIdScoping:
|
||||
"""Verify Mem0 plugin uses gateway user_id when provided."""
|
||||
|
||||
def test_gateway_user_id_overrides_default(self):
|
||||
"""When user_id is passed via kwargs, it should override the config default."""
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
# Mock _load_config to return a config with default user_id
|
||||
with patch("plugins.memory.mem0._load_config", return_value={
|
||||
"api_key": "test-key",
|
||||
"user_id": "hermes-user",
|
||||
"agent_id": "hermes",
|
||||
"rerank": True,
|
||||
}):
|
||||
provider.initialize(session_id="test-sess", user_id="tg_user_99")
|
||||
|
||||
assert provider._user_id == "tg_user_99"
|
||||
|
||||
def test_no_user_id_falls_back_to_config(self):
|
||||
"""Without user_id in kwargs, should use config default."""
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
with patch("plugins.memory.mem0._load_config", return_value={
|
||||
"api_key": "test-key",
|
||||
"user_id": "custom-default",
|
||||
"agent_id": "hermes",
|
||||
"rerank": True,
|
||||
}):
|
||||
provider.initialize(session_id="test-sess")
|
||||
|
||||
assert provider._user_id == "custom-default"
|
||||
|
||||
def test_no_user_id_no_config_uses_hermes_user(self):
|
||||
"""Without user_id or config override, should default to 'hermes-user'."""
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
with patch("plugins.memory.mem0._load_config", return_value={
|
||||
"api_key": "test-key",
|
||||
"agent_id": "hermes",
|
||||
"rerank": True,
|
||||
}):
|
||||
provider.initialize(session_id="test-sess")
|
||||
|
||||
assert provider._user_id == "hermes-user"
|
||||
|
||||
def test_different_users_get_different_ids(self):
|
||||
"""Two providers initialized with different user_ids should be scoped differently."""
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
p1 = Mem0MemoryProvider()
|
||||
p2 = Mem0MemoryProvider()
|
||||
|
||||
with patch("plugins.memory.mem0._load_config", return_value={
|
||||
"api_key": "test-key",
|
||||
"user_id": "hermes-user",
|
||||
"agent_id": "hermes",
|
||||
"rerank": True,
|
||||
}):
|
||||
p1.initialize(session_id="sess-1", user_id="alice_123")
|
||||
p2.initialize(session_id="sess-2", user_id="bob_456")
|
||||
|
||||
assert p1._user_id == "alice_123"
|
||||
assert p2._user_id == "bob_456"
|
||||
assert p1._user_id != p2._user_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Honcho provider user_id tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHonchoUserIdScoping:
|
||||
"""Verify Honcho plugin uses gateway user_id for peer_name when provided."""
|
||||
|
||||
def test_gateway_user_id_overrides_peer_name(self):
|
||||
"""When user_id is in kwargs, cfg.peer_name should be overridden."""
|
||||
from plugins.memory.honcho import HonchoMemoryProvider
|
||||
|
||||
provider = HonchoMemoryProvider()
|
||||
|
||||
# Create a mock config with a static peer_name
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.enabled = True
|
||||
mock_cfg.api_key = "test-key"
|
||||
mock_cfg.base_url = None
|
||||
mock_cfg.peer_name = "static-user"
|
||||
mock_cfg.recall_mode = "tools" # Use tools mode to defer session init
|
||||
|
||||
with patch(
|
||||
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
|
||||
return_value=mock_cfg,
|
||||
):
|
||||
provider.initialize(
|
||||
session_id="test-sess",
|
||||
user_id="discord_user_789",
|
||||
platform="discord",
|
||||
)
|
||||
|
||||
# The config's peer_name should have been overridden with the user_id
|
||||
assert mock_cfg.peer_name == "discord_user_789"
|
||||
|
||||
def test_no_user_id_preserves_config_peer_name(self):
|
||||
"""Without user_id, the config peer_name should be preserved."""
|
||||
from plugins.memory.honcho import HonchoMemoryProvider
|
||||
|
||||
provider = HonchoMemoryProvider()
|
||||
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.enabled = True
|
||||
mock_cfg.api_key = "test-key"
|
||||
mock_cfg.base_url = None
|
||||
mock_cfg.peer_name = "my-custom-peer"
|
||||
mock_cfg.recall_mode = "tools"
|
||||
|
||||
with patch(
|
||||
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
|
||||
return_value=mock_cfg,
|
||||
):
|
||||
provider.initialize(
|
||||
session_id="test-sess",
|
||||
platform="cli",
|
||||
)
|
||||
|
||||
# peer_name should not have been overridden
|
||||
assert mock_cfg.peer_name == "my-custom-peer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIAgent user_id propagation test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAIAgentUserIdPropagation:
|
||||
"""Verify AIAgent stores user_id and passes it to memory init kwargs."""
|
||||
|
||||
def test_user_id_stored_on_agent(self):
|
||||
"""AIAgent should store user_id as instance attribute."""
|
||||
with patch.dict(os.environ, {"HERMES_HOME": "/tmp/test_hermes"}):
|
||||
from run_agent import AIAgent
|
||||
agent = object.__new__(AIAgent)
|
||||
# Manually set the attribute as __init__ does
|
||||
agent._user_id = "test_user_42"
|
||||
assert agent._user_id == "test_user_42"
|
||||
|
||||
def test_user_id_none_by_default(self):
|
||||
"""AIAgent should have None user_id when not provided (CLI mode)."""
|
||||
with patch.dict(os.environ, {"HERMES_HOME": "/tmp/test_hermes"}):
|
||||
from run_agent import AIAgent
|
||||
agent = object.__new__(AIAgent)
|
||||
agent._user_id = None
|
||||
assert agent._user_id is None
|
||||
42
tests/agent/test_minimax_auxiliary_url.py
Normal file
42
tests/agent/test_minimax_auxiliary_url.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Tests for MiniMax auxiliary client URL normalization.
|
||||
|
||||
MiniMax and MiniMax-CN set inference_base_url to the /anthropic path.
|
||||
The auxiliary client uses the OpenAI SDK, which needs /v1 instead.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from agent.auxiliary_client import _to_openai_base_url
|
||||
|
||||
|
||||
class TestToOpenaiBaseUrl:
|
||||
def test_minimax_global_anthropic_suffix_replaced(self):
|
||||
assert _to_openai_base_url("https://api.minimax.io/anthropic") == "https://api.minimax.io/v1"
|
||||
|
||||
def test_minimax_cn_anthropic_suffix_replaced(self):
|
||||
assert _to_openai_base_url("https://api.minimaxi.com/anthropic") == "https://api.minimaxi.com/v1"
|
||||
|
||||
def test_trailing_slash_stripped_before_replace(self):
|
||||
assert _to_openai_base_url("https://api.minimax.io/anthropic/") == "https://api.minimax.io/v1"
|
||||
|
||||
def test_v1_url_unchanged(self):
|
||||
assert _to_openai_base_url("https://api.openai.com/v1") == "https://api.openai.com/v1"
|
||||
|
||||
def test_openrouter_url_unchanged(self):
|
||||
assert _to_openai_base_url("https://openrouter.ai/api/v1") == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_anthropic_domain_unchanged(self):
|
||||
"""api.anthropic.com doesn't end with /anthropic — should be untouched."""
|
||||
assert _to_openai_base_url("https://api.anthropic.com") == "https://api.anthropic.com"
|
||||
|
||||
def test_anthropic_in_subpath_unchanged(self):
|
||||
assert _to_openai_base_url("https://example.com/anthropic/extra") == "https://example.com/anthropic/extra"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _to_openai_base_url("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert _to_openai_base_url(None) == ""
|
||||
105
tests/agent/test_minimax_provider.py
Normal file
105
tests/agent/test_minimax_provider.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Tests for MiniMax provider hardening — context lengths, thinking guard, catalog."""
|
||||
|
||||
|
||||
class TestMinimaxContextLengths:
|
||||
"""Verify per-model context length entries for MiniMax models."""
|
||||
|
||||
def test_m1_variants_have_1m_context(self):
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
# Keys are lowercase because the lookup lowercases model names
|
||||
for model in ("minimax-m1", "minimax-m1-40k", "minimax-m1-80k",
|
||||
"minimax-m1-128k", "minimax-m1-256k"):
|
||||
assert model in DEFAULT_CONTEXT_LENGTHS, f"{model} missing from context lengths"
|
||||
assert DEFAULT_CONTEXT_LENGTHS[model] == 1_000_000, f"{model} expected 1M"
|
||||
|
||||
def test_m2_variants_have_1m_context(self):
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
# Keys are lowercase because the lookup lowercases model names
|
||||
for model in ("minimax-m2.5", "minimax-m2.7"):
|
||||
assert model in DEFAULT_CONTEXT_LENGTHS, f"{model} missing from context lengths"
|
||||
assert DEFAULT_CONTEXT_LENGTHS[model] == 1_048_576, f"{model} expected 1048576"
|
||||
|
||||
def test_minimax_prefix_fallback(self):
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
# The generic "minimax" prefix entry should be 1M for unknown models
|
||||
assert DEFAULT_CONTEXT_LENGTHS["minimax"] == 1_048_576
|
||||
|
||||
|
||||
|
||||
class TestMinimaxThinkingGuard:
|
||||
"""Verify that build_anthropic_kwargs does NOT add thinking params for MiniMax models."""
|
||||
|
||||
def test_no_thinking_for_minimax_m27(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
assert "output_config" not in kwargs
|
||||
|
||||
def test_no_thinking_for_minimax_m1(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M1-128k",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
|
||||
def test_thinking_still_works_for_claude(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-20250514",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
||||
|
||||
class TestMinimaxAuxModel:
|
||||
"""Verify auxiliary model is standard (not highspeed)."""
|
||||
|
||||
def test_minimax_aux_is_standard(self):
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert _API_KEY_PROVIDER_AUX_MODELS["minimax"] == "MiniMax-M2.7"
|
||||
assert _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"] == "MiniMax-M2.7"
|
||||
|
||||
def test_minimax_aux_not_highspeed(self):
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax"]
|
||||
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"]
|
||||
|
||||
|
||||
class TestMinimaxModelCatalog:
|
||||
"""Verify the model catalog includes M1 family and excludes deprecated models."""
|
||||
|
||||
def test_catalog_includes_m1_family(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M1" in models
|
||||
assert "MiniMax-M1-40k" in models
|
||||
assert "MiniMax-M1-80k" in models
|
||||
assert "MiniMax-M1-128k" in models
|
||||
assert "MiniMax-M1-256k" in models
|
||||
|
||||
def test_catalog_excludes_deprecated(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.1" not in models
|
||||
|
||||
def test_catalog_excludes_highspeed(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.7-highspeed" not in models
|
||||
assert "MiniMax-M2.5-highspeed" not in models
|
||||
@@ -423,7 +423,7 @@ class TestBuildNousSubscriptionPrompt:
|
||||
"web": NousFeatureState("web", "Web tools", True, True, True, True, False, True, "firecrawl"),
|
||||
"image_gen": NousFeatureState("image_gen", "Image generation", True, True, True, True, False, True, "Nous Subscription"),
|
||||
"tts": NousFeatureState("tts", "OpenAI TTS", True, True, True, True, False, True, "OpenAI TTS"),
|
||||
"browser": NousFeatureState("browser", "Browser automation", True, True, True, True, False, True, "Browserbase"),
|
||||
"browser": NousFeatureState("browser", "Browser automation", True, True, True, True, False, True, "Browser Use"),
|
||||
"modal": NousFeatureState("modal", "Modal execution", False, True, False, False, False, True, "local"),
|
||||
},
|
||||
),
|
||||
@@ -431,9 +431,9 @@ class TestBuildNousSubscriptionPrompt:
|
||||
|
||||
prompt = build_nous_subscription_prompt({"web_search", "browser_navigate"})
|
||||
|
||||
assert "Browserbase" in prompt
|
||||
assert "Browser Use" in prompt
|
||||
assert "Modal execution is optional" in prompt
|
||||
assert "do not ask the user for Firecrawl, FAL, OpenAI TTS, or Browserbase API keys" in prompt
|
||||
assert "do not ask the user for Firecrawl, FAL, OpenAI TTS, or Browser-Use API keys" in prompt
|
||||
|
||||
def test_non_subscriber_prompt_includes_relevant_upgrade_guidance(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_ENABLE_NOUS_MANAGED_TOOLS", "1")
|
||||
|
||||
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
46
tests/cli/test_cli_browser_connect.py
Normal file
46
tests/cli/test_cli_browser_connect.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Tests for CLI browser CDP auto-launch helpers."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
class TestChromeDebugLaunch:
|
||||
def test_windows_launch_uses_browser_found_on_path(self):
|
||||
captured = {}
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
captured["cmd"] = cmd
|
||||
captured["kwargs"] = kwargs
|
||||
return object()
|
||||
|
||||
with patch("cli.shutil.which", side_effect=lambda name: r"C:\Chrome\chrome.exe" if name == "chrome.exe" else None), \
|
||||
patch("cli.os.path.isfile", side_effect=lambda path: path == r"C:\Chrome\chrome.exe"), \
|
||||
patch("subprocess.Popen", side_effect=fake_popen):
|
||||
assert HermesCLI._try_launch_chrome_debug(9333, "Windows") is True
|
||||
|
||||
assert captured["cmd"] == [r"C:\Chrome\chrome.exe", "--remote-debugging-port=9333"]
|
||||
assert captured["kwargs"]["start_new_session"] is True
|
||||
|
||||
def test_windows_launch_falls_back_to_common_install_dirs(self, monkeypatch):
|
||||
captured = {}
|
||||
program_files = r"C:\Program Files"
|
||||
# Use os.path.join so path separators match cross-platform
|
||||
installed = os.path.join(program_files, "Google", "Chrome", "Application", "chrome.exe")
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
captured["cmd"] = cmd
|
||||
captured["kwargs"] = kwargs
|
||||
return object()
|
||||
|
||||
monkeypatch.setenv("ProgramFiles", program_files)
|
||||
monkeypatch.delenv("ProgramFiles(x86)", raising=False)
|
||||
monkeypatch.delenv("LOCALAPPDATA", raising=False)
|
||||
|
||||
with patch("cli.shutil.which", return_value=None), \
|
||||
patch("cli.os.path.isfile", side_effect=lambda path: path == installed), \
|
||||
patch("subprocess.Popen", side_effect=fake_popen):
|
||||
assert HermesCLI._try_launch_chrome_debug(9222, "Windows") is True
|
||||
|
||||
assert captured["cmd"] == [installed, "--remote-debugging-port=9222"]
|
||||
@@ -330,7 +330,7 @@ def test_model_flow_nous_prints_subscription_guidance_without_mutating_explicit_
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["claude-opus-4-6"],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None, **kw: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: None)
|
||||
monkeypatch.setattr("hermes_cli.auth._update_config_for_provider", lambda provider, url: None)
|
||||
monkeypatch.setattr(
|
||||
@@ -368,7 +368,7 @@ def test_model_flow_nous_applies_managed_tts_default_when_unconfigured(monkeypat
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["claude-opus-4-6"],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None, **kw: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: None)
|
||||
monkeypatch.setattr("hermes_cli.auth._update_config_for_provider", lambda provider, url: None)
|
||||
monkeypatch.setattr(
|
||||
@@ -538,7 +538,7 @@ def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
|
||||
return "openrouter"
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth.resolve_provider", _resolve_provider)
|
||||
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices: len(choices) - 1)
|
||||
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices, **kwargs: len(choices) - 1)
|
||||
monkeypatch.setattr("sys.stdin", type("FakeTTY", (), {"isatty": lambda self: True})())
|
||||
|
||||
hermes_main.cmd_model(SimpleNamespace())
|
||||
@@ -579,6 +579,7 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys):
|
||||
# "Use this model? [Y/n]:" — confirm with Enter, then context length.
|
||||
answers = iter(["http://localhost:8000", "local-key", "", ""])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers))
|
||||
monkeypatch.setattr("getpass.getpass", lambda _prompt="": next(answers))
|
||||
|
||||
hermes_main._model_flow_custom({})
|
||||
output = capsys.readouterr().out
|
||||
@@ -601,7 +602,7 @@ def test_cmd_model_forwards_nous_login_tls_options(monkeypatch):
|
||||
monkeypatch.setattr("hermes_cli.config.save_env_value", lambda key, value: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.resolve_provider", lambda requested, **kwargs: "nous")
|
||||
monkeypatch.setattr("hermes_cli.auth.get_provider_auth_state", lambda provider_id: None)
|
||||
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices: 0)
|
||||
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices, **kwargs: 0)
|
||||
|
||||
captured = {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Regression tests for CLI /retry history replacement semantics."""
|
||||
|
||||
from tests.test_cli_init import _make_cli
|
||||
from tests.cli.test_cli_init import _make_cli
|
||||
|
||||
|
||||
def test_retry_last_truncates_history_before_requeueing_message():
|
||||
98
tests/cli/test_cli_skin_integration.py
Normal file
98
tests/cli/test_cli_skin_integration.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI, _rich_text_from_ansi
|
||||
from hermes_cli.skin_engine import get_active_skin, set_active_skin
|
||||
|
||||
|
||||
def _make_cli_stub():
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli._sudo_state = None
|
||||
cli._secret_state = None
|
||||
cli._approval_state = None
|
||||
cli._clarify_state = None
|
||||
cli._clarify_freetext = False
|
||||
cli._command_running = False
|
||||
cli._agent_running = False
|
||||
cli._voice_recording = False
|
||||
cli._voice_processing = False
|
||||
cli._voice_mode = False
|
||||
cli._command_spinner_frame = lambda: "⟳"
|
||||
cli._tui_style_base = {
|
||||
"prompt": "#fff",
|
||||
"input-area": "#fff",
|
||||
"input-rule": "#aaa",
|
||||
"prompt-working": "#888 italic",
|
||||
}
|
||||
cli._app = SimpleNamespace(style=None)
|
||||
cli._invalidate = MagicMock()
|
||||
return cli
|
||||
|
||||
|
||||
class TestCliSkinPromptIntegration:
|
||||
def test_default_prompt_fragments_use_default_symbol(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("default")
|
||||
assert cli._get_tui_prompt_fragments() == [("class:prompt", "❯ ")]
|
||||
|
||||
def test_ares_prompt_fragments_use_skin_symbol(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("ares")
|
||||
assert cli._get_tui_prompt_fragments() == [("class:prompt", "⚔ ❯ ")]
|
||||
|
||||
def test_secret_prompt_fragments_preserve_secret_state(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._secret_state = {"response_queue": object()}
|
||||
|
||||
set_active_skin("ares")
|
||||
assert cli._get_tui_prompt_fragments() == [("class:sudo-prompt", "🔑 ❯ ")]
|
||||
|
||||
def test_icon_only_skin_symbol_still_visible_in_special_states(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._secret_state = {"response_queue": object()}
|
||||
|
||||
with patch("hermes_cli.skin_engine.get_active_prompt_symbol", return_value="⚔ "):
|
||||
assert cli._get_tui_prompt_fragments() == [("class:sudo-prompt", "🔑 ⚔ ")]
|
||||
|
||||
def test_build_tui_style_dict_uses_skin_overrides(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("ares")
|
||||
skin = get_active_skin()
|
||||
style_dict = cli._build_tui_style_dict()
|
||||
|
||||
assert style_dict["prompt"] == skin.get_color("prompt")
|
||||
assert style_dict["input-rule"] == skin.get_color("input_rule")
|
||||
assert style_dict["prompt-working"] == f"{skin.get_color('banner_dim')} italic"
|
||||
assert style_dict["approval-title"] == f"{skin.get_color('ui_warn')} bold"
|
||||
|
||||
def test_apply_tui_skin_style_updates_running_app(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("ares")
|
||||
assert cli._apply_tui_skin_style() is True
|
||||
assert cli._app.style is not None
|
||||
cli._invalidate.assert_called_once_with(min_interval=0.0)
|
||||
|
||||
def test_handle_skin_command_refreshes_live_tui(self, capsys):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
with patch("cli.save_config_value", return_value=True):
|
||||
cli._handle_skin_command("/skin ares")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Skin set to: ares (saved)" in output
|
||||
assert "Prompt + TUI colors updated." in output
|
||||
assert cli._app.style is not None
|
||||
|
||||
|
||||
class TestAnsiRichTextHelper:
|
||||
def test_preserves_literal_brackets(self):
|
||||
text = _rich_text_from_ansi("[notatag] literal")
|
||||
assert text.plain == "[notatag] literal"
|
||||
|
||||
def test_strips_ansi_but_keeps_plain_text(self):
|
||||
text = _rich_text_from_ansi("\x1b[31mred\x1b[0m")
|
||||
assert text.plain == "red"
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
@@ -78,6 +79,92 @@ class TestCLIStatusBar:
|
||||
assert "$0.06" not in text # cost hidden by default
|
||||
assert "15m" in text
|
||||
|
||||
def test_input_height_counts_wide_characters_using_cell_width(self):
|
||||
cli_obj = _make_cli()
|
||||
|
||||
class _Doc:
|
||||
lines = ["你" * 10]
|
||||
|
||||
class _Buffer:
|
||||
document = _Doc()
|
||||
|
||||
input_area = SimpleNamespace(buffer=_Buffer())
|
||||
|
||||
def _input_height():
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
from prompt_toolkit.utils import get_cwidth
|
||||
|
||||
doc = input_area.buffer.document
|
||||
prompt_width = max(2, get_cwidth(cli_obj._get_tui_prompt_text()))
|
||||
try:
|
||||
available_width = get_app().output.get_size().columns - prompt_width
|
||||
except Exception:
|
||||
import shutil
|
||||
available_width = shutil.get_terminal_size((80, 24)).columns - prompt_width
|
||||
if available_width < 10:
|
||||
available_width = 40
|
||||
visual_lines = 0
|
||||
for line in doc.lines:
|
||||
line_width = get_cwidth(line)
|
||||
if line_width <= 0:
|
||||
visual_lines += 1
|
||||
else:
|
||||
visual_lines += max(1, -(-line_width // available_width))
|
||||
return min(max(visual_lines, 1), 8)
|
||||
except Exception:
|
||||
return 1
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=14)
|
||||
with patch.object(HermesCLI, "_get_tui_prompt_text", return_value="❯ "), \
|
||||
patch("prompt_toolkit.application.get_app", return_value=mock_app):
|
||||
assert _input_height() == 2
|
||||
|
||||
def test_input_height_uses_prompt_toolkit_width_over_shutil(self):
|
||||
cli_obj = _make_cli()
|
||||
|
||||
class _Doc:
|
||||
lines = ["你" * 10]
|
||||
|
||||
class _Buffer:
|
||||
document = _Doc()
|
||||
|
||||
input_area = SimpleNamespace(buffer=_Buffer())
|
||||
|
||||
def _input_height():
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
from prompt_toolkit.utils import get_cwidth
|
||||
|
||||
doc = input_area.buffer.document
|
||||
prompt_width = max(2, get_cwidth(cli_obj._get_tui_prompt_text()))
|
||||
try:
|
||||
available_width = get_app().output.get_size().columns - prompt_width
|
||||
except Exception:
|
||||
import shutil
|
||||
available_width = shutil.get_terminal_size((80, 24)).columns - prompt_width
|
||||
if available_width < 10:
|
||||
available_width = 40
|
||||
visual_lines = 0
|
||||
for line in doc.lines:
|
||||
line_width = get_cwidth(line)
|
||||
if line_width <= 0:
|
||||
visual_lines += 1
|
||||
else:
|
||||
visual_lines += max(1, -(-line_width // available_width))
|
||||
return min(max(visual_lines, 1), 8)
|
||||
except Exception:
|
||||
return 1
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=14)
|
||||
with patch.object(HermesCLI, "_get_tui_prompt_text", return_value="❯ "), \
|
||||
patch("prompt_toolkit.application.get_app", return_value=mock_app), \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
assert _input_height() == 2
|
||||
mock_shutil.assert_not_called()
|
||||
|
||||
def test_build_status_bar_text_no_cost_in_status_bar(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
66
tests/cli/test_session_boundary_hooks.py
Normal file
66
tests/cli/test_session_boundary_hooks.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from hermes_cli.plugins import VALID_HOOKS, PluginManager
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def test_session_hooks_in_valid_hooks():
|
||||
"""Verify on_session_finalize and on_session_reset are registered as valid hooks."""
|
||||
assert "on_session_finalize" in VALID_HOOKS
|
||||
assert "on_session_reset" in VALID_HOOKS
|
||||
|
||||
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
def test_session_finalize_on_reset(mock_invoke_hook):
|
||||
"""Verify on_session_finalize fires when /new or /reset is used."""
|
||||
cli = HermesCLI()
|
||||
cli.agent = MagicMock()
|
||||
cli.agent.session_id = "test-session-id"
|
||||
|
||||
# Simulate /new command which triggers on_session_finalize for the old session
|
||||
cli.new_session(silent=True)
|
||||
|
||||
# Check if on_session_finalize was called for the old session
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="test-session-id", platform="cli"
|
||||
)
|
||||
# Check if on_session_reset was called for the new session
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_reset", session_id=cli.session_id, platform="cli"
|
||||
)
|
||||
|
||||
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
def test_session_finalize_on_cleanup(mock_invoke_hook):
|
||||
"""Verify on_session_finalize fires during CLI exit cleanup."""
|
||||
import cli as cli_mod
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.session_id = "cleanup-session-id"
|
||||
cli_mod._active_agent_ref = mock_agent
|
||||
cli_mod._cleanup_done = False
|
||||
|
||||
cli_mod._run_cleanup()
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="cleanup-session-id", platform="cli"
|
||||
)
|
||||
|
||||
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
def test_hook_errors_are_caught(mock_invoke_hook):
|
||||
"""Verify hook exceptions are caught and don't crash the agent."""
|
||||
mgr = PluginManager()
|
||||
|
||||
# Register a hook that raises
|
||||
def bad_callback(**kwargs):
|
||||
raise Exception("Hook failed")
|
||||
|
||||
mgr._hooks["on_session_finalize"] = [bad_callback]
|
||||
|
||||
# This should not raise
|
||||
results = mgr.invoke_hook("on_session_finalize", session_id="test", platform="cli")
|
||||
assert results == []
|
||||
@@ -33,6 +33,13 @@ def git_repo(tmp_path):
|
||||
["git", "commit", "-m", "Initial commit"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
# Add a fake remote ref so cleanup logic sees the initial commit as
|
||||
# "pushed". Without this, `git log HEAD --not --remotes` treats every
|
||||
# commit as unpushed and cleanup refuses to delete worktrees.
|
||||
subprocess.run(
|
||||
["git", "update-ref", "refs/remotes/origin/main", "HEAD"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
@@ -81,7 +88,11 @@ def _setup_worktree(repo_root):
|
||||
|
||||
|
||||
def _cleanup_worktree(info):
|
||||
"""Test version of _cleanup_worktree."""
|
||||
"""Test version of _cleanup_worktree.
|
||||
|
||||
Preserves the worktree only if it has unpushed commits.
|
||||
Dirty working tree alone is not enough to keep it.
|
||||
"""
|
||||
wt_path = info["path"]
|
||||
branch = info["branch"]
|
||||
repo_root = info["repo_root"]
|
||||
@@ -89,15 +100,15 @@ def _cleanup_worktree(info):
|
||||
if not Path(wt_path).exists():
|
||||
return
|
||||
|
||||
# Check for uncommitted changes
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
# Check for unpushed commits
|
||||
result = subprocess.run(
|
||||
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
|
||||
capture_output=True, text=True, timeout=10, cwd=wt_path,
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
has_unpushed = bool(result.stdout.strip())
|
||||
|
||||
if has_changes:
|
||||
return False # Did not clean up
|
||||
if has_unpushed:
|
||||
return False # Did not clean up — has unpushed commits
|
||||
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", wt_path, "--force"],
|
||||
@@ -204,20 +215,45 @@ class TestWorktreeCleanup:
|
||||
assert result is True
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_dirty_worktree_kept(self, git_repo):
|
||||
def test_dirty_worktree_cleaned_when_no_unpushed(self, git_repo):
|
||||
"""Dirty working tree without unpushed commits is cleaned up.
|
||||
|
||||
Agent sessions typically leave untracked files / artifacts behind.
|
||||
Since all real work is in pushed commits, these don't warrant
|
||||
keeping the worktree.
|
||||
"""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make uncommitted changes
|
||||
# Make uncommitted changes (untracked file)
|
||||
(Path(info["path"]) / "new-file.txt").write_text("uncommitted")
|
||||
subprocess.run(
|
||||
["git", "add", "new-file.txt"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# The git_repo fixture already has a fake remote ref so the initial
|
||||
# commit is seen as "pushed". No unpushed commits → cleanup proceeds.
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is False
|
||||
assert Path(info["path"]).exists() # Still there
|
||||
assert result is True # Cleaned up despite dirty working tree
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_worktree_with_unpushed_commits_kept(self, git_repo):
|
||||
"""Worktree with unpushed commits is preserved."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make a commit that is NOT on any remote
|
||||
(Path(info["path"]) / "work.txt").write_text("real work")
|
||||
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "agent work"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is False # Kept — has unpushed commits
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_branch_deleted_on_cleanup(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
@@ -367,7 +403,7 @@ class TestMultipleWorktrees:
|
||||
lines = [l for l in result.stdout.strip().splitlines() if l.strip()]
|
||||
assert len(lines) == 11
|
||||
|
||||
# Cleanup all
|
||||
# Cleanup all (git_repo fixture has a fake remote ref so cleanup works)
|
||||
for info in worktrees:
|
||||
# Discard changes first so cleanup works
|
||||
subprocess.run(
|
||||
@@ -492,33 +528,77 @@ class TestStaleWorktreePruning:
|
||||
assert not pruned
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_keeps_dirty_old_worktree(self, git_repo):
|
||||
"""Old worktrees with uncommitted changes should NOT be pruned."""
|
||||
def test_keeps_old_worktree_with_unpushed_commits(self, git_repo):
|
||||
"""Old worktrees (24-72h) with unpushed commits should NOT be pruned."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make it dirty
|
||||
(Path(info["path"]) / "dirty.txt").write_text("uncommitted")
|
||||
# Make an unpushed commit
|
||||
(Path(info["path"]) / "work.txt").write_text("real work")
|
||||
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "add", "dirty.txt"],
|
||||
["git", "commit", "-m", "agent work"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# Make it old
|
||||
# Make it old (25h — in the 24-72h soft tier)
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Check if it would be pruned
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
# Check for unpushed commits (simulates prune logic)
|
||||
result = subprocess.run(
|
||||
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
assert has_changes # Should be dirty → not pruned
|
||||
has_unpushed = bool(result.stdout.strip())
|
||||
assert has_unpushed # Has unpushed commits → not pruned in soft tier
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_force_prunes_very_old_worktree(self, git_repo):
|
||||
"""Worktrees older than 72h should be force-pruned regardless."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make an unpushed commit (would normally protect it)
|
||||
(Path(info["path"]) / "work.txt").write_text("stale work")
|
||||
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "old agent work"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# Make it very old (73h — beyond the 72h hard threshold)
|
||||
old_time = time.time() - (73 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Simulate the force-prune tier check
|
||||
hard_cutoff = time.time() - (72 * 3600)
|
||||
mtime = Path(info["path"]).stat().st_mtime
|
||||
assert mtime <= hard_cutoff # Should qualify for force removal
|
||||
|
||||
# Actually remove it (simulates _prune_stale_worktrees force path)
|
||||
branch_result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, timeout=5, cwd=info["path"],
|
||||
)
|
||||
branch = branch_result.stdout.strip()
|
||||
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", info["path"], "--force"],
|
||||
capture_output=True, text=True, timeout=15, cwd=str(git_repo),
|
||||
)
|
||||
if branch:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases for robustness."""
|
||||
@@ -611,6 +691,133 @@ class TestTerminalCWDIntegration:
|
||||
assert result.stdout.strip() == "true"
|
||||
|
||||
|
||||
class TestOrphanedBranchPruning:
|
||||
"""Test cleanup of orphaned hermes/* and pr-* branches."""
|
||||
|
||||
def test_prunes_orphaned_hermes_branch(self, git_repo):
|
||||
"""hermes/hermes-* branches with no worktree should be deleted."""
|
||||
# Create a branch that looks like a worktree branch but has no worktree
|
||||
subprocess.run(
|
||||
["git", "branch", "hermes/hermes-deadbeef", "HEAD"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
|
||||
# Verify it exists
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--list", "hermes/hermes-deadbeef"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
assert "hermes/hermes-deadbeef" in result.stdout
|
||||
|
||||
# Simulate _prune_orphaned_branches logic
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
|
||||
wt_result = subprocess.run(
|
||||
["git", "worktree", "list", "--porcelain"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
active_branches = {"main"}
|
||||
for line in wt_result.stdout.split("\n"):
|
||||
if line.startswith("branch refs/heads/"):
|
||||
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
|
||||
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches
|
||||
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
|
||||
]
|
||||
assert "hermes/hermes-deadbeef" in orphaned
|
||||
|
||||
# Delete them
|
||||
if orphaned:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D"] + orphaned,
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
# Verify gone
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--list", "hermes/hermes-deadbeef"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
assert "hermes/hermes-deadbeef" not in result.stdout
|
||||
|
||||
def test_prunes_orphaned_pr_branch(self, git_repo):
|
||||
"""pr-* branches should be deleted during pruning."""
|
||||
subprocess.run(
|
||||
["git", "branch", "pr-1234", "HEAD"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "branch", "pr-5678", "HEAD"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
|
||||
active_branches = {"main"}
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches and b.startswith("pr-")
|
||||
]
|
||||
assert "pr-1234" in orphaned
|
||||
assert "pr-5678" in orphaned
|
||||
|
||||
subprocess.run(
|
||||
["git", "branch", "-D"] + orphaned,
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
# Verify gone
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
remaining = result.stdout.strip()
|
||||
assert "pr-1234" not in remaining
|
||||
assert "pr-5678" not in remaining
|
||||
|
||||
def test_preserves_active_worktree_branch(self, git_repo):
|
||||
"""Branches with active worktrees should NOT be pruned."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "worktree", "list", "--porcelain"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
active_branches = set()
|
||||
for line in result.stdout.split("\n"):
|
||||
if line.startswith("branch refs/heads/"):
|
||||
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
|
||||
|
||||
assert info["branch"] in active_branches # Protected
|
||||
|
||||
def test_preserves_main_branch(self, git_repo):
|
||||
"""main branch should never be pruned."""
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
active_branches = {"main"}
|
||||
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches
|
||||
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
|
||||
]
|
||||
assert "main" not in orphaned
|
||||
|
||||
|
||||
class TestSystemPromptInjection:
|
||||
"""Test that the agent gets worktree context in its system prompt."""
|
||||
|
||||
@@ -625,7 +832,7 @@ class TestSystemPromptInjection:
|
||||
f"{info['path']}. Your branch is `{info['branch']}`. "
|
||||
f"Changes here do not affect the main working tree or other agents. "
|
||||
f"Remember to commit and push your changes, and create a PR if appropriate. "
|
||||
f"The original repo is at {info['repo_root']}.]"
|
||||
f"The original repo is at {info['repo_root']}.]\n"
|
||||
)
|
||||
|
||||
assert info["path"] in wt_note
|
||||
@@ -152,11 +152,22 @@ def test_gateway_run_agent_codex_path_handles_internal_401_refresh(monkeypatch):
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner._smart_model_routing = {}
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner._session_db = None
|
||||
# Ensure model resolution returns the codex model even if xdist
|
||||
# leaked env vars cleared HERMES_MODEL.
|
||||
monkeypatch.setattr(
|
||||
gateway_run.GatewayRunner,
|
||||
"_resolve_turn_agent_config",
|
||||
lambda self, msg, model, runtime: {
|
||||
"model": model or "gpt-5.3-codex",
|
||||
"runtime": runtime,
|
||||
},
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
@@ -339,6 +339,36 @@ class TestMarkJobRun:
|
||||
assert updated["last_status"] == "error"
|
||||
assert updated["last_error"] == "timeout"
|
||||
|
||||
def test_delivery_error_tracked_separately(self, tmp_cron_dir):
|
||||
"""Agent succeeds but delivery fails — both tracked independently."""
|
||||
job = create_job(prompt="Report", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=True, delivery_error="platform 'telegram' not configured")
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_status"] == "ok"
|
||||
assert updated["last_error"] is None
|
||||
assert updated["last_delivery_error"] == "platform 'telegram' not configured"
|
||||
|
||||
def test_delivery_error_cleared_on_success(self, tmp_cron_dir):
|
||||
"""Successful delivery clears the previous delivery error."""
|
||||
job = create_job(prompt="Report", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=True, delivery_error="network timeout")
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_delivery_error"] == "network timeout"
|
||||
# Next run delivers successfully
|
||||
mark_job_run(job["id"], success=True, delivery_error=None)
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_delivery_error"] is None
|
||||
|
||||
def test_both_agent_and_delivery_error(self, tmp_cron_dir):
|
||||
"""Agent fails AND delivery fails — both errors recorded."""
|
||||
job = create_job(prompt="Report", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=False, error="model timeout",
|
||||
delivery_error="platform 'discord' not enabled")
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_status"] == "error"
|
||||
assert updated["last_error"] == "model timeout"
|
||||
assert updated["last_delivery_error"] == "platform 'discord' not enabled"
|
||||
|
||||
|
||||
class TestAdvanceNextRun:
|
||||
"""Tests for advance_next_run() — crash-safety for recurring jobs."""
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, run_job, SILENT_MARKER, _build_job_prompt
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt
|
||||
|
||||
|
||||
class TestResolveOrigin:
|
||||
@@ -277,6 +277,188 @@ class TestDeliverResultWrapping:
|
||||
# Media files should be forwarded separately
|
||||
assert kwargs["media_files"] == [("/tmp/test-voice.ogg", False)]
|
||||
|
||||
def test_live_adapter_sends_media_as_attachments(self):
|
||||
"""When a live adapter is available, MEDIA files should be sent as native
|
||||
platform attachments (e.g., Discord voice, Telegram audio) rather than
|
||||
as literal 'MEDIA:/path' text."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send.return_value = MagicMock(success=True)
|
||||
adapter.send_voice.return_value = MagicMock(success=True)
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.DISCORD: pconfig}
|
||||
|
||||
loop = MagicMock()
|
||||
loop.is_running.return_value = True
|
||||
|
||||
# run_coroutine_threadsafe returns concurrent.futures.Future (has timeout kwarg)
|
||||
def fake_run_coro(coro, _loop):
|
||||
future = Future()
|
||||
future.set_result(MagicMock(success=True))
|
||||
coro.close()
|
||||
return future
|
||||
|
||||
job = {
|
||||
"id": "tts-job",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "discord", "chat_id": "9876"},
|
||||
}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
|
||||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"Here is TTS\nMEDIA:/tmp/cron-voice.mp3",
|
||||
adapters={Platform.DISCORD: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
# Text should be sent without the MEDIA tag
|
||||
adapter.send.assert_called_once()
|
||||
text_sent = adapter.send.call_args[0][1]
|
||||
assert "MEDIA:" not in text_sent
|
||||
assert "Here is TTS" in text_sent
|
||||
|
||||
# Audio file should be sent as a voice attachment
|
||||
adapter.send_voice.assert_called_once()
|
||||
voice_call = adapter.send_voice.call_args
|
||||
assert voice_call[1]["audio_path"] == "/tmp/cron-voice.mp3"
|
||||
|
||||
def test_live_adapter_routes_image_to_send_image_file(self):
|
||||
"""Image MEDIA files should be routed to send_image_file, not send_voice."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send.return_value = MagicMock(success=True)
|
||||
adapter.send_image_file.return_value = MagicMock(success=True)
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.DISCORD: pconfig}
|
||||
|
||||
loop = MagicMock()
|
||||
loop.is_running.return_value = True
|
||||
|
||||
def fake_run_coro(coro, _loop):
|
||||
future = Future()
|
||||
future.set_result(MagicMock(success=True))
|
||||
coro.close()
|
||||
return future
|
||||
|
||||
job = {
|
||||
"id": "img-job",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "discord", "chat_id": "1234"},
|
||||
}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
|
||||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"Chart attached\nMEDIA:/tmp/chart.png",
|
||||
adapters={Platform.DISCORD: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
adapter.send_image_file.assert_called_once()
|
||||
assert adapter.send_image_file.call_args[1]["image_path"] == "/tmp/chart.png"
|
||||
adapter.send_voice.assert_not_called()
|
||||
|
||||
def test_live_adapter_media_only_no_text(self):
|
||||
"""When content is ONLY a MEDIA tag with no text, media should still be sent."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send_voice.return_value = MagicMock(success=True)
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
loop = MagicMock()
|
||||
loop.is_running.return_value = True
|
||||
|
||||
def fake_run_coro(coro, _loop):
|
||||
future = Future()
|
||||
future.set_result(MagicMock(success=True))
|
||||
coro.close()
|
||||
return future
|
||||
|
||||
job = {
|
||||
"id": "voice-only",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "999"},
|
||||
}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
|
||||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"MEDIA:/tmp/voice.ogg",
|
||||
adapters={Platform.TELEGRAM: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
# Text send should NOT be called (no text after stripping MEDIA tag)
|
||||
adapter.send.assert_not_called()
|
||||
# Audio should still be delivered
|
||||
adapter.send_voice.assert_called_once()
|
||||
|
||||
def test_live_adapter_sends_cleaned_text_not_raw(self):
|
||||
"""The live adapter path must send cleaned text (MEDIA tags stripped),
|
||||
not the raw delivery_content with embedded MEDIA: tags."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send.return_value = MagicMock(success=True)
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
loop = MagicMock()
|
||||
loop.is_running.return_value = True
|
||||
|
||||
def fake_run_coro(coro, _loop):
|
||||
future = Future()
|
||||
future.set_result(MagicMock(success=True))
|
||||
coro.close()
|
||||
return future
|
||||
|
||||
job = {
|
||||
"id": "img-job",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "555"},
|
||||
}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}), \
|
||||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"Report\nMEDIA:/tmp/chart.png",
|
||||
adapters={Platform.TELEGRAM: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
text_sent = adapter.send.call_args[0][1]
|
||||
assert "MEDIA:" not in text_sent
|
||||
assert "Report" in text_sent
|
||||
|
||||
def test_no_mirror_to_session_call(self):
|
||||
"""Cron deliveries should NOT mirror into the gateway session."""
|
||||
from gateway.config import Platform
|
||||
@@ -326,6 +508,90 @@ class TestDeliverResultWrapping:
|
||||
assert send_mock.call_args.kwargs["thread_id"] == "17585"
|
||||
|
||||
|
||||
class TestDeliverResultErrorReturns:
|
||||
"""Verify _deliver_result returns error strings on failure, None on success."""
|
||||
|
||||
def test_returns_none_on_successful_delivery(self):
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})):
|
||||
job = {
|
||||
"id": "ok-job",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_local_delivery(self):
|
||||
"""local-only jobs don't deliver — not a failure."""
|
||||
job = {"id": "local-job", "deliver": "local"}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is None
|
||||
|
||||
def test_returns_error_for_unknown_platform(self):
|
||||
job = {
|
||||
"id": "bad-platform",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "fax", "chat_id": "123"},
|
||||
}
|
||||
with patch("gateway.config.load_gateway_config"):
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "unknown platform" in result
|
||||
|
||||
def test_returns_error_when_platform_disabled(self):
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = False
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg):
|
||||
job = {
|
||||
"id": "disabled",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "not configured" in result
|
||||
|
||||
def test_returns_error_on_send_failure(self):
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"error": "rate limited"})):
|
||||
job = {
|
||||
"id": "rate-limited",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "rate limited" in result
|
||||
|
||||
def test_returns_error_for_unresolved_target(self, monkeypatch):
|
||||
"""Non-local delivery with no resolvable target should return an error."""
|
||||
monkeypatch.delenv("TELEGRAM_HOME_CHANNEL", raising=False)
|
||||
job = {"id": "no-target", "deliver": "telegram"}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "no delivery target" in result
|
||||
|
||||
|
||||
class TestRunJobSessionPersistence:
|
||||
def test_run_job_passes_session_db_and_cron_platform(self, tmp_path):
|
||||
job = {
|
||||
@@ -709,6 +975,18 @@ class TestSilentDelivery:
|
||||
tick(verbose=False)
|
||||
deliver_mock.assert_not_called()
|
||||
|
||||
def test_silent_trailing_suppresses_delivery(self):
|
||||
"""Agent appended [SILENT] after explanation text — must still suppress."""
|
||||
response = "2 deals filtered out (like<10, reply<15).\n\n[SILENT]"
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=[self._make_job()]), \
|
||||
patch("cron.scheduler.run_job", return_value=(True, "# output", response, None)), \
|
||||
patch("cron.scheduler.save_job_output", return_value="/tmp/out.md"), \
|
||||
patch("cron.scheduler._deliver_result") as deliver_mock, \
|
||||
patch("cron.scheduler.mark_job_run"):
|
||||
from cron.scheduler import tick
|
||||
tick(verbose=False)
|
||||
deliver_mock.assert_not_called()
|
||||
|
||||
def test_silent_is_case_insensitive(self):
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=[self._make_job()]), \
|
||||
patch("cron.scheduler.run_job", return_value=(True, "# output", "[silent] nothing new", None)), \
|
||||
@@ -850,3 +1128,57 @@ class TestTickAdvanceBeforeRun:
|
||||
adv_mock.assert_called_once_with("test-advance")
|
||||
# advance must happen before run
|
||||
assert call_order == [("advance", "test-advance"), ("run", "test-advance")]
|
||||
|
||||
|
||||
class TestSendMediaViaAdapter:
|
||||
"""Unit tests for _send_media_via_adapter — routes files to typed adapter methods."""
|
||||
|
||||
@staticmethod
|
||||
def _run_with_loop(adapter, chat_id, media_files, metadata, job):
|
||||
"""Helper: run _send_media_via_adapter with a real running event loop."""
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=loop.run_forever, daemon=True)
|
||||
t.start()
|
||||
try:
|
||||
_send_media_via_adapter(adapter, chat_id, media_files, metadata, loop, job)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
t.join(timeout=5)
|
||||
loop.close()
|
||||
|
||||
def test_video_dispatched_to_send_video(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send_video = AsyncMock()
|
||||
media_files = [("/tmp/clip.mp4", False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j1"})
|
||||
adapter.send_video.assert_called_once()
|
||||
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/clip.mp4"
|
||||
|
||||
def test_unknown_ext_dispatched_to_send_document(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send_document = AsyncMock()
|
||||
media_files = [("/tmp/report.pdf", False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j2"})
|
||||
adapter.send_document.assert_called_once()
|
||||
assert adapter.send_document.call_args[1]["file_path"] == "/tmp/report.pdf"
|
||||
|
||||
def test_multiple_media_files_all_delivered(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send_voice = AsyncMock()
|
||||
adapter.send_image_file = AsyncMock()
|
||||
media_files = [("/tmp/voice.mp3", False), ("/tmp/photo.jpg", False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"})
|
||||
adapter.send_voice.assert_called_once()
|
||||
adapter.send_image_file.assert_called_once()
|
||||
|
||||
def test_single_failure_does_not_block_others(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send_voice = AsyncMock(side_effect=RuntimeError("network error"))
|
||||
adapter.send_image_file = AsyncMock()
|
||||
media_files = [("/tmp/voice.ogg", False), ("/tmp/photo.png", False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j4"})
|
||||
adapter.send_voice.assert_called_once()
|
||||
adapter.send_image_file.assert_called_once()
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Security tests for Terminal-Bench 2 archive extraction."""
|
||||
|
||||
import base64
|
||||
import importlib
|
||||
import io
|
||||
import sys
|
||||
import tarfile
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = types.ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
def _load_terminalbench_module(monkeypatch):
|
||||
class _EvalHandlingEnum:
|
||||
STOP_TRAIN = "stop_train"
|
||||
|
||||
class _APIServerConfig:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _AgentResult:
|
||||
pass
|
||||
|
||||
class _HermesAgentLoop:
|
||||
pass
|
||||
|
||||
class _HermesAgentBaseEnv:
|
||||
pass
|
||||
|
||||
class _HermesAgentEnvConfig:
|
||||
pass
|
||||
|
||||
class _ToolContext:
|
||||
pass
|
||||
|
||||
stub_modules = {
|
||||
"atroposlib": _stub_module("atroposlib"),
|
||||
"atroposlib.envs": _stub_module("atroposlib.envs"),
|
||||
"atroposlib.envs.base": _stub_module(
|
||||
"atroposlib.envs.base",
|
||||
EvalHandlingEnum=_EvalHandlingEnum,
|
||||
),
|
||||
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
|
||||
"atroposlib.envs.server_handling.server_manager": _stub_module(
|
||||
"atroposlib.envs.server_handling.server_manager",
|
||||
APIServerConfig=_APIServerConfig,
|
||||
),
|
||||
"environments.agent_loop": _stub_module(
|
||||
"environments.agent_loop",
|
||||
AgentResult=_AgentResult,
|
||||
HermesAgentLoop=_HermesAgentLoop,
|
||||
),
|
||||
"environments.hermes_base_env": _stub_module(
|
||||
"environments.hermes_base_env",
|
||||
HermesAgentBaseEnv=_HermesAgentBaseEnv,
|
||||
HermesAgentEnvConfig=_HermesAgentEnvConfig,
|
||||
),
|
||||
"environments.tool_context": _stub_module(
|
||||
"environments.tool_context",
|
||||
ToolContext=_ToolContext,
|
||||
),
|
||||
"tools.terminal_tool": _stub_module(
|
||||
"tools.terminal_tool",
|
||||
register_task_env_overrides=lambda *args, **kwargs: None,
|
||||
clear_task_env_overrides=lambda *args, **kwargs: None,
|
||||
cleanup_vm=lambda *args, **kwargs: None,
|
||||
),
|
||||
}
|
||||
|
||||
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
|
||||
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
|
||||
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
|
||||
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
|
||||
"atroposlib.envs.server_handling.server_manager"
|
||||
]
|
||||
|
||||
for name, module in stub_modules.items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
|
||||
sys.modules.pop(module_name, None)
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
def _build_tar_b64(entries):
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for entry in entries:
|
||||
kind = entry["kind"]
|
||||
info = tarfile.TarInfo(entry["name"])
|
||||
|
||||
if kind == "dir":
|
||||
info.type = tarfile.DIRTYPE
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
if kind == "file":
|
||||
data = entry["data"].encode("utf-8")
|
||||
info.size = len(data)
|
||||
tar.addfile(info, io.BytesIO(data))
|
||||
continue
|
||||
|
||||
if kind == "symlink":
|
||||
info.type = tarfile.SYMTYPE
|
||||
info.linkname = entry["target"]
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unknown tar entry kind: {kind}")
|
||||
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "dir", "name": "nested"},
|
||||
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "file", "name": "../escape.txt", "data": "owned"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (tmp_path / "escape.txt").exists()
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsupported archive member type"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (target / "link").exists()
|
||||
@@ -439,7 +439,7 @@ class TestChatCompletionsEndpoint:
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
# Simulate tool progress before streaming content
|
||||
if tp_cb:
|
||||
tp_cb("terminal", "ls -la", {"command": "ls -la"})
|
||||
tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Here are the files.")
|
||||
@@ -476,8 +476,8 @@ class TestChatCompletionsEndpoint:
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
if tp_cb:
|
||||
tp_cb("_thinking", "some internal state", {})
|
||||
tp_cb("web_search", "Python docs", {"query": "Python docs"})
|
||||
tp_cb("tool.started", "_thinking", "some internal state", {})
|
||||
tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Found it.")
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestHermesApiServerToolset:
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
for tool in ["browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close"]:
|
||||
"browser_press"]:
|
||||
assert tool in tools, f"Missing browser tool: {tool}"
|
||||
|
||||
def test_toolset_includes_homeassistant_tools(self):
|
||||
|
||||
313
tests/gateway/test_command_bypass_active_session.py
Normal file
313
tests/gateway/test_command_bypass_active_session.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Regression tests: slash commands must bypass the base adapter's active-session guard.
|
||||
|
||||
When an agent is running, the base adapter's Level 1 guard in
|
||||
handle_message() intercepts all incoming messages and queues them as
|
||||
pending. Certain commands (/stop, /new, /reset, /approve, /deny,
|
||||
/status) must bypass this guard and be dispatched directly to the gateway
|
||||
runner — otherwise they are queued as user text and either:
|
||||
- leak into the conversation as agent input (/stop, /new), or
|
||||
- deadlock (/approve, /deny — agent blocks on Event.wait)
|
||||
|
||||
These tests verify that the bypass works at the adapter level and that
|
||||
the safety net in _run_agent discards leaked command text.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
"""Concrete adapter with abstract methods stubbed out."""
|
||||
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {}
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a minimal adapter for testing the active-session guard."""
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = _StubAdapter(config, Platform.TELEGRAM)
|
||||
adapter.sent_responses = []
|
||||
|
||||
async def _mock_handler(event):
|
||||
cmd = event.get_command()
|
||||
return f"handled:{cmd}" if cmd else f"handled:text:{event.text}"
|
||||
|
||||
adapter._message_handler = _mock_handler
|
||||
|
||||
async def _mock_send_retry(chat_id, content, **kwargs):
|
||||
adapter.sent_responses.append(content)
|
||||
|
||||
adapter._send_with_retry = _mock_send_retry
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(text="/stop", chat_id="12345"):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
|
||||
)
|
||||
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
|
||||
|
||||
|
||||
def _session_key(chat_id="12345"):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
|
||||
)
|
||||
return build_session_key(source)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: commands bypass Level 1 when session is active
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCommandBypassActiveSession:
|
||||
"""Commands that must bypass the active-session guard."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_bypasses_guard(self):
|
||||
"""/stop must be dispatched directly, not queued."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/stop"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/stop was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:stop" in r for r in adapter.sent_responses), (
|
||||
"/stop response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_bypasses_guard(self):
|
||||
"""/new must be dispatched directly, not queued."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/new"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:new" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_bypasses_guard(self):
|
||||
"""/reset (alias for /new) must be dispatched directly."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/reset"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:reset" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_bypasses_guard(self):
|
||||
"""/approve must bypass (deadlock prevention)."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/approve"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:approve" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_bypasses_guard(self):
|
||||
"""/deny must bypass (deadlock prevention)."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/deny"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:deny" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_bypasses_guard(self):
|
||||
"""/status must bypass so it returns a system response."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/status"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:status" in r for r in adapter.sent_responses)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: non-bypass messages still get queued
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNonBypassStillQueued:
|
||||
"""Regular messages and unknown commands must be queued, not dispatched."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_text_queued(self):
|
||||
"""Plain text while agent is running must be queued as pending."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("hello world"))
|
||||
|
||||
assert sk in adapter._pending_messages, (
|
||||
"Regular text was not queued — it should be pending"
|
||||
)
|
||||
assert len(adapter.sent_responses) == 0, (
|
||||
"Regular text should not produce a direct response"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_command_queued(self):
|
||||
"""Unknown /commands must be queued, not dispatched."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/foobar"))
|
||||
|
||||
assert sk in adapter._pending_messages
|
||||
assert len(adapter.sent_responses) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_path_not_treated_as_command(self):
|
||||
"""A message like '/path/to/file' must not bypass the guard."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/path/to/file.py"))
|
||||
|
||||
assert sk in adapter._pending_messages
|
||||
assert len(adapter.sent_responses) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: no active session — commands go through normally
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoActiveSessionNormalDispatch:
|
||||
"""When no agent is running, messages spawn a background task normally."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_no_session_active(self):
|
||||
"""/stop without an active session spawns a background task
|
||||
(the Level 2 handler will return 'No active task')."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
|
||||
# No active session — _active_sessions is empty
|
||||
assert sk not in adapter._active_sessions
|
||||
|
||||
await adapter.handle_message(_make_event("/stop"))
|
||||
|
||||
# Should have gone through the normal path (background task spawned)
|
||||
# and NOT be in _pending_messages (that's the queued-during-active path)
|
||||
assert sk not in adapter._pending_messages
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: safety net in _run_agent discards command text from pending queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPendingCommandSafetyNet:
|
||||
"""The safety net in gateway/run.py _run_agent must discard command text
|
||||
that leaks into the pending queue via interrupt_message fallback."""
|
||||
|
||||
def test_stop_command_detected(self):
|
||||
"""resolve_command must recognize /stop so the safety net can
|
||||
discard it."""
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
assert resolve_command("stop") is not None
|
||||
assert resolve_command("stop").name == "stop"
|
||||
|
||||
def test_new_command_detected(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
assert resolve_command("new") is not None
|
||||
assert resolve_command("new").name == "new"
|
||||
|
||||
def test_reset_alias_detected(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
assert resolve_command("reset") is not None
|
||||
assert resolve_command("reset").name == "new" # alias
|
||||
|
||||
def test_unknown_command_not_detected(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
assert resolve_command("foobar") is None
|
||||
|
||||
def test_file_path_not_detected_as_command(self):
|
||||
"""'/path/to/file' should not resolve as a command."""
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
# The safety net splits on whitespace and takes the first word
|
||||
# after stripping '/'. For '/path/to/file', that's 'path/to/file'.
|
||||
assert resolve_command("path/to/file") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: bypass with @botname suffix (Telegram-style)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBypassWithBotnameSuffix:
|
||||
"""Telegram appends @botname to commands. The bypass must still work."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_with_botname(self):
|
||||
"""/stop@MyHermesBot must bypass the guard."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/stop@MyHermesBot"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/stop@MyHermesBot was queued instead of bypassing"
|
||||
)
|
||||
assert any("handled:stop" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_with_botname(self):
|
||||
"""/new@MyHermesBot must bypass the guard."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/new@MyHermesBot"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:new" in r for r in adapter.sent_responses)
|
||||
343
tests/gateway/test_discord_channel_controls.py
Normal file
343
tests/gateway/test_discord_channel_controls.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Tests for Discord ignored_channels and no_thread_channels config."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "dm"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class FakeTextChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.parent_id = getattr(parent, "id", None)
|
||||
self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def make_message(*, channel, content: str, mentions=None):
|
||||
author = SimpleNamespace(id=42, display_name="TestUser", name="TestUser")
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
mentions=list(mentions or []),
|
||||
attachments=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=channel,
|
||||
author=author,
|
||||
)
|
||||
|
||||
|
||||
# ── ignored_channels ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channel_blocks_message(adapter, monkeypatch):
|
||||
"""Messages in ignored channels are silently dropped."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channel_blocks_even_with_mention(adapter, monkeypatch):
|
||||
"""Ignored channels take priority — even @mentions are dropped."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
|
||||
bot_user = adapter._client.user
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=500),
|
||||
content=f"<@{bot_user.id}> hello",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ignored_channel_processes_normally(adapter, monkeypatch):
|
||||
"""Channels not in the ignored list process normally."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500,600")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=700), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channels_csv_parsing(adapter, monkeypatch):
|
||||
"""Multiple channel IDs are parsed correctly from CSV."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500, 600 , 700")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
for ch_id in (500, 600, 700):
|
||||
adapter.handle_message.reset_mock()
|
||||
message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channels_empty_string_ignores_nothing(adapter, monkeypatch):
|
||||
"""Empty DISCORD_IGNORED_CHANNELS means nothing is ignored."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channel_thread_parent_match(adapter, monkeypatch):
|
||||
"""Thread whose parent channel is ignored should also be ignored."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
parent = FakeTextChannel(channel_id=500, name="ignored-channel")
|
||||
thread = FakeThread(channel_id=501, name="thread-in-ignored", parent=parent)
|
||||
message = make_message(channel=thread, content="hello from thread")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dms_unaffected_by_ignored_channels(adapter, monkeypatch):
|
||||
"""DMs should never be affected by ignored_channels."""
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeDMChannel(channel_id=500), content="dm hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
# ── no_thread_channels ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_channel_skips_auto_thread(adapter, monkeypatch):
|
||||
"""Channels in no_thread_channels should not auto-create threads."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_channel_still_auto_threads(adapter, monkeypatch):
|
||||
"""Channels NOT in no_thread_channels still get auto-threading."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
fake_thread = FakeThread(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=900), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_channels_csv_parsing(adapter, monkeypatch):
|
||||
"""Multiple no_thread channel IDs parsed from CSV."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800, 900")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
|
||||
|
||||
for ch_id in (800, 900):
|
||||
adapter._auto_create_thread.reset_mock()
|
||||
adapter.handle_message.reset_mock()
|
||||
message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_with_auto_thread_disabled_is_noop(adapter, monkeypatch):
|
||||
"""no_thread_channels is a no-op when auto_thread is globally disabled."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
# ── config.py bridging ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_config_bridges_ignored_channels(monkeypatch, tmp_path):
|
||||
"""gateway/config.py bridges discord.ignored_channels to env var."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"discord": {
|
||||
"ignored_channels": ["111", "222"],
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Use setenv (not delenv) so monkeypatch registers cleanup even when
|
||||
# the var doesn't exist yet — load_gateway_config will overwrite it.
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("DISCORD_IGNORED_CHANNELS") == "111,222"
|
||||
|
||||
|
||||
def test_config_bridges_no_thread_channels(monkeypatch, tmp_path):
|
||||
"""gateway/config.py bridges discord.no_thread_channels to env var."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"discord": {
|
||||
"no_thread_channels": ["333"],
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("DISCORD_NO_THREAD_CHANNELS") == "333"
|
||||
|
||||
|
||||
def test_config_env_var_takes_precedence(monkeypatch, tmp_path):
|
||||
"""Env vars should take precedence over config.yaml values."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"discord": {
|
||||
"ignored_channels": ["111"],
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "999")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
# Env var should NOT be overwritten
|
||||
assert os.getenv("DISCORD_IGNORED_CHANNELS") == "999"
|
||||
277
tests/gateway/test_discord_reply_mode.py
Normal file
277
tests/gateway/test_discord_reply_mode.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Tests for Discord reply_to_mode functionality.
|
||||
|
||||
Covers the threading behavior control for multi-chunk replies:
|
||||
- "off": Never reply-reference to original message
|
||||
- "first": Only first chunk uses reply reference (default)
|
||||
- "all": All chunks reply-reference the original message
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig, GatewayConfig, Platform, _apply_env_overrides
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter_factory():
|
||||
"""Factory to create DiscordAdapter with custom reply_to_mode."""
|
||||
def create(reply_to_mode: str = "first"):
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
|
||||
return DiscordAdapter(config)
|
||||
return create
|
||||
|
||||
|
||||
class TestReplyToModeConfig:
|
||||
"""Tests for reply_to_mode configuration loading."""
|
||||
|
||||
def test_default_mode_is_first(self, adapter_factory):
|
||||
adapter = adapter_factory()
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
def test_off_mode(self, adapter_factory):
|
||||
adapter = adapter_factory(reply_to_mode="off")
|
||||
assert adapter._reply_to_mode == "off"
|
||||
|
||||
def test_first_mode(self, adapter_factory):
|
||||
adapter = adapter_factory(reply_to_mode="first")
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
def test_all_mode(self, adapter_factory):
|
||||
adapter = adapter_factory(reply_to_mode="all")
|
||||
assert adapter._reply_to_mode == "all"
|
||||
|
||||
def test_invalid_mode_stored_as_is(self, adapter_factory):
|
||||
"""Invalid modes are stored but send() handles them gracefully."""
|
||||
adapter = adapter_factory(reply_to_mode="invalid")
|
||||
assert adapter._reply_to_mode == "invalid"
|
||||
|
||||
def test_none_mode_defaults_to_first(self):
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
def test_empty_string_mode_defaults_to_first(self):
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode="")
|
||||
adapter = DiscordAdapter(config)
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
|
||||
def _make_discord_adapter(reply_to_mode: str = "first"):
|
||||
"""Create a DiscordAdapter with mocked client and channel for send() tests."""
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
|
||||
adapter = DiscordAdapter(config)
|
||||
|
||||
# Mock the Discord client and channel
|
||||
mock_channel = AsyncMock()
|
||||
ref_message = MagicMock()
|
||||
mock_channel.fetch_message = AsyncMock(return_value=ref_message)
|
||||
|
||||
sent_msg = MagicMock()
|
||||
sent_msg.id = 42
|
||||
mock_channel.send = AsyncMock(return_value=sent_msg)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
adapter._client = mock_client
|
||||
return adapter, mock_channel, ref_message
|
||||
|
||||
|
||||
class TestSendWithReplyToMode:
|
||||
"""Tests for send() method respecting reply_to_mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_off_mode_no_reply_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("off")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
# Should never try to fetch the reference message
|
||||
channel.fetch_message.assert_not_called()
|
||||
# All chunks sent without reference
|
||||
for call in channel.send.call_args_list:
|
||||
assert call.kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_mode_only_first_chunk_references(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("first")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
# Should fetch the reference message
|
||||
channel.fetch_message.assert_called_once_with(999)
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert calls[0].kwargs.get("reference") is ref_msg
|
||||
assert calls[1].kwargs.get("reference") is None
|
||||
assert calls[2].kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_mode_all_chunks_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("all")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
channel.fetch_message.assert_called_once_with(999)
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 3
|
||||
for call in calls:
|
||||
assert call.kwargs.get("reference") is ref_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reply_to_param_no_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("all")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to=None)
|
||||
|
||||
channel.fetch_message.assert_not_called()
|
||||
for call in channel.send.call_args_list:
|
||||
assert call.kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_respects_first_mode(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("first")
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0].kwargs.get("reference") is ref_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_off_mode(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("off")
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
channel.fetch_message.assert_not_called()
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0].kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_mode_falls_back_to_first_behavior(self):
|
||||
"""Invalid mode behaves like 'first' — only first chunk gets reference."""
|
||||
adapter, channel, ref_msg = _make_discord_adapter("banana")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 2
|
||||
assert calls[0].kwargs.get("reference") is ref_msg
|
||||
assert calls[1].kwargs.get("reference") is None
|
||||
|
||||
|
||||
class TestConfigSerialization:
|
||||
"""Tests for reply_to_mode serialization (shared with Telegram)."""
|
||||
|
||||
def test_to_dict_includes_reply_to_mode(self):
|
||||
config = PlatformConfig(enabled=True, token="test", reply_to_mode="all")
|
||||
result = config.to_dict()
|
||||
assert result["reply_to_mode"] == "all"
|
||||
|
||||
def test_from_dict_loads_reply_to_mode(self):
|
||||
data = {"enabled": True, "token": "***", "reply_to_mode": "off"}
|
||||
config = PlatformConfig.from_dict(data)
|
||||
assert config.reply_to_mode == "off"
|
||||
|
||||
def test_from_dict_defaults_to_first(self):
|
||||
data = {"enabled": True, "token": "***"}
|
||||
config = PlatformConfig.from_dict(data)
|
||||
assert config.reply_to_mode == "first"
|
||||
|
||||
|
||||
class TestEnvVarOverride:
|
||||
"""Tests for DISCORD_REPLY_TO_MODE environment variable override."""
|
||||
|
||||
def _make_config(self):
|
||||
config = GatewayConfig()
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig(enabled=True, token="test")
|
||||
return config
|
||||
|
||||
def test_env_var_sets_off_mode(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
|
||||
|
||||
def test_env_var_sets_all_mode(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "all"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
|
||||
|
||||
def test_env_var_case_insensitive(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "ALL"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
|
||||
|
||||
def test_env_var_invalid_value_ignored(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "banana"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
|
||||
|
||||
def test_env_var_empty_value_ignored(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": ""}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
|
||||
|
||||
def test_env_var_creates_platform_config_if_missing(self):
|
||||
"""DISCORD_REPLY_TO_MODE creates PlatformConfig even without DISCORD_BOT_TOKEN."""
|
||||
config = GatewayConfig()
|
||||
assert Platform.DISCORD not in config.platforms
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.DISCORD in config.platforms
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
|
||||
@@ -8,7 +8,7 @@ import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
try:
|
||||
import lark_oapi
|
||||
@@ -17,6 +17,18 @@ except ImportError:
|
||||
_HAS_LARK_OAPI = False
|
||||
|
||||
|
||||
def _mock_event_dispatcher_builder(mock_handler_class):
|
||||
mock_builder = Mock()
|
||||
mock_builder.register_p2_im_message_message_read_v1 = Mock(return_value=mock_builder)
|
||||
mock_builder.register_p2_im_message_receive_v1 = Mock(return_value=mock_builder)
|
||||
mock_builder.register_p2_im_message_reaction_created_v1 = Mock(return_value=mock_builder)
|
||||
mock_builder.register_p2_im_message_reaction_deleted_v1 = Mock(return_value=mock_builder)
|
||||
mock_builder.register_p2_card_action_trigger = Mock(return_value=mock_builder)
|
||||
mock_builder.build = Mock(return_value=object())
|
||||
mock_handler_class.builder = Mock(return_value=mock_builder)
|
||||
return mock_builder
|
||||
|
||||
|
||||
class TestPlatformEnum(unittest.TestCase):
|
||||
def test_feishu_in_platform_enum(self):
|
||||
from gateway.config import Platform
|
||||
@@ -262,12 +274,14 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
with (
|
||||
patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True),
|
||||
patch("gateway.platforms.feishu.FEISHU_WEBHOOK_AVAILABLE", True),
|
||||
patch("gateway.platforms.feishu.EventDispatcherHandler") as mock_handler_class,
|
||||
patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)),
|
||||
patch("gateway.platforms.feishu.release_scoped_lock"),
|
||||
patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()),
|
||||
patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()),
|
||||
patch("gateway.platforms.feishu.web", web_module),
|
||||
):
|
||||
_mock_event_dispatcher_builder(mock_handler_class)
|
||||
connected = asyncio.run(adapter.connect())
|
||||
|
||||
self.assertTrue(connected)
|
||||
@@ -283,13 +297,13 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
ws_client = object()
|
||||
ws_client = SimpleNamespace()
|
||||
|
||||
with (
|
||||
patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True),
|
||||
patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True),
|
||||
patch("gateway.platforms.feishu.lark", SimpleNamespace(LogLevel=SimpleNamespace(INFO="INFO", WARNING="WARNING"))),
|
||||
patch("gateway.platforms.feishu.EventDispatcherHandler", object()),
|
||||
patch("gateway.platforms.feishu.EventDispatcherHandler") as mock_handler_class,
|
||||
patch("gateway.platforms.feishu.FeishuWSClient", return_value=ws_client),
|
||||
patch("gateway.platforms.feishu._run_official_feishu_ws_client"),
|
||||
patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)) as acquire_lock,
|
||||
@@ -297,6 +311,8 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()),
|
||||
patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()),
|
||||
):
|
||||
_mock_event_dispatcher_builder(mock_handler_class)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
future = loop.create_future()
|
||||
future.set_result(None)
|
||||
@@ -305,6 +321,9 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
def run_in_executor(self, *_args, **_kwargs):
|
||||
return future
|
||||
|
||||
def is_closed(self):
|
||||
return False
|
||||
|
||||
try:
|
||||
with patch("gateway.platforms.feishu.asyncio.get_running_loop", return_value=_Loop()):
|
||||
connected = asyncio.run(adapter.connect())
|
||||
@@ -313,6 +332,7 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
loop.close()
|
||||
|
||||
self.assertTrue(connected)
|
||||
self.assertIsNone(adapter._event_handler)
|
||||
acquire_lock.assert_called_once_with(
|
||||
"feishu-app-id",
|
||||
"cli_app",
|
||||
@@ -354,14 +374,14 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
ws_client = object()
|
||||
ws_client = SimpleNamespace()
|
||||
sleeps = []
|
||||
|
||||
with (
|
||||
patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True),
|
||||
patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True),
|
||||
patch("gateway.platforms.feishu.lark", SimpleNamespace(LogLevel=SimpleNamespace(INFO="INFO", WARNING="WARNING"))),
|
||||
patch("gateway.platforms.feishu.EventDispatcherHandler", object()),
|
||||
patch("gateway.platforms.feishu.EventDispatcherHandler") as mock_handler_class,
|
||||
patch("gateway.platforms.feishu.FeishuWSClient", return_value=ws_client),
|
||||
patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)),
|
||||
patch("gateway.platforms.feishu.release_scoped_lock"),
|
||||
@@ -369,6 +389,8 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
patch("gateway.platforms.feishu.asyncio.sleep", side_effect=lambda delay: sleeps.append(delay)),
|
||||
patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()),
|
||||
):
|
||||
_mock_event_dispatcher_builder(mock_handler_class)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
future = loop.create_future()
|
||||
future.set_result(None)
|
||||
@@ -383,6 +405,9 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
||||
raise OSError("temporary websocket failure")
|
||||
return future
|
||||
|
||||
def is_closed(self):
|
||||
return False
|
||||
|
||||
fake_loop = _Loop()
|
||||
try:
|
||||
with patch("gateway.platforms.feishu.asyncio.get_running_loop", return_value=fake_loop):
|
||||
@@ -536,6 +561,113 @@ class TestAdapterModule(unittest.TestCase):
|
||||
self.assertIn("register_p2_im_message_reaction_deleted_v1", source)
|
||||
self.assertIn("register_p2_card_action_trigger", source)
|
||||
|
||||
def test_load_settings_uses_sdk_defaults_for_invalid_ws_reconnect_values(self):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
settings = FeishuAdapter._load_settings(
|
||||
{
|
||||
"ws_reconnect_nonce": -1,
|
||||
"ws_reconnect_interval": "bad",
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(settings.ws_reconnect_nonce, 30)
|
||||
self.assertEqual(settings.ws_reconnect_interval, 120)
|
||||
|
||||
def test_load_settings_accepts_custom_ws_reconnect_values(self):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
settings = FeishuAdapter._load_settings(
|
||||
{
|
||||
"ws_reconnect_nonce": 0,
|
||||
"ws_reconnect_interval": 3,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(settings.ws_reconnect_nonce, 0)
|
||||
self.assertEqual(settings.ws_reconnect_interval, 3)
|
||||
|
||||
def test_load_settings_accepts_custom_ws_ping_values(self):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
settings = FeishuAdapter._load_settings(
|
||||
{
|
||||
"ws_ping_interval": 10,
|
||||
"ws_ping_timeout": 8,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(settings.ws_ping_interval, 10)
|
||||
self.assertEqual(settings.ws_ping_timeout, 8)
|
||||
|
||||
def test_load_settings_ignores_invalid_ws_ping_values(self):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
settings = FeishuAdapter._load_settings(
|
||||
{
|
||||
"ws_ping_interval": 0,
|
||||
"ws_ping_timeout": -1,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertIsNone(settings.ws_ping_interval)
|
||||
self.assertIsNone(settings.ws_ping_timeout)
|
||||
|
||||
def test_runtime_ws_overrides_reapply_after_sdk_configure(self):
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
class _FakeWSClient:
|
||||
def __init__(self):
|
||||
self._reconnect_nonce = 30
|
||||
self._reconnect_interval = 120
|
||||
self._ping_interval = 120
|
||||
self.configure_calls = []
|
||||
|
||||
def _configure(self, conf):
|
||||
self.configure_calls.append(conf)
|
||||
self._reconnect_nonce = conf.ReconnectNonce
|
||||
self._reconnect_interval = conf.ReconnectInterval
|
||||
self._ping_interval = conf.PingInterval
|
||||
|
||||
def start(self):
|
||||
conf = SimpleNamespace(ReconnectNonce=99, ReconnectInterval=88, PingInterval=77)
|
||||
self._configure(conf)
|
||||
raise RuntimeError("stop test client")
|
||||
|
||||
fake_client = _FakeWSClient()
|
||||
fake_adapter = SimpleNamespace(
|
||||
_ws_thread_loop=None,
|
||||
_ws_reconnect_nonce=2,
|
||||
_ws_reconnect_interval=3,
|
||||
_ws_ping_interval=4,
|
||||
_ws_ping_timeout=5,
|
||||
)
|
||||
fake_client_module = ModuleType("lark_oapi.ws.client")
|
||||
fake_client_module.loop = None
|
||||
fake_client_module.websockets = SimpleNamespace(connect=AsyncMock())
|
||||
fake_ws_module = ModuleType("lark_oapi.ws")
|
||||
fake_ws_module.client = fake_client_module
|
||||
fake_root_module = ModuleType("lark_oapi")
|
||||
fake_root_module.ws = fake_ws_module
|
||||
|
||||
original_modules = sys.modules.copy()
|
||||
sys.modules["lark_oapi"] = fake_root_module
|
||||
sys.modules["lark_oapi.ws"] = fake_ws_module
|
||||
sys.modules["lark_oapi.ws.client"] = fake_client_module
|
||||
try:
|
||||
from gateway.platforms.feishu import _run_official_feishu_ws_client
|
||||
|
||||
_run_official_feishu_ws_client(fake_client, fake_adapter)
|
||||
finally:
|
||||
sys.modules.clear()
|
||||
sys.modules.update(original_modules)
|
||||
|
||||
self.assertEqual(len(fake_client.configure_calls), 1)
|
||||
self.assertEqual(fake_client._reconnect_nonce, 2)
|
||||
self.assertEqual(fake_client._reconnect_interval, 3)
|
||||
self.assertEqual(fake_client._ping_interval, 4)
|
||||
|
||||
|
||||
class TestAdapterBehavior(unittest.TestCase):
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
@@ -690,10 +822,10 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
message = SimpleNamespace(mentions=[])
|
||||
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
|
||||
self.assertFalse(adapter._should_accept_group_message(message, sender_id))
|
||||
self.assertFalse(adapter._should_accept_group_message(message, sender_id, ""))
|
||||
|
||||
message_with_mention = SimpleNamespace(mentions=[SimpleNamespace(key="@_user_1")])
|
||||
self.assertFalse(adapter._should_accept_group_message(message_with_mention, sender_id))
|
||||
self.assertFalse(adapter._should_accept_group_message(message_with_mention, sender_id, ""))
|
||||
|
||||
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
|
||||
def test_group_message_with_other_user_mention_is_rejected_when_bot_identity_unknown(self):
|
||||
@@ -707,7 +839,7 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
id=SimpleNamespace(open_id="ou_other", user_id="u_other"),
|
||||
)
|
||||
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id))
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id, ""))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
@@ -736,28 +868,222 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
adapter._should_accept_group_message(
|
||||
mentioned,
|
||||
SimpleNamespace(open_id="ou_allowed", user_id=None),
|
||||
"",
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
adapter._should_accept_group_message(
|
||||
mentioned,
|
||||
SimpleNamespace(open_id="ou_blocked", user_id=None),
|
||||
"",
|
||||
)
|
||||
)
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FEISHU_GROUP_POLICY": "open",
|
||||
"FEISHU_BOT_OPEN_ID": "ou_bot",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_per_group_allowlist_policy_gates_by_sender(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
extra={
|
||||
"group_rules": {
|
||||
"oc_chat_a": {
|
||||
"policy": "allowlist",
|
||||
"allowlist": ["ou_alice", "ou_bob"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
|
||||
message = SimpleNamespace(
|
||||
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_alice", user_id=None),
|
||||
"oc_chat_a",
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_charlie", user_id=None),
|
||||
"oc_chat_a",
|
||||
)
|
||||
)
|
||||
|
||||
def test_per_group_blacklist_policy_blocks_specific_users(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
extra={
|
||||
"group_rules": {
|
||||
"oc_chat_b": {
|
||||
"policy": "blacklist",
|
||||
"blacklist": ["ou_blocked"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
|
||||
message = SimpleNamespace(
|
||||
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_alice", user_id=None),
|
||||
"oc_chat_b",
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_blocked", user_id=None),
|
||||
"oc_chat_b",
|
||||
)
|
||||
)
|
||||
|
||||
def test_per_group_admin_only_policy_requires_admin(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
extra={
|
||||
"admins": ["ou_admin"],
|
||||
"group_rules": {
|
||||
"oc_chat_c": {
|
||||
"policy": "admin_only",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
|
||||
message = SimpleNamespace(
|
||||
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_admin", user_id=None),
|
||||
"oc_chat_c",
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_regular", user_id=None),
|
||||
"oc_chat_c",
|
||||
)
|
||||
)
|
||||
|
||||
def test_per_group_disabled_policy_blocks_all(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
extra={
|
||||
"admins": ["ou_admin"],
|
||||
"group_rules": {
|
||||
"oc_chat_d": {
|
||||
"policy": "disabled",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
|
||||
message = SimpleNamespace(
|
||||
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_admin", user_id=None),
|
||||
"oc_chat_d",
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_regular", user_id=None),
|
||||
"oc_chat_d",
|
||||
)
|
||||
)
|
||||
|
||||
def test_global_admins_bypass_all_group_rules(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
extra={
|
||||
"admins": ["ou_admin"],
|
||||
"group_rules": {
|
||||
"oc_chat_e": {
|
||||
"policy": "allowlist",
|
||||
"allowlist": ["ou_alice"],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
|
||||
message = SimpleNamespace(
|
||||
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_admin", user_id=None),
|
||||
"oc_chat_e",
|
||||
)
|
||||
)
|
||||
|
||||
def test_default_group_policy_fallback_for_chats_without_explicit_rule(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
extra={
|
||||
"default_group_policy": "open",
|
||||
}
|
||||
)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
|
||||
message = SimpleNamespace(
|
||||
mentions=[SimpleNamespace(name="Bot", id=SimpleNamespace(open_id="ou_bot", user_id=None))]
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
adapter._should_accept_group_message(
|
||||
message,
|
||||
SimpleNamespace(open_id="ou_anyone", user_id=None),
|
||||
"oc_chat_unknown",
|
||||
)
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
|
||||
def test_group_message_matches_bot_open_id_when_configured(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
|
||||
|
||||
bot_mention = SimpleNamespace(
|
||||
@@ -769,22 +1095,16 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
id=SimpleNamespace(open_id="ou_other", user_id="u_other"),
|
||||
)
|
||||
|
||||
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[bot_mention]), sender_id))
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id))
|
||||
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[bot_mention]), sender_id, ""))
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id, ""))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FEISHU_GROUP_POLICY": "open",
|
||||
"FEISHU_BOT_NAME": "Hermes Bot",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
|
||||
def test_group_message_matches_bot_name_when_only_name_available(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._bot_name = "Hermes Bot"
|
||||
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
|
||||
|
||||
named_mention = SimpleNamespace(
|
||||
@@ -796,22 +1116,16 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
id=SimpleNamespace(open_id="ou_other", user_id="u_other"),
|
||||
)
|
||||
|
||||
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[named_mention]), sender_id))
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[different_mention]), sender_id))
|
||||
self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[named_mention]), sender_id, ""))
|
||||
self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[different_mention]), sender_id, ""))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FEISHU_GROUP_POLICY": "open",
|
||||
"FEISHU_BOT_OPEN_ID": "ou_bot",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True)
|
||||
def test_group_post_message_uses_parsed_mentions_when_sdk_mentions_missing(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._bot_open_id = "ou_bot"
|
||||
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
|
||||
message = SimpleNamespace(
|
||||
message_type="post",
|
||||
@@ -819,7 +1133,7 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
content='{"en_us":{"content":[[{"tag":"at","user_name":"Hermes","open_id":"ou_bot"}]]}}',
|
||||
)
|
||||
|
||||
self.assertTrue(adapter._should_accept_group_message(message, sender_id))
|
||||
self.assertTrue(adapter._should_accept_group_message(message, sender_id, ""))
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_extract_post_message_as_text(self):
|
||||
@@ -1196,7 +1510,12 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._loop = object()
|
||||
|
||||
class _Loop:
|
||||
def is_closed(self):
|
||||
return False
|
||||
|
||||
adapter._loop = _Loop()
|
||||
|
||||
message = SimpleNamespace(
|
||||
message_id="om_text",
|
||||
@@ -1210,6 +1529,7 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
data = SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
|
||||
|
||||
future = SimpleNamespace(add_done_callback=lambda *_args, **_kwargs: None)
|
||||
|
||||
def _submit(coro, _loop):
|
||||
coro.close()
|
||||
return future
|
||||
@@ -1219,6 +1539,30 @@ class TestAdapterBehavior(unittest.TestCase):
|
||||
|
||||
self.assertTrue(submit.called)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_webhook_request_uses_same_message_dispatch_path(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
adapter._on_message_event = Mock()
|
||||
|
||||
body = json.dumps({
|
||||
"header": {"event_type": "im.message.receive_v1"},
|
||||
"event": {"message": {"message_id": "om_test"}},
|
||||
}).encode("utf-8")
|
||||
request = SimpleNamespace(
|
||||
remote="127.0.0.1",
|
||||
content_length=None,
|
||||
headers={},
|
||||
read=AsyncMock(return_value=body),
|
||||
)
|
||||
|
||||
response = asyncio.run(adapter._handle_webhook_request(request))
|
||||
|
||||
self.assertEqual(response.status, 200)
|
||||
adapter._on_message_event.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_process_inbound_message_uses_event_sender_identity_only(self):
|
||||
from gateway.config import PlatformConfig
|
||||
@@ -2456,7 +2800,7 @@ class TestGroupMentionAtAll(unittest.TestCase):
|
||||
mentions=[],
|
||||
)
|
||||
sender_id = SimpleNamespace(open_id="ou_any", user_id=None)
|
||||
self.assertTrue(adapter._should_accept_group_message(message, sender_id))
|
||||
self.assertTrue(adapter._should_accept_group_message(message, sender_id, ""))
|
||||
|
||||
@patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "allowlist", "FEISHU_ALLOWED_USERS": "ou_allowed"}, clear=True)
|
||||
def test_at_all_still_requires_policy_gate(self):
|
||||
@@ -2468,10 +2812,10 @@ class TestGroupMentionAtAll(unittest.TestCase):
|
||||
message = SimpleNamespace(content='{"text":"@_all attention"}', mentions=[])
|
||||
# Non-allowlisted user — should be blocked even with @_all.
|
||||
blocked_sender = SimpleNamespace(open_id="ou_blocked", user_id=None)
|
||||
self.assertFalse(adapter._should_accept_group_message(message, blocked_sender))
|
||||
self.assertFalse(adapter._should_accept_group_message(message, blocked_sender, ""))
|
||||
# Allowlisted user — should pass.
|
||||
allowed_sender = SimpleNamespace(open_id="ou_allowed", user_id=None)
|
||||
self.assertTrue(adapter._should_accept_group_message(message, allowed_sender))
|
||||
self.assertTrue(adapter._should_accept_group_message(message, allowed_sender, ""))
|
||||
|
||||
|
||||
@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed")
|
||||
|
||||
432
tests/gateway/test_feishu_approval_buttons.py
Normal file
432
tests/gateway/test_feishu_approval_buttons.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""Tests for Feishu interactive card approval buttons."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the repo root is importable
|
||||
# ---------------------------------------------------------------------------
|
||||
_repo = str(Path(__file__).resolve().parents[2])
|
||||
if _repo not in sys.path:
|
||||
sys.path.insert(0, _repo)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal Feishu mock so FeishuAdapter can be imported without lark-oapi
|
||||
# ---------------------------------------------------------------------------
|
||||
def _ensure_feishu_mocks():
|
||||
"""Provide stubs for lark-oapi / aiohttp.web so the import succeeds."""
|
||||
if "lark_oapi" not in sys.modules:
|
||||
mod = MagicMock()
|
||||
for name in (
|
||||
"lark_oapi", "lark_oapi.api.im.v1",
|
||||
"lark_oapi.event", "lark_oapi.event.callback_type",
|
||||
):
|
||||
sys.modules.setdefault(name, mod)
|
||||
if "aiohttp" not in sys.modules:
|
||||
aio = MagicMock()
|
||||
sys.modules.setdefault("aiohttp", aio)
|
||||
sys.modules.setdefault("aiohttp.web", aio.web)
|
||||
|
||||
|
||||
_ensure_feishu_mocks()
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter() -> FeishuAdapter:
|
||||
"""Create a FeishuAdapter with mocked internals."""
|
||||
config = PlatformConfig(enabled=True)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._client = MagicMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_card_action_data(
|
||||
action_value: dict,
|
||||
chat_id: str = "oc_12345",
|
||||
open_id: str = "ou_user1",
|
||||
token: str = "tok_abc",
|
||||
) -> SimpleNamespace:
|
||||
"""Create a mock Feishu card action callback data object."""
|
||||
return SimpleNamespace(
|
||||
event=SimpleNamespace(
|
||||
token=token,
|
||||
context=SimpleNamespace(open_chat_id=chat_id),
|
||||
operator=SimpleNamespace(open_id=open_id),
|
||||
action=SimpleNamespace(
|
||||
tag="button",
|
||||
value=action_value,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# send_exec_approval — interactive card with buttons
|
||||
# ===========================================================================
|
||||
|
||||
class TestFeishuExecApproval:
|
||||
"""Test send_exec_approval sends an interactive card."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_interactive_card(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_001"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_send:
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="oc_12345",
|
||||
command="rm -rf /important",
|
||||
session_key="agent:main:feishu:group:oc_12345",
|
||||
description="dangerous deletion",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "msg_001"
|
||||
|
||||
mock_send.assert_called_once()
|
||||
kwargs = mock_send.call_args[1]
|
||||
assert kwargs["chat_id"] == "oc_12345"
|
||||
assert kwargs["msg_type"] == "interactive"
|
||||
|
||||
# Verify card payload contains the command and buttons
|
||||
card = json.loads(kwargs["payload"])
|
||||
assert card["header"]["template"] == "orange"
|
||||
assert "rm -rf /important" in card["elements"][0]["content"]
|
||||
assert "dangerous deletion" in card["elements"][0]["content"]
|
||||
|
||||
# Check buttons
|
||||
actions = card["elements"][1]["actions"]
|
||||
assert len(actions) == 4
|
||||
action_names = [a["value"]["hermes_action"] for a in actions]
|
||||
assert action_names == [
|
||||
"approve_once", "approve_session", "approve_always", "deny"
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_approval_state(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_002"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_12345",
|
||||
command="echo test",
|
||||
session_key="my-session-key",
|
||||
)
|
||||
|
||||
assert len(adapter._approval_state) == 1
|
||||
approval_id = list(adapter._approval_state.keys())[0]
|
||||
state = adapter._approval_state[approval_id]
|
||||
assert state["session_key"] == "my-session-key"
|
||||
assert state["message_id"] == "msg_002"
|
||||
assert state["chat_id"] == "oc_12345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_connected(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._client = None
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="oc_12345", command="ls", session_key="s"
|
||||
)
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_long_command(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_003"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_send:
|
||||
long_cmd = "x" * 5000
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_12345", command=long_cmd, session_key="s"
|
||||
)
|
||||
|
||||
card = json.loads(mock_send.call_args[1]["payload"])
|
||||
content = card["elements"][0]["content"]
|
||||
assert "..." in content
|
||||
assert len(content) < 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_approvals_get_unique_ids(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_x"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_1", command="cmd1", session_key="s1"
|
||||
)
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_2", command="cmd2", session_key="s2"
|
||||
)
|
||||
|
||||
assert len(adapter._approval_state) == 2
|
||||
ids = list(adapter._approval_state.keys())
|
||||
assert ids[0] != ids[1]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _handle_card_action_event — approval button clicks
|
||||
# ===========================================================================
|
||||
|
||||
class TestFeishuApprovalCallback:
|
||||
"""Test the approval intercept in _handle_card_action_event."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_approval_on_click(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[1] = {
|
||||
"session_key": "agent:main:feishu:group:oc_12345",
|
||||
"message_id": "msg_001",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_once", "approval_id": 1},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_user1", "user_name": "Norbert", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("agent:main:feishu:group:oc_12345", "once")
|
||||
mock_update.assert_called_once_with("msg_001", "Approved once", "Norbert", "once")
|
||||
|
||||
# State should be cleaned up
|
||||
assert 1 not in adapter._approval_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_button(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[2] = {
|
||||
"session_key": "some-session",
|
||||
"message_id": "msg_002",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "deny", "approval_id": 2},
|
||||
token="tok_deny",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_alice", "user_name": "Alice", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("some-session", "deny")
|
||||
mock_update.assert_called_once_with("msg_002", "Denied", "Alice", "deny")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_approval(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[3] = {
|
||||
"session_key": "sess-3",
|
||||
"message_id": "msg_003",
|
||||
"chat_id": "oc_99",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_session", "approval_id": 3},
|
||||
token="tok_ses",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_u", "user_name": "Bob", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("sess-3", "session")
|
||||
mock_update.assert_called_once_with("msg_003", "Approved for session", "Bob", "session")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_approval(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[4] = {
|
||||
"session_key": "sess-4",
|
||||
"message_id": "msg_004",
|
||||
"chat_id": "oc_55",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_always", "approval_id": 4},
|
||||
token="tok_alw",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_u", "user_name": "Carol", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock),
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("sess-4", "always")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_resolved_drops_silently(self):
|
||||
adapter = _make_adapter()
|
||||
# No state for approval_id 99 — already resolved
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_once", "approval_id": 99},
|
||||
token="tok_gone",
|
||||
)
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
# Should NOT resolve — already handled
|
||||
mock_resolve.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_approval_actions_route_normally(self):
|
||||
"""Non-approval card actions should still become synthetic commands."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"custom_action": "something_else"},
|
||||
token="tok_normal",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_u", "user_name": "Dave", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "get_chat_info", new_callable=AsyncMock, return_value={"name": "Test Chat"}),
|
||||
patch.object(adapter, "_handle_message_with_guards", new_callable=AsyncMock) as mock_handle,
|
||||
patch("tools.approval.resolve_gateway_approval") as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
# Should NOT resolve any approval
|
||||
mock_resolve.assert_not_called()
|
||||
# Should have routed as synthetic command
|
||||
mock_handle.assert_called_once()
|
||||
event = mock_handle.call_args[0][0]
|
||||
assert "/card button" in event.text
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _update_approval_card — card replacement after resolution
|
||||
# ===========================================================================
|
||||
|
||||
class TestFeishuUpdateApprovalCard:
|
||||
"""Test the card update after approval resolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_card_on_approve(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_update = AsyncMock()
|
||||
adapter._client.im.v1.message.update = MagicMock()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"msg_001", "Approved once", "Norbert", "once"
|
||||
)
|
||||
|
||||
mock_thread.assert_called_once()
|
||||
# Verify the update request was built
|
||||
call_args = mock_thread.call_args
|
||||
assert call_args[0][0] == adapter._client.im.v1.message.update
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_card_on_deny(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"msg_002", "Denied", "Alice", "deny"
|
||||
)
|
||||
|
||||
mock_thread.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_update_when_not_connected(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._client = None
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"msg_001", "Approved", "Bob", "once"
|
||||
)
|
||||
|
||||
mock_thread.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_update_when_no_message_id(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"", "Approved", "Bob", "once"
|
||||
)
|
||||
|
||||
mock_thread.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_update_errors(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("API error")):
|
||||
# Should not raise
|
||||
await adapter._update_approval_card(
|
||||
"msg_001", "Approved", "Bob", "once"
|
||||
)
|
||||
@@ -2,12 +2,54 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_fake_nio():
|
||||
"""Create a lightweight fake ``nio`` module with real response classes.
|
||||
|
||||
Tests that call production methods doing ``import nio`` / ``isinstance(resp, nio.XxxResponse)``
|
||||
need real classes (not MagicMock auto-attributes) to satisfy isinstance checks.
|
||||
Use via ``patch.dict("sys.modules", {"nio": _make_fake_nio()})``.
|
||||
"""
|
||||
mod = types.ModuleType("nio")
|
||||
|
||||
class RoomSendResponse:
|
||||
def __init__(self, event_id="$fake"):
|
||||
self.event_id = event_id
|
||||
|
||||
class RoomRedactResponse:
|
||||
pass
|
||||
|
||||
class RoomCreateResponse:
|
||||
def __init__(self, room_id="!fake:example.org"):
|
||||
self.room_id = room_id
|
||||
|
||||
class RoomInviteResponse:
|
||||
pass
|
||||
|
||||
class UploadResponse:
|
||||
def __init__(self, content_uri="mxc://example.org/fake"):
|
||||
self.content_uri = content_uri
|
||||
|
||||
# Minimal Api stub for code that checks nio.Api.RoomPreset
|
||||
class _Api:
|
||||
pass
|
||||
mod.Api = _Api
|
||||
|
||||
mod.RoomSendResponse = RoomSendResponse
|
||||
mod.RoomRedactResponse = RoomRedactResponse
|
||||
mod.RoomCreateResponse = RoomCreateResponse
|
||||
mod.RoomInviteResponse = RoomInviteResponse
|
||||
mod.UploadResponse = UploadResponse
|
||||
return mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform & Config
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -428,6 +470,7 @@ class TestMatrixRequirements:
|
||||
def test_check_requirements_with_token(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
try:
|
||||
import nio # noqa: F401
|
||||
@@ -448,6 +491,45 @@ class TestMatrixRequirements:
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
assert check_matrix_requirements() is False
|
||||
|
||||
def test_check_requirements_encryption_true_no_e2ee_deps(self, monkeypatch):
|
||||
"""MATRIX_ENCRYPTION=true should fail if python-olm is not installed."""
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False):
|
||||
assert matrix_mod.check_matrix_requirements() is False
|
||||
|
||||
def test_check_requirements_encryption_false_no_e2ee_deps_ok(self, monkeypatch):
|
||||
"""Without encryption, missing E2EE deps should not block startup."""
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False):
|
||||
# Still needs nio itself to be importable
|
||||
try:
|
||||
import nio # noqa: F401
|
||||
assert matrix_mod.check_matrix_requirements() is True
|
||||
except ImportError:
|
||||
assert matrix_mod.check_matrix_requirements() is False
|
||||
|
||||
def test_check_requirements_encryption_true_with_e2ee_deps(self, monkeypatch):
|
||||
"""MATRIX_ENCRYPTION=true should pass if E2EE deps are available."""
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
|
||||
try:
|
||||
import nio # noqa: F401
|
||||
assert matrix_mod.check_matrix_requirements() is True
|
||||
except ImportError:
|
||||
assert matrix_mod.check_matrix_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access-token auth / E2EE bootstrap
|
||||
@@ -516,10 +598,12 @@ class TestMatrixAccessTokenAuth:
|
||||
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
|
||||
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
|
||||
fake_client.restore_login.assert_called_once_with(
|
||||
"@bot:example.org", "DEV123", "syt_test_access_token"
|
||||
@@ -532,6 +616,326 @@ class TestMatrixAccessTokenAuth:
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
class TestMatrixE2EEHardFail:
|
||||
"""connect() must refuse to start when E2EE is requested but deps are missing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_when_encryption_true_but_no_e2ee_deps(self):
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"encryption": True,
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock()
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False):
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_when_olm_not_loaded_after_login(self):
|
||||
"""Even if _check_e2ee_deps passes, if olm is None after auth, hard-fail."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"encryption": True,
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
class FakeWhoamiResponse:
|
||||
def __init__(self, user_id, device_id):
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
|
||||
fake_client.close = AsyncMock()
|
||||
# olm is None — crypto store not loaded
|
||||
fake_client.olm = None
|
||||
fake_client.should_upload_keys = False
|
||||
|
||||
def _restore_login(user_id, device_id, access_token):
|
||||
fake_client.user_id = user_id
|
||||
fake_client.device_id = device_id
|
||||
fake_client.access_token = access_token
|
||||
|
||||
fake_client.restore_login = MagicMock(side_effect=_restore_login)
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
|
||||
fake_nio.WhoamiResponse = FakeWhoamiResponse
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
fake_client.close.assert_awaited_once()
|
||||
|
||||
|
||||
class TestMatrixDeviceId:
|
||||
"""MATRIX_DEVICE_ID should be used for stable device identity."""
|
||||
|
||||
def test_device_id_from_config_extra(self):
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"device_id": "HERMES_BOT_STABLE",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
assert adapter._device_id == "HERMES_BOT_STABLE"
|
||||
|
||||
def test_device_id_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_DEVICE_ID", "FROM_ENV")
|
||||
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
assert adapter._device_id == "FROM_ENV"
|
||||
|
||||
def test_device_id_config_takes_precedence_over_env(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_DEVICE_ID", "FROM_ENV")
|
||||
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"device_id": "FROM_CONFIG",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
assert adapter._device_id == "FROM_CONFIG"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_uses_configured_device_id_over_whoami(self):
|
||||
"""When MATRIX_DEVICE_ID is set, it should be used instead of whoami device_id."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"encryption": True,
|
||||
"device_id": "MY_STABLE_DEVICE",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
class FakeWhoamiResponse:
|
||||
def __init__(self, user_id, device_id):
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
|
||||
class FakeSyncResponse:
|
||||
def __init__(self):
|
||||
self.rooms = MagicMock(join={})
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "WHOAMI_DEV"))
|
||||
fake_client.sync = AsyncMock(return_value=FakeSyncResponse())
|
||||
fake_client.keys_upload = AsyncMock()
|
||||
fake_client.keys_query = AsyncMock()
|
||||
fake_client.keys_claim = AsyncMock()
|
||||
fake_client.send_to_device_messages = AsyncMock(return_value=[])
|
||||
fake_client.get_users_for_key_claiming = MagicMock(return_value={})
|
||||
fake_client.close = AsyncMock()
|
||||
fake_client.add_event_callback = MagicMock()
|
||||
fake_client.rooms = {}
|
||||
fake_client.account_data = {}
|
||||
fake_client.olm = object()
|
||||
fake_client.should_upload_keys = False
|
||||
fake_client.should_query_keys = False
|
||||
fake_client.should_claim_keys = False
|
||||
|
||||
def _restore_login(user_id, device_id, access_token):
|
||||
fake_client.user_id = user_id
|
||||
fake_client.device_id = device_id
|
||||
fake_client.access_token = access_token
|
||||
|
||||
fake_client.restore_login = MagicMock(side_effect=_restore_login)
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
|
||||
fake_nio.WhoamiResponse = FakeWhoamiResponse
|
||||
fake_nio.SyncResponse = FakeSyncResponse
|
||||
fake_nio.LoginResponse = type("LoginResponse", (), {})
|
||||
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
|
||||
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
|
||||
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
|
||||
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
|
||||
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
|
||||
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
|
||||
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
|
||||
# The configured device_id should override the whoami device_id
|
||||
fake_client.restore_login.assert_called_once_with(
|
||||
"@bot:example.org", "MY_STABLE_DEVICE", "syt_test_access_token"
|
||||
)
|
||||
assert fake_client.device_id == "MY_STABLE_DEVICE"
|
||||
|
||||
# Verify device_id was passed to nio.AsyncClient constructor
|
||||
ctor_call = fake_nio.AsyncClient.call_args
|
||||
assert ctor_call.kwargs.get("device_id") == "MY_STABLE_DEVICE"
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
class TestMatrixE2EEClientConstructorFailure:
|
||||
"""connect() should hard-fail if nio.AsyncClient() raises when encryption is on."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_when_e2ee_client_constructor_raises(self):
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"encryption": True,
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock(side_effect=Exception("olm init failed"))
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMatrixPasswordLoginDeviceId:
|
||||
"""MATRIX_DEVICE_ID should be passed to nio.AsyncClient even with password login."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_login_passes_device_id_to_constructor(self):
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"password": "secret",
|
||||
"device_id": "STABLE_PW_DEVICE",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
class FakeLoginResponse:
|
||||
pass
|
||||
|
||||
class FakeSyncResponse:
|
||||
def __init__(self):
|
||||
self.rooms = MagicMock(join={})
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.login = AsyncMock(return_value=FakeLoginResponse())
|
||||
fake_client.sync = AsyncMock(return_value=FakeSyncResponse())
|
||||
fake_client.close = AsyncMock()
|
||||
fake_client.add_event_callback = MagicMock()
|
||||
fake_client.rooms = {}
|
||||
fake_client.account_data = {}
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
|
||||
fake_nio.LoginResponse = FakeLoginResponse
|
||||
fake_nio.SyncResponse = FakeSyncResponse
|
||||
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
|
||||
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
|
||||
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
|
||||
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
|
||||
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
|
||||
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
|
||||
# Verify device_id was passed to the nio.AsyncClient constructor
|
||||
ctor_call = fake_nio.AsyncClient.call_args
|
||||
assert ctor_call.kwargs.get("device_id") == "STABLE_PW_DEVICE"
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
class TestMatrixDeviceIdConfig:
|
||||
"""MATRIX_DEVICE_ID should be plumbed through gateway config."""
|
||||
|
||||
def test_device_id_in_config_extra(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_DEVICE_ID", "HERMES_BOT")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert mc.extra.get("device_id") == "HERMES_BOT"
|
||||
|
||||
def test_device_id_not_set_when_env_empty(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.delenv("MATRIX_DEVICE_ID", raising=False)
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert "device_id" not in mc.extra
|
||||
|
||||
|
||||
class TestMatrixE2EEMaintenance:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_loop_runs_e2ee_maintenance_requests(self):
|
||||
@@ -1071,10 +1475,12 @@ class TestMatrixEncryptedMedia:
|
||||
fake_nio.InviteMemberEvent = FakeInviteMemberEvent
|
||||
fake_nio.MegolmEvent = FakeMegolmEvent
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True):
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
|
||||
callback_classes = [call.args[1] for call in fake_client.add_event_callback.call_args_list]
|
||||
assert FakeRoomEncryptedImage in callback_classes
|
||||
@@ -1086,7 +1492,10 @@ class TestMatrixEncryptedMedia:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_room_message_media_decrypts_encrypted_image_and_passes_local_path(self):
|
||||
from nio.crypto.attachments import encrypt_attachment
|
||||
try:
|
||||
from nio.crypto.attachments import encrypt_attachment
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip("matrix-nio[e2e] required for encryption tests")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
@@ -1154,7 +1563,10 @@ class TestMatrixEncryptedMedia:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_room_message_media_decrypts_encrypted_voice_and_caches_audio(self):
|
||||
from nio.crypto.attachments import encrypt_attachment
|
||||
try:
|
||||
from nio.crypto.attachments import encrypt_attachment
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip("matrix-nio[e2e] required for encryption tests")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
@@ -1223,7 +1635,10 @@ class TestMatrixEncryptedMedia:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_room_message_media_decrypts_encrypted_file_and_caches_document(self):
|
||||
from nio.crypto.attachments import encrypt_attachment
|
||||
try:
|
||||
from nio.crypto.attachments import encrypt_attachment
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip("matrix-nio[e2e] required for encryption tests")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
@@ -1519,14 +1934,15 @@ class TestMatrixReactions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reaction(self):
|
||||
"""_send_reaction should call room_send with m.reaction."""
|
||||
nio = pytest.importorskip("nio")
|
||||
fake_nio = _make_fake_nio()
|
||||
mock_client = MagicMock()
|
||||
mock_client.room_send = AsyncMock(
|
||||
return_value=MagicMock(spec=nio.RoomSendResponse)
|
||||
return_value=fake_nio.RoomSendResponse("$reaction1")
|
||||
)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
result = await self.adapter._send_reaction("!room:ex", "$event1", "👍")
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await self.adapter._send_reaction("!room:ex", "$event1", "👍")
|
||||
assert result is True
|
||||
mock_client.room_send.assert_called_once()
|
||||
args = mock_client.room_send.call_args
|
||||
@@ -1538,7 +1954,8 @@ class TestMatrixReactions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reaction_no_client(self):
|
||||
self.adapter._client = None
|
||||
result = await self.adapter._send_reaction("!room:ex", "$ev", "👍")
|
||||
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
|
||||
result = await self.adapter._send_reaction("!room:ex", "$ev", "👍")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1635,21 +2052,23 @@ class TestMatrixRedaction:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redact_message(self):
|
||||
nio = pytest.importorskip("nio")
|
||||
fake_nio = _make_fake_nio()
|
||||
mock_client = MagicMock()
|
||||
mock_client.room_redact = AsyncMock(
|
||||
return_value=MagicMock(spec=nio.RoomRedactResponse)
|
||||
return_value=fake_nio.RoomRedactResponse()
|
||||
)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
result = await self.adapter.redact_message("!room:ex", "$ev1", "oops")
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await self.adapter.redact_message("!room:ex", "$ev1", "oops")
|
||||
assert result is True
|
||||
mock_client.room_redact.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redact_no_client(self):
|
||||
self.adapter._client = None
|
||||
result = await self.adapter.redact_message("!room:ex", "$ev1")
|
||||
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
|
||||
result = await self.adapter.redact_message("!room:ex", "$ev1")
|
||||
assert result is False
|
||||
|
||||
|
||||
@@ -1663,33 +2082,35 @@ class TestMatrixRoomManagement:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_room(self):
|
||||
nio = pytest.importorskip("nio")
|
||||
mock_resp = MagicMock(spec=nio.RoomCreateResponse)
|
||||
mock_resp.room_id = "!new:example.org"
|
||||
fake_nio = _make_fake_nio()
|
||||
mock_resp = fake_nio.RoomCreateResponse(room_id="!new:example.org")
|
||||
mock_client = MagicMock()
|
||||
mock_client.room_create = AsyncMock(return_value=mock_resp)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
room_id = await self.adapter.create_room(name="Test Room", topic="A test")
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
room_id = await self.adapter.create_room(name="Test Room", topic="A test")
|
||||
assert room_id == "!new:example.org"
|
||||
assert "!new:example.org" in self.adapter._joined_rooms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_user(self):
|
||||
nio = pytest.importorskip("nio")
|
||||
fake_nio = _make_fake_nio()
|
||||
mock_client = MagicMock()
|
||||
mock_client.room_invite = AsyncMock(
|
||||
return_value=MagicMock(spec=nio.RoomInviteResponse)
|
||||
return_value=fake_nio.RoomInviteResponse()
|
||||
)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
result = await self.adapter.invite_user("!room:ex", "@user:ex")
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await self.adapter.invite_user("!room:ex", "@user:ex")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_room_no_client(self):
|
||||
self.adapter._client = None
|
||||
result = await self.adapter.create_room()
|
||||
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
|
||||
result = await self.adapter.create_room()
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -1735,28 +2156,28 @@ class TestMatrixMessageTypes:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_emote(self):
|
||||
nio = pytest.importorskip("nio")
|
||||
fake_nio = _make_fake_nio()
|
||||
mock_client = MagicMock()
|
||||
mock_resp = MagicMock(spec=nio.RoomSendResponse)
|
||||
mock_resp.event_id = "$emote1"
|
||||
mock_resp = fake_nio.RoomSendResponse(event_id="$emote1")
|
||||
mock_client.room_send = AsyncMock(return_value=mock_resp)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
result = await self.adapter.send_emote("!room:ex", "waves hello")
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await self.adapter.send_emote("!room:ex", "waves hello")
|
||||
assert result.success is True
|
||||
call_args = mock_client.room_send.call_args[0]
|
||||
assert call_args[2]["msgtype"] == "m.emote"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_notice(self):
|
||||
nio = pytest.importorskip("nio")
|
||||
fake_nio = _make_fake_nio()
|
||||
mock_client = MagicMock()
|
||||
mock_resp = MagicMock(spec=nio.RoomSendResponse)
|
||||
mock_resp.event_id = "$notice1"
|
||||
mock_resp = fake_nio.RoomSendResponse(event_id="$notice1")
|
||||
mock_client.room_send = AsyncMock(return_value=mock_resp)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
result = await self.adapter.send_notice("!room:ex", "System message")
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await self.adapter.send_notice("!room:ex", "System message")
|
||||
assert result.success is True
|
||||
call_args = mock_client.room_send.call_args[0]
|
||||
assert call_args[2]["msgtype"] == "m.notice"
|
||||
@@ -1764,5 +2185,6 @@ class TestMatrixMessageTypes:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_emote_empty_text(self):
|
||||
self.adapter._client = MagicMock()
|
||||
result = await self.adapter.send_emote("!room:ex", "")
|
||||
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
|
||||
result = await self.adapter.send_emote("!room:ex", "")
|
||||
assert result.success is False
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
"""Tests for Matrix voice message support (MSC3245)."""
|
||||
import io
|
||||
import types
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
nio = pytest.importorskip("nio", reason="matrix-nio not installed")
|
||||
# Try importing real nio; skip entire file if not available.
|
||||
# A MagicMock in sys.modules (from another test) is not the real package.
|
||||
try:
|
||||
import nio as _nio_probe
|
||||
if not isinstance(_nio_probe, types.ModuleType) or not hasattr(_nio_probe, "__file__"):
|
||||
pytest.skip("nio in sys.modules is a mock, not the real package", allow_module_level=True)
|
||||
except ImportError:
|
||||
pytest.skip("matrix-nio not installed", allow_module_level=True)
|
||||
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
|
||||
@@ -504,7 +504,8 @@ class TestMattermostFileUpload:
|
||||
self.adapter._session = MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_downloads_and_uploads(self):
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
async def test_send_image_downloads_and_uploads(self, _mock_safe):
|
||||
"""send_image should download the URL, upload via /api/v4/files, then post."""
|
||||
# Mock the download (GET)
|
||||
mock_dl_resp = AsyncMock()
|
||||
|
||||
@@ -596,10 +596,11 @@ def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
|
||||
return resp
|
||||
|
||||
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
class TestMattermostSendUrlAsFile:
|
||||
"""Tests for MattermostAdapter._send_url_as_file"""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
def test_success_on_first_attempt(self, _mock_safe):
|
||||
"""200 on first attempt → file uploaded and post created."""
|
||||
adapter = _make_mm_adapter()
|
||||
resp = _make_aiohttp_resp(200)
|
||||
@@ -616,7 +617,7 @@ class TestMattermostSendUrlAsFile:
|
||||
adapter._upload_file.assert_called_once()
|
||||
adapter._api_post.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe):
|
||||
"""429 on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -637,7 +638,7 @@ class TestMattermostSendUrlAsFile:
|
||||
assert adapter._session.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_500_then_succeeds(self):
|
||||
def test_retries_on_500_then_succeeds(self, _mock_safe):
|
||||
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -655,7 +656,7 @@ class TestMattermostSendUrlAsFile:
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self):
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self, _mock_safe):
|
||||
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -674,7 +675,7 @@ class TestMattermostSendUrlAsFile:
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_falls_back_on_client_error(self):
|
||||
def test_falls_back_on_client_error(self, _mock_safe):
|
||||
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
|
||||
import aiohttp
|
||||
|
||||
@@ -699,7 +700,7 @@ class TestMattermostSendUrlAsFile:
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_non_retryable_404_falls_back_immediately(self):
|
||||
def test_non_retryable_404_falls_back_immediately(self, _mock_safe):
|
||||
"""404 is non-retryable (< 500, != 429); send() is called right away."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from gateway.platforms.base import (
|
||||
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
_safe_url_for_log,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +19,31 @@ class TestSecretCaptureGuidance:
|
||||
assert "~/.hermes/.env" in message
|
||||
|
||||
|
||||
class TestSafeUrlForLog:
|
||||
def test_strips_query_fragment_and_userinfo(self):
|
||||
url = (
|
||||
"https://user:pass@example.com/private/path/image.png"
|
||||
"?X-Amz-Signature=supersecret&token=abc#frag"
|
||||
)
|
||||
result = _safe_url_for_log(url)
|
||||
assert result == "https://example.com/.../image.png"
|
||||
assert "supersecret" not in result
|
||||
assert "token=abc" not in result
|
||||
assert "user:pass@" not in result
|
||||
|
||||
def test_truncates_long_values(self):
|
||||
long_url = "https://example.com/" + ("a" * 300)
|
||||
result = _safe_url_for_log(long_url, max_len=40)
|
||||
assert len(result) == 40
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_handles_small_and_non_positive_max_len(self):
|
||||
url = "https://example.com/very/long/path/file.png?token=secret"
|
||||
assert _safe_url_for_log(url, max_len=3) == "..."
|
||||
assert _safe_url_for_log(url, max_len=2) == ".."
|
||||
assert _safe_url_for_log(url, max_len=0) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageEvent — command parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -59,6 +59,7 @@ def _make_runner():
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner.session_store = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
|
||||
@@ -87,7 +87,6 @@ class TestReasoningCommand:
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
|
||||
@@ -108,7 +107,6 @@ class TestReasoningCommand:
|
||||
config_path.write_text("agent:\n reasoning_effort: medium\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "medium"}
|
||||
@@ -138,7 +136,6 @@ class TestReasoningCommand:
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
@@ -170,55 +167,6 @@ class TestReasoningCommand:
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": True, "effort": "low"}
|
||||
|
||||
def test_run_agent_prefers_config_over_stale_reasoning_env(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: none\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setenv("HERMES_REASONING_EFFORT", "low")
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="ping",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="session-1",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": False}
|
||||
|
||||
def test_run_agent_includes_enabled_mcp_servers_in_gateway_toolsets(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
@@ -201,8 +201,8 @@ class TestHandleResumeCommand:
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_flushes_memories_with_gateway_session_key(self, tmp_path):
|
||||
"""Resume should preserve the gateway session key for Honcho flushes."""
|
||||
async def test_resume_flushes_memories(self, tmp_path):
|
||||
"""Resume should flush memories from the current session before switching."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
@@ -221,6 +221,5 @@ class TestHandleResumeCommand:
|
||||
|
||||
runner._async_flush_memories.assert_called_once_with(
|
||||
"current_session_001",
|
||||
_session_key_for_event(event),
|
||||
)
|
||||
db.close()
|
||||
|
||||
@@ -71,6 +71,24 @@ class FakeAgent:
|
||||
}
|
||||
|
||||
|
||||
class LongPreviewAgent:
|
||||
"""Agent that emits a tool call with a very long preview string."""
|
||||
LONG_CMD = "cd /home/teknium/.hermes/hermes-agent/.worktrees/hermes-d8860339 && source .venv/bin/activate && python -m pytest tests/gateway/test_run_progress_topics.py -n0 -q"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", self.LONG_CMD, {})
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
@@ -217,3 +235,102 @@ async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch
|
||||
assert adapter.sent
|
||||
assert adapter.sent[0]["metadata"] == {"thread_id": "1234567890.000001"}
|
||||
assert all(call["metadata"] == {"thread_id": "1234567890.000001"} for call in adapter.typing)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview truncation tests (all/new mode respects tool_preview_length)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0):
|
||||
"""Shared setup for long-preview truncation tests.
|
||||
|
||||
Returns (adapter, result) after running the agent with LongPreviewAgent.
|
||||
``preview_length`` controls display.tool_preview_length in the config file
|
||||
that _run_agent reads — so the gateway picks it up the same way production does.
|
||||
"""
|
||||
import asyncio
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = LongPreviewAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
# Write config.yaml so _run_agent picks up tool_preview_length
|
||||
config = {"display": {"tool_preview_length": preview_length}}
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config), encoding="utf-8")
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-trunc",
|
||||
session_key="agent:main:telegram:dm:12345",
|
||||
)
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
def test_all_mode_default_truncation_40_chars(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is 0 (default), all/new mode truncates to 40 chars."""
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# The long command should be truncated — total preview <= 40 chars
|
||||
assert "..." in content
|
||||
# Extract the preview part between quotes
|
||||
import re
|
||||
match = re.search(r'"(.+)"', content)
|
||||
assert match, f"No quoted preview found in: {content}"
|
||||
preview_text = match.group(1)
|
||||
assert len(preview_text) <= 40, f"Preview too long ({len(preview_text)}): {preview_text}"
|
||||
|
||||
|
||||
def test_all_mode_respects_custom_preview_length(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is explicitly set (e.g. 120), all/new mode uses that."""
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=120)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# With 120-char cap, the command (165 chars) should still be truncated but longer
|
||||
import re
|
||||
match = re.search(r'"(.+)"', content)
|
||||
assert match, f"No quoted preview found in: {content}"
|
||||
preview_text = match.group(1)
|
||||
# Should be longer than the 40-char default
|
||||
assert len(preview_text) > 40, f"Preview suspiciously short ({len(preview_text)}): {preview_text}"
|
||||
# But still capped at 120
|
||||
assert len(preview_text) <= 120, f"Preview too long ({len(preview_text)}): {preview_text}"
|
||||
|
||||
|
||||
def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
||||
"""Short previews (under the cap) are not truncated."""
|
||||
# Set a generous cap — the LongPreviewAgent's command is ~165 chars
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=200)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
158
tests/gateway/test_session_boundary_hooks.py
Normal file
158
tests/gateway/test_session_boundary_hooks.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests that on_session_finalize and on_session_reset plugin hooks fire in the gateway."""
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner._session_model_overrides = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._background_tasks = set()
|
||||
|
||||
session_key = build_session_key(_make_source())
|
||||
session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-old",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
new_session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-new",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = new_session_entry
|
||||
runner.session_store.reset_session.return_value = new_session_entry
|
||||
runner.session_store._entries = {session_key: session_entry}
|
||||
runner.session_store._generate_session_key.return_value = session_key
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._agent_cache_lock = None
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._format_session_info = lambda: ""
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_reset_fires_finalize_hook(mock_invoke_hook):
|
||||
"""/new must fire on_session_finalize with the OLD session id."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="sess-old", platform="telegram"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_reset_fires_reset_hook(mock_invoke_hook):
|
||||
"""/new must fire on_session_reset with the NEW session id."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_reset", session_id="sess-new", platform="telegram"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_finalize_before_reset(mock_invoke_hook):
|
||||
"""on_session_finalize must fire before on_session_reset."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
calls = [c for c in mock_invoke_hook.call_args_list
|
||||
if c[0][0] in ("on_session_finalize", "on_session_reset")]
|
||||
hook_names = [c[0][0] for c in calls]
|
||||
assert hook_names == ["on_session_finalize", "on_session_reset"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
|
||||
"""Gateway stop() must fire on_session_finalize for each active agent."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._running = True
|
||||
runner._background_tasks = set()
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._shutdown_event = MagicMock()
|
||||
runner.adapters = {}
|
||||
runner._exit_reason = "test"
|
||||
|
||||
agent1 = MagicMock()
|
||||
agent1.session_id = "sess-a"
|
||||
agent2 = MagicMock()
|
||||
agent2.session_id = "sess-b"
|
||||
runner._running_agents = {"key-a": agent1, "key-b": agent2}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), \
|
||||
patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
finalize_calls = [
|
||||
c for c in mock_invoke_hook.call_args_list
|
||||
if c[0][0] == "on_session_finalize"
|
||||
]
|
||||
session_ids = {c[1]["session_id"] for c in finalize_calls}
|
||||
assert session_ids == {"sess-a", "sess-b"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook", side_effect=Exception("boom"))
|
||||
async def test_hook_error_does_not_break_reset(mock_invoke_hook):
|
||||
"""Plugin hook errors must not prevent /new from completing."""
|
||||
runner = _make_runner()
|
||||
|
||||
result = await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
# Should still return a success message despite hook errors
|
||||
assert "Session reset" in result or "New session" in result
|
||||
@@ -36,11 +36,16 @@ def _make_runner():
|
||||
)
|
||||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._background_tasks = set()
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.delivery_router = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
|
||||
@@ -699,6 +699,147 @@ class TestReactions:
|
||||
assert remove_calls[0].kwargs["name"] == "eyes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestThreadReplyHandling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThreadReplyHandling:
|
||||
"""Test thread reply processing without explicit bot mentions."""
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session_store(self):
|
||||
"""Create a mock session store with entries dict."""
|
||||
store = MagicMock()
|
||||
store._entries = {}
|
||||
store._ensure_loaded = MagicMock()
|
||||
store.config = MagicMock()
|
||||
store.config.group_sessions_per_user = True
|
||||
return store
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter_with_session_store(self, mock_session_store):
|
||||
"""Create an adapter with a mock session store attached."""
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
a = SlackAdapter(config)
|
||||
a._app = MagicMock()
|
||||
a._app.client = AsyncMock()
|
||||
a._bot_user_id = "U_BOT"
|
||||
a._team_bot_user_ids = {"T_TEAM": "U_BOT"}
|
||||
a._running = True
|
||||
a.handle_message = AsyncMock()
|
||||
a.set_session_store(mock_session_store)
|
||||
return a
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_reply_without_mention_no_session_ignored(
|
||||
self, adapter_with_session_store, mock_session_store
|
||||
):
|
||||
"""Thread replies without mention should be ignored if no active session."""
|
||||
mock_session_store._entries = {} # No active sessions
|
||||
|
||||
event = {
|
||||
"text": "Just replying in the thread",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"ts": "123.456",
|
||||
"thread_ts": "123.000", # Different from ts - this is a reply
|
||||
"channel_type": "channel",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
await adapter_with_session_store._handle_slack_message(event)
|
||||
adapter_with_session_store.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_reply_without_mention_with_session_processed(
|
||||
self, adapter_with_session_store, mock_session_store
|
||||
):
|
||||
"""Thread replies without mention should be processed if there's an active session."""
|
||||
# Simulate an active session for this thread
|
||||
session_key = "agent:main:slack:group:C123:123.000:U_USER"
|
||||
mock_session_store._entries = {session_key: MagicMock()}
|
||||
|
||||
event = {
|
||||
"text": "Follow-up question",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"ts": "123.456",
|
||||
"thread_ts": "123.000", # Reply in thread 123.000
|
||||
"channel_type": "channel",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
await adapter_with_session_store._handle_slack_message(event)
|
||||
adapter_with_session_store.handle_message.assert_called_once()
|
||||
|
||||
# Verify the text is passed through unchanged (no mention stripping needed)
|
||||
msg_event = adapter_with_session_store.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "Follow-up question"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_reply_with_mention_strips_bot_id(
|
||||
self, adapter_with_session_store, mock_session_store
|
||||
):
|
||||
"""Thread replies with @mention should still strip the bot ID."""
|
||||
# Even with a session, mentions should be stripped
|
||||
session_key = "agent:main:slack:group:C123:123.000:U_USER"
|
||||
mock_session_store._entries = {session_key: MagicMock()}
|
||||
|
||||
event = {
|
||||
"text": "<@U_BOT> thanks for the help",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"ts": "123.456",
|
||||
"thread_ts": "123.000",
|
||||
"channel_type": "channel",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
await adapter_with_session_store._handle_slack_message(event)
|
||||
adapter_with_session_store.handle_message.assert_called_once()
|
||||
|
||||
msg_event = adapter_with_session_store.handle_message.call_args[0][0]
|
||||
assert "<@U_BOT>" not in msg_event.text
|
||||
assert msg_event.text == "thanks for the help"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_level_message_requires_mention_even_with_session(
|
||||
self, adapter_with_session_store, mock_session_store
|
||||
):
|
||||
"""Top-level channel messages should require mention even if session exists."""
|
||||
# Session exists but this is a top-level message (no thread_ts)
|
||||
session_key = "agent:main:slack:group:C123:123.000:U_USER"
|
||||
mock_session_store._entries = {session_key: MagicMock()}
|
||||
|
||||
event = {
|
||||
"text": "New question without mention",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"ts": "456.789",
|
||||
# No thread_ts - this is a top-level message
|
||||
"channel_type": "channel",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
await adapter_with_session_store._handle_slack_message(event)
|
||||
adapter_with_session_store.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_store_ignores_thread_replies(
|
||||
self, adapter
|
||||
):
|
||||
"""If no session store is attached, thread replies without mention should be ignored."""
|
||||
# adapter fixture has no session store attached
|
||||
event = {
|
||||
"text": "Thread reply without mention",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"ts": "123.456",
|
||||
"thread_ts": "123.000",
|
||||
"channel_type": "channel",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUserNameResolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
426
tests/gateway/test_slack_approval_buttons.py
Normal file
426
tests/gateway/test_slack_approval_buttons.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Tests for Slack Block Kit approval buttons and thread context fetching."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the repo root is importable
|
||||
# ---------------------------------------------------------------------------
|
||||
_repo = str(Path(__file__).resolve().parents[2])
|
||||
if _repo not in sys.path:
|
||||
sys.path.insert(0, _repo)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal Slack SDK mock so SlackAdapter can be imported
|
||||
# ---------------------------------------------------------------------------
|
||||
def _ensure_slack_mock():
|
||||
"""Wire up the minimal mocks required to import SlackAdapter."""
|
||||
if "slack_bolt" in sys.modules:
|
||||
return
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
sys.modules["slack_bolt"] = slack_bolt
|
||||
sys.modules["slack_bolt.async_app"] = slack_bolt.async_app
|
||||
handler_mod = MagicMock()
|
||||
handler_mod.AsyncSocketModeHandler = MagicMock
|
||||
sys.modules["slack_bolt.adapter"] = MagicMock()
|
||||
sys.modules["slack_bolt.adapter.socket_mode"] = MagicMock()
|
||||
sys.modules["slack_bolt.adapter.socket_mode.async_handler"] = handler_mod
|
||||
sdk_mod = MagicMock()
|
||||
sdk_mod.web = MagicMock()
|
||||
sdk_mod.web.async_client = MagicMock()
|
||||
sdk_mod.web.async_client.AsyncWebClient = MagicMock
|
||||
sys.modules["slack_sdk"] = sdk_mod
|
||||
sys.modules["slack_sdk.web"] = sdk_mod.web
|
||||
sys.modules["slack_sdk.web.async_client"] = sdk_mod.web.async_client
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a SlackAdapter instance with mocked internals."""
|
||||
config = PlatformConfig(enabled=True, token="xoxb-test-token")
|
||||
adapter = SlackAdapter(config)
|
||||
adapter._app = MagicMock()
|
||||
adapter._bot_user_id = "U_BOT"
|
||||
adapter._team_clients = {"T1": AsyncMock()}
|
||||
adapter._team_bot_user_ids = {"T1": "U_BOT"}
|
||||
adapter._channel_team = {"C1": "T1"}
|
||||
return adapter
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# send_exec_approval — Block Kit buttons
|
||||
# ===========================================================================
|
||||
|
||||
class TestSlackExecApproval:
|
||||
"""Test the send_exec_approval method sends Block Kit buttons."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_blocks_with_buttons(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"})
|
||||
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="C1",
|
||||
command="rm -rf /important",
|
||||
session_key="agent:main:slack:group:C1:1111",
|
||||
description="dangerous deletion",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1234.5678"
|
||||
|
||||
# Verify chat_postMessage was called with blocks
|
||||
mock_client.chat_postMessage.assert_called_once()
|
||||
kwargs = mock_client.chat_postMessage.call_args[1]
|
||||
assert "blocks" in kwargs
|
||||
blocks = kwargs["blocks"]
|
||||
assert len(blocks) == 2
|
||||
assert blocks[0]["type"] == "section"
|
||||
assert "rm -rf /important" in blocks[0]["text"]["text"]
|
||||
assert "dangerous deletion" in blocks[0]["text"]["text"]
|
||||
assert blocks[1]["type"] == "actions"
|
||||
elements = blocks[1]["elements"]
|
||||
assert len(elements) == 4
|
||||
action_ids = [e["action_id"] for e in elements]
|
||||
assert "hermes_approve_once" in action_ids
|
||||
assert "hermes_approve_session" in action_ids
|
||||
assert "hermes_approve_always" in action_ids
|
||||
assert "hermes_deny" in action_ids
|
||||
# Each button carries the session key as value
|
||||
for e in elements:
|
||||
assert e["value"] == "agent:main:slack:group:C1:1111"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_in_thread(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"})
|
||||
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="C1",
|
||||
command="echo test",
|
||||
session_key="test-session",
|
||||
metadata={"thread_id": "9999.0000"},
|
||||
)
|
||||
|
||||
kwargs = mock_client.chat_postMessage.call_args[1]
|
||||
assert kwargs.get("thread_ts") == "9999.0000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_connected(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._app = None
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="C1", command="ls", session_key="s"
|
||||
)
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_long_command(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1.2"})
|
||||
|
||||
long_cmd = "x" * 5000
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="C1", command=long_cmd, session_key="s"
|
||||
)
|
||||
|
||||
kwargs = mock_client.chat_postMessage.call_args[1]
|
||||
section_text = kwargs["blocks"][0]["text"]["text"]
|
||||
assert "..." in section_text
|
||||
assert len(section_text) < 5000
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _handle_approval_action — button click handler
|
||||
# ===========================================================================
|
||||
|
||||
class TestSlackApprovalAction:
|
||||
"""Test the approval button click handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_approval(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_resolved["1234.5678"] = False
|
||||
|
||||
ack = AsyncMock()
|
||||
body = {
|
||||
"message": {
|
||||
"ts": "1234.5678",
|
||||
"blocks": [
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": "original text"}},
|
||||
{"type": "actions", "elements": []},
|
||||
],
|
||||
},
|
||||
"channel": {"id": "C1"},
|
||||
"user": {"name": "norbert"},
|
||||
}
|
||||
action = {
|
||||
"action_id": "hermes_approve_once",
|
||||
"value": "agent:main:slack:group:C1:1111",
|
||||
}
|
||||
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.chat_update = AsyncMock()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._handle_approval_action(ack, body, action)
|
||||
|
||||
ack.assert_called_once()
|
||||
mock_resolve.assert_called_once_with("agent:main:slack:group:C1:1111", "once")
|
||||
|
||||
# Message should be updated with decision
|
||||
mock_client.chat_update.assert_called_once()
|
||||
update_kwargs = mock_client.chat_update.call_args[1]
|
||||
assert "Approved once by norbert" in update_kwargs["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevents_double_click(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_resolved["1234.5678"] = True # Already resolved
|
||||
|
||||
ack = AsyncMock()
|
||||
body = {
|
||||
"message": {"ts": "1234.5678", "blocks": []},
|
||||
"channel": {"id": "C1"},
|
||||
"user": {"name": "norbert"},
|
||||
}
|
||||
action = {
|
||||
"action_id": "hermes_approve_once",
|
||||
"value": "some-session",
|
||||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._handle_approval_action(ack, body, action)
|
||||
|
||||
# Should have acked but NOT resolved
|
||||
ack.assert_called_once()
|
||||
mock_resolve.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_action(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_resolved["1.2"] = False
|
||||
|
||||
ack = AsyncMock()
|
||||
body = {
|
||||
"message": {"ts": "1.2", "blocks": [
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": "cmd"}},
|
||||
]},
|
||||
"channel": {"id": "C1"},
|
||||
"user": {"name": "alice"},
|
||||
}
|
||||
action = {"action_id": "hermes_deny", "value": "session-key"}
|
||||
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.chat_update = AsyncMock()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._handle_approval_action(ack, body, action)
|
||||
|
||||
mock_resolve.assert_called_once_with("session-key", "deny")
|
||||
update_kwargs = mock_client.chat_update.call_args[1]
|
||||
assert "Denied by alice" in update_kwargs["text"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _fetch_thread_context
|
||||
# ===========================================================================
|
||||
|
||||
class TestSlackThreadContext:
|
||||
"""Test thread context fetching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_and_formats_context(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "1000.0", "user": "U1", "text": "This is the parent message"},
|
||||
{"ts": "1000.1", "user": "U2", "text": "I think we should refactor"},
|
||||
{"ts": "1000.2", "user": "U1", "text": "Good idea, <@U_BOT> what do you think?"},
|
||||
]
|
||||
})
|
||||
|
||||
# Mock user name resolution
|
||||
adapter._user_name_cache = {"U1": "Alice", "U2": "Bob"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1",
|
||||
thread_ts="1000.0",
|
||||
current_ts="1000.2", # The message that triggered the fetch
|
||||
team_id="T1",
|
||||
)
|
||||
|
||||
assert "[Thread context" in context
|
||||
assert "[thread parent] Alice: This is the parent message" in context
|
||||
assert "Bob: I think we should refactor" in context
|
||||
# Current message should be excluded
|
||||
assert "what do you think" not in context
|
||||
# Bot mention should be stripped from context
|
||||
assert "<@U_BOT>" not in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_bot_messages(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "1000.0", "user": "U1", "text": "Parent"},
|
||||
{"ts": "1000.1", "bot_id": "B1", "text": "Bot reply (should be skipped)"},
|
||||
{"ts": "1000.2", "user": "U1", "text": "Current"},
|
||||
]
|
||||
})
|
||||
adapter._user_name_cache = {"U1": "Alice"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.2", team_id="T1"
|
||||
)
|
||||
|
||||
assert "Bot reply" not in context
|
||||
assert "Alice: Parent" in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_thread(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={"messages": []})
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
|
||||
)
|
||||
assert context == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_failure_returns_empty(self):
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(side_effect=Exception("API error"))
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
|
||||
)
|
||||
assert context == ""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _has_active_session_for_thread — session key fix (#5833)
|
||||
# ===========================================================================
|
||||
|
||||
class TestSessionKeyFix:
|
||||
"""Test that _has_active_session_for_thread uses build_session_key."""
|
||||
|
||||
def test_uses_build_session_key(self):
|
||||
"""Verify the fix uses build_session_key instead of manual key construction."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
# Mock session store with a known entry
|
||||
mock_store = MagicMock()
|
||||
mock_store._entries = {
|
||||
"agent:main:slack:group:C1:1000.0": MagicMock()
|
||||
}
|
||||
mock_store._ensure_loaded = MagicMock()
|
||||
mock_store.config = MagicMock()
|
||||
mock_store.config.group_sessions_per_user = False # threads don't include user_id
|
||||
mock_store.config.thread_sessions_per_user = False
|
||||
adapter._session_store = mock_store
|
||||
|
||||
# With the fix, build_session_key should be called which respects
|
||||
# group_sessions_per_user=False (no user_id appended)
|
||||
result = adapter._has_active_session_for_thread(
|
||||
channel_id="C1", thread_ts="1000.0", user_id="U123"
|
||||
)
|
||||
|
||||
# Should find the session because build_session_key with
|
||||
# group_sessions_per_user=False doesn't append user_id
|
||||
assert result is True
|
||||
|
||||
def test_no_session_returns_false(self):
|
||||
adapter = _make_adapter()
|
||||
mock_store = MagicMock()
|
||||
mock_store._entries = {}
|
||||
mock_store._ensure_loaded = MagicMock()
|
||||
mock_store.config = MagicMock()
|
||||
mock_store.config.group_sessions_per_user = True
|
||||
mock_store.config.thread_sessions_per_user = False
|
||||
adapter._session_store = mock_store
|
||||
|
||||
result = adapter._has_active_session_for_thread(
|
||||
channel_id="C1", thread_ts="1000.0", user_id="U123"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_no_session_store(self):
|
||||
adapter = _make_adapter()
|
||||
# No _session_store attribute
|
||||
result = adapter._has_active_session_for_thread(
|
||||
channel_id="C1", thread_ts="1000.0", user_id="U123"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Thread engagement — bot-started threads & mentioned threads
|
||||
# ===========================================================================
|
||||
|
||||
class TestThreadEngagement:
|
||||
"""Test _bot_message_ts and _mentioned_threads tracking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_tracks_bot_message_ts(self):
|
||||
"""Bot's sent messages are tracked so thread replies work without @mention."""
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.chat_postMessage = AsyncMock(return_value={"ts": "9000.1"})
|
||||
|
||||
await adapter.send(chat_id="C1", content="Hello!", metadata={"thread_id": "8000.0"})
|
||||
|
||||
assert "9000.1" in adapter._bot_message_ts
|
||||
# Thread root should also be tracked
|
||||
assert "8000.0" in adapter._bot_message_ts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_message_ts_cap(self):
|
||||
"""Verify memory is bounded when many messages are sent."""
|
||||
adapter = _make_adapter()
|
||||
adapter._BOT_TS_MAX = 10 # low cap for testing
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
|
||||
for i in range(20):
|
||||
mock_client.chat_postMessage = AsyncMock(return_value={"ts": f"{i}.0"})
|
||||
await adapter.send(chat_id="C1", content=f"msg {i}")
|
||||
|
||||
assert len(adapter._bot_message_ts) <= 10
|
||||
|
||||
def test_mentioned_threads_populated_on_mention(self):
|
||||
"""When bot is @mentioned in a thread, that thread is tracked."""
|
||||
adapter = _make_adapter()
|
||||
# Simulate what _handle_slack_message does on mention
|
||||
adapter._mentioned_threads.add("1000.0")
|
||||
assert "1000.0" in adapter._mentioned_threads
|
||||
|
||||
def test_mentioned_threads_cap(self):
|
||||
"""Verify _mentioned_threads is bounded."""
|
||||
adapter = _make_adapter()
|
||||
adapter._MENTIONED_THREADS_MAX = 10
|
||||
for i in range(15):
|
||||
adapter._mentioned_threads.add(f"{i}.0")
|
||||
if len(adapter._mentioned_threads) > adapter._MENTIONED_THREADS_MAX:
|
||||
to_remove = list(adapter._mentioned_threads)[:adapter._MENTIONED_THREADS_MAX // 2]
|
||||
for t in to_remove:
|
||||
adapter._mentioned_threads.discard(t)
|
||||
assert len(adapter._mentioned_threads) <= 10
|
||||
@@ -51,7 +51,8 @@ def _make_runner(session_entry: SessionEntry):
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
@@ -82,12 +83,34 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Session ID:** `sess-1`" in result
|
||||
assert "**Tokens:** 321" in result
|
||||
assert "**Agent Running:** Yes ⚡" in result
|
||||
assert "**Title:**" not in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_includes_session_title_when_present():
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session_title.return_value = "My titled session"
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Session ID:** `sess-1`" in result
|
||||
assert "**Title:** My titled session" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
@@ -177,3 +177,238 @@ class TestStreamRunMediaStripping:
|
||||
assert "MEDIA:" not in sent_text, f"MEDIA: leaked into display: {sent_text!r}"
|
||||
|
||||
assert consumer.already_sent
|
||||
|
||||
|
||||
# ── Segment break (tool boundary) tests ──────────────────────────────────
|
||||
|
||||
|
||||
class TestSegmentBreakOnToolBoundary:
|
||||
"""Verify that on_delta(None) finalizes the current message and starts a
|
||||
new one so the final response appears below tool-progress messages."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_creates_new_message(self):
|
||||
"""After a None boundary, next text creates a fresh message."""
|
||||
adapter = MagicMock()
|
||||
send_result_1 = SimpleNamespace(success=True, message_id="msg_1")
|
||||
send_result_2 = SimpleNamespace(success=True, message_id="msg_2")
|
||||
edit_result = SimpleNamespace(success=True)
|
||||
adapter.send = AsyncMock(side_effect=[send_result_1, send_result_2])
|
||||
adapter.edit_message = AsyncMock(return_value=edit_result)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
# Phase 1: intermediate text before tool calls
|
||||
consumer.on_delta("Let me search for that...")
|
||||
# Tool boundary — model is about to call tools
|
||||
consumer.on_delta(None)
|
||||
# Phase 2: final response text after tools finished
|
||||
consumer.on_delta("Here are the results.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
# Should have sent TWO separate messages (two adapter.send calls),
|
||||
# not just edited the first one.
|
||||
assert adapter.send.call_count == 2
|
||||
first_text = adapter.send.call_args_list[0][1]["content"]
|
||||
second_text = adapter.send.call_args_list[1][1]["content"]
|
||||
assert "search" in first_text
|
||||
assert "results" in second_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_no_text_before(self):
|
||||
"""A None boundary with no preceding text is a no-op."""
|
||||
adapter = MagicMock()
|
||||
send_result = SimpleNamespace(success=True, message_id="msg_1")
|
||||
adapter.send = AsyncMock(return_value=send_result)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
# No text before the boundary — model went straight to tool calls
|
||||
consumer.on_delta(None)
|
||||
consumer.on_delta("Final answer.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
# Only one send call (the final answer)
|
||||
assert adapter.send.call_count == 1
|
||||
assert "Final answer" in adapter.send.call_args_list[0][1]["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_removes_cursor(self):
|
||||
"""The finalized segment message should not have a cursor."""
|
||||
adapter = MagicMock()
|
||||
send_result = SimpleNamespace(success=True, message_id="msg_1")
|
||||
edit_result = SimpleNamespace(success=True)
|
||||
adapter.send = AsyncMock(return_value=send_result)
|
||||
adapter.edit_message = AsyncMock(return_value=edit_result)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Thinking...")
|
||||
consumer.on_delta(None)
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
# The first segment should have been finalized without cursor.
|
||||
# Check all edit_message calls + the initial send for the first segment.
|
||||
# The last state of msg_1 should NOT have the cursor.
|
||||
all_texts = []
|
||||
for call in adapter.send.call_args_list:
|
||||
all_texts.append(call[1].get("content", ""))
|
||||
for call in adapter.edit_message.call_args_list:
|
||||
all_texts.append(call[1].get("content", ""))
|
||||
|
||||
# Find the text(s) that contain "Thinking" — the finalized version
|
||||
# should not have the cursor.
|
||||
thinking_texts = [t for t in all_texts if "Thinking" in t]
|
||||
assert thinking_texts, "Expected at least one message with 'Thinking'"
|
||||
# The LAST occurrence is the finalized version
|
||||
assert "▉" not in thinking_texts[-1], (
|
||||
f"Cursor found in finalized segment: {thinking_texts[-1]!r}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_segment_breaks(self):
|
||||
"""Multiple tool boundaries create multiple message segments."""
|
||||
adapter = MagicMock()
|
||||
msg_counter = iter(["msg_1", "msg_2", "msg_3"])
|
||||
adapter.send = AsyncMock(
|
||||
side_effect=lambda **kw: SimpleNamespace(success=True, message_id=next(msg_counter))
|
||||
)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Phase 1")
|
||||
consumer.on_delta(None) # tool boundary
|
||||
consumer.on_delta("Phase 2")
|
||||
consumer.on_delta(None) # another tool boundary
|
||||
consumer.on_delta("Phase 3")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
# Three separate messages
|
||||
assert adapter.send.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_sent_stays_true_after_segment(self):
|
||||
"""already_sent remains True after a segment break."""
|
||||
adapter = MagicMock()
|
||||
send_result = SimpleNamespace(success=True, message_id="msg_1")
|
||||
adapter.send = AsyncMock(return_value=send_result)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Text")
|
||||
consumer.on_delta(None)
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.already_sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_failure_sends_only_unsent_tail_at_finish(self):
|
||||
"""If an edit fails mid-stream, send only the missing tail once at finish."""
|
||||
adapter = MagicMock()
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
assert adapter.send.call_count == 2
|
||||
first_text = adapter.send.call_args_list[0][1]["content"]
|
||||
second_text = adapter.send.call_args_list[1][1]["content"]
|
||||
assert "Hello" in first_text
|
||||
assert second_text.strip() == "world"
|
||||
assert consumer.already_sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_clears_failed_edit_fallback_state(self):
|
||||
"""A tool boundary after edit failure must not duplicate the next segment."""
|
||||
adapter = MagicMock()
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(None)
|
||||
consumer.on_delta("Next segment")
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "Next segment"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_final_splits_long_continuation_without_dropping_text(self):
|
||||
"""Long continuation tails should be chunked when fallback final-send runs."""
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
SimpleNamespace(success=True, message_id="msg_3"),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 610
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
prefix = "abc"
|
||||
tail = "x" * 620
|
||||
consumer.on_delta(prefix)
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(tail)
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert len(sent_texts) == 3
|
||||
assert sent_texts[0].startswith(prefix)
|
||||
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
|
||||
|
||||
291
tests/gateway/test_telegram_approval_buttons.py
Normal file
291
tests/gateway/test_telegram_approval_buttons.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Tests for Telegram inline keyboard approval buttons."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the repo root is importable
|
||||
# ---------------------------------------------------------------------------
|
||||
_repo = str(Path(__file__).resolve().parents[2])
|
||||
if _repo not in sys.path:
|
||||
sys.path.insert(0, _repo)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal Telegram mock so TelegramAdapter can be imported
|
||||
# ---------------------------------------------------------------------------
|
||||
def _ensure_telegram_mock():
|
||||
"""Wire up the minimal mocks required to import TelegramAdapter."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
mod = MagicMock()
|
||||
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
mod.constants.ParseMode.MARKDOWN = "Markdown"
|
||||
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
mod.constants.ParseMode.HTML = "HTML"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
mod.constants.ChatType.GROUP = "group"
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
# Provide real exception classes so ``except (NetworkError, ...)`` in
|
||||
# connect() doesn't blow up under xdist when this mock leaks.
|
||||
mod.error.NetworkError = type("NetworkError", (OSError,), {})
|
||||
mod.error.TimedOut = type("TimedOut", (OSError,), {})
|
||||
mod.error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
sys.modules.setdefault("telegram.error", mod.error)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a TelegramAdapter with mocked internals."""
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = TelegramAdapter(config)
|
||||
adapter._bot = AsyncMock()
|
||||
adapter._app = MagicMock()
|
||||
return adapter
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# send_exec_approval — inline keyboard buttons
|
||||
# ===========================================================================
|
||||
|
||||
class TestTelegramExecApproval:
|
||||
"""Test the send_exec_approval method sends InlineKeyboard buttons."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_inline_keyboard(self):
|
||||
adapter = _make_adapter()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 42
|
||||
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="12345",
|
||||
command="rm -rf /important",
|
||||
session_key="agent:main:telegram:group:12345:99",
|
||||
description="dangerous deletion",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "42"
|
||||
|
||||
adapter._bot.send_message.assert_called_once()
|
||||
kwargs = adapter._bot.send_message.call_args[1]
|
||||
assert kwargs["chat_id"] == 12345
|
||||
assert "rm -rf /important" in kwargs["text"]
|
||||
assert "dangerous deletion" in kwargs["text"]
|
||||
assert kwargs["reply_markup"] is not None # InlineKeyboardMarkup
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_approval_state(self):
|
||||
adapter = _make_adapter()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 42
|
||||
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="12345",
|
||||
command="echo test",
|
||||
session_key="my-session-key",
|
||||
)
|
||||
|
||||
# The approval_id should map to the session_key
|
||||
assert len(adapter._approval_state) == 1
|
||||
approval_id = list(adapter._approval_state.keys())[0]
|
||||
assert adapter._approval_state[approval_id] == "my-session-key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_in_thread(self):
|
||||
adapter = _make_adapter()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 42
|
||||
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="12345",
|
||||
command="ls",
|
||||
session_key="s",
|
||||
metadata={"thread_id": "999"},
|
||||
)
|
||||
|
||||
kwargs = adapter._bot.send_message.call_args[1]
|
||||
assert kwargs.get("message_thread_id") == 999
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_connected(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._bot = None
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="12345", command="ls", session_key="s"
|
||||
)
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_long_command(self):
|
||||
adapter = _make_adapter()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 1
|
||||
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
|
||||
|
||||
long_cmd = "x" * 5000
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="12345", command=long_cmd, session_key="s"
|
||||
)
|
||||
|
||||
kwargs = adapter._bot.send_message.call_args[1]
|
||||
assert "..." in kwargs["text"]
|
||||
assert len(kwargs["text"]) < 5000
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _handle_callback_query — approval button clicks
|
||||
# ===========================================================================
|
||||
|
||||
class TestTelegramApprovalCallback:
|
||||
"""Test the approval callback handling in _handle_callback_query."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_approval_on_click(self):
|
||||
adapter = _make_adapter()
|
||||
# Set up approval state
|
||||
adapter._approval_state[1] = "agent:main:telegram:group:12345:99"
|
||||
|
||||
# Mock callback query
|
||||
query = AsyncMock()
|
||||
query.data = "ea:once:1"
|
||||
query.message = MagicMock()
|
||||
query.message.chat_id = 12345
|
||||
query.from_user = MagicMock()
|
||||
query.from_user.first_name = "Norbert"
|
||||
query.answer = AsyncMock()
|
||||
query.edit_message_text = AsyncMock()
|
||||
|
||||
update = MagicMock()
|
||||
update.callback_query = query
|
||||
context = MagicMock()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._handle_callback_query(update, context)
|
||||
|
||||
mock_resolve.assert_called_once_with("agent:main:telegram:group:12345:99", "once")
|
||||
query.answer.assert_called_once()
|
||||
query.edit_message_text.assert_called_once()
|
||||
|
||||
# State should be cleaned up
|
||||
assert 1 not in adapter._approval_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_button(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[2] = "some-session"
|
||||
|
||||
query = AsyncMock()
|
||||
query.data = "ea:deny:2"
|
||||
query.message = MagicMock()
|
||||
query.message.chat_id = 12345
|
||||
query.from_user = MagicMock()
|
||||
query.from_user.first_name = "Alice"
|
||||
query.answer = AsyncMock()
|
||||
query.edit_message_text = AsyncMock()
|
||||
|
||||
update = MagicMock()
|
||||
update.callback_query = query
|
||||
context = MagicMock()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._handle_callback_query(update, context)
|
||||
|
||||
mock_resolve.assert_called_once_with("some-session", "deny")
|
||||
edit_kwargs = query.edit_message_text.call_args[1]
|
||||
assert "Denied" in edit_kwargs["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_resolved(self):
|
||||
adapter = _make_adapter()
|
||||
# No state for approval_id 99 — already resolved
|
||||
|
||||
query = AsyncMock()
|
||||
query.data = "ea:once:99"
|
||||
query.message = MagicMock()
|
||||
query.message.chat_id = 12345
|
||||
query.from_user = MagicMock()
|
||||
query.from_user.first_name = "Bob"
|
||||
query.answer = AsyncMock()
|
||||
|
||||
update = MagicMock()
|
||||
update.callback_query = query
|
||||
context = MagicMock()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._handle_callback_query(update, context)
|
||||
|
||||
# Should NOT resolve — already handled
|
||||
mock_resolve.assert_not_called()
|
||||
# Should still ack with "already resolved" message
|
||||
query.answer.assert_called_once()
|
||||
assert "already been resolved" in query.answer.call_args[1]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_picker_callback_not_affected(self):
|
||||
"""Ensure model picker callbacks still route correctly."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
query = AsyncMock()
|
||||
query.data = "mp:some_provider"
|
||||
query.message = MagicMock()
|
||||
query.message.chat_id = 12345
|
||||
query.from_user = MagicMock()
|
||||
|
||||
update = MagicMock()
|
||||
update.callback_query = query
|
||||
context = MagicMock()
|
||||
|
||||
# Model picker callback should be handled (not crash)
|
||||
# We just verify it doesn't try to resolve an approval
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
with patch.object(adapter, "_handle_model_picker_callback", new_callable=AsyncMock):
|
||||
await adapter._handle_callback_query(update, context)
|
||||
|
||||
mock_resolve.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_prompt_callback_not_affected(self):
|
||||
"""Ensure update prompt callbacks still work."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
query = AsyncMock()
|
||||
query.data = "update_prompt:y"
|
||||
query.message = MagicMock()
|
||||
query.message.chat_id = 12345
|
||||
query.from_user = MagicMock()
|
||||
query.from_user.id = 123
|
||||
query.answer = AsyncMock()
|
||||
query.edit_message_text = AsyncMock()
|
||||
|
||||
update = MagicMock()
|
||||
update.callback_query = query
|
||||
context = MagicMock()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
with patch("hermes_constants.get_hermes_home", return_value=Path("/tmp/test")):
|
||||
try:
|
||||
await adapter._handle_callback_query(update, context)
|
||||
except Exception:
|
||||
pass # May fail on file write, that's fine
|
||||
|
||||
# Should NOT have triggered approval resolution
|
||||
mock_resolve.assert_not_called()
|
||||
77
tests/gateway/test_telegram_caption_merge.py
Normal file
77
tests/gateway/test_telegram_caption_merge.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Tests for TelegramPlatform._merge_caption caption deduplication logic."""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
merge = TelegramAdapter._merge_caption
|
||||
|
||||
|
||||
class TestMergeCaptionBasic:
|
||||
def test_no_existing_text(self):
|
||||
assert merge(None, "Hello") == "Hello"
|
||||
|
||||
def test_empty_existing_text(self):
|
||||
assert merge("", "Hello") == "Hello"
|
||||
|
||||
def test_exact_duplicate_dropped(self):
|
||||
assert merge("Revenue", "Revenue") == "Revenue"
|
||||
|
||||
def test_different_captions_merged(self):
|
||||
result = merge("Q3 Results", "Q4 Projections")
|
||||
assert result == "Q3 Results\n\nQ4 Projections"
|
||||
|
||||
|
||||
class TestMergeCaptionSubstringBug:
|
||||
"""These are the exact scenarios that the old substring check got wrong."""
|
||||
|
||||
def test_shorter_caption_not_dropped_when_substring(self):
|
||||
# Bug: "Meeting" in "Meeting agenda" → True → caption was silently lost
|
||||
result = merge("Meeting agenda", "Meeting")
|
||||
assert result == "Meeting agenda\n\nMeeting"
|
||||
|
||||
def test_longer_caption_not_dropped_when_contains_existing(self):
|
||||
# "Revenue and Profit" contains "Revenue", but they are different captions
|
||||
result = merge("Revenue", "Revenue and Profit")
|
||||
assert result == "Revenue\n\nRevenue and Profit"
|
||||
|
||||
def test_prefix_caption_not_dropped(self):
|
||||
result = merge("Q3 Results - Revenue", "Q3 Results")
|
||||
assert result == "Q3 Results - Revenue\n\nQ3 Results"
|
||||
|
||||
|
||||
class TestMergeCaptionWhitespace:
|
||||
def test_trailing_space_treated_as_duplicate(self):
|
||||
assert merge("Revenue", "Revenue ") == "Revenue"
|
||||
|
||||
def test_leading_space_treated_as_duplicate(self):
|
||||
assert merge("Revenue", " Revenue") == "Revenue"
|
||||
|
||||
def test_whitespace_only_new_text_not_added(self):
|
||||
# strip() makes it empty string → falsy check in callers guards this,
|
||||
# but _merge_caption itself: strip matches "" which is not in list → would merge.
|
||||
# Callers already guard with `if event.text:` so this is an edge case.
|
||||
result = merge("Revenue", " ")
|
||||
# " ".strip() == "" → not in ["Revenue"] → gets merged (caller guards prevent this)
|
||||
assert "\n\n" in result or result == "Revenue"
|
||||
|
||||
|
||||
class TestMergeCaptionMultipleItems:
|
||||
def test_three_unique_captions_all_present(self):
|
||||
text = merge(None, "A")
|
||||
text = merge(text, "B")
|
||||
text = merge(text, "C")
|
||||
assert text == "A\n\nB\n\nC"
|
||||
|
||||
def test_duplicate_in_middle_dropped(self):
|
||||
text = merge(None, "A")
|
||||
text = merge(text, "B")
|
||||
text = merge(text, "A") # duplicate
|
||||
assert text == "A\n\nB"
|
||||
|
||||
def test_album_scenario_revenue_profit(self):
|
||||
# Album Item 1: "Revenue and Profit", Item 2: "Revenue"
|
||||
# Old bug: "Revenue" in ["Revenue and Profit"] → True → lost
|
||||
text = merge(None, "Revenue and Profit")
|
||||
text = merge(text, "Revenue")
|
||||
assert text == "Revenue and Profit\n\nRevenue"
|
||||
@@ -20,8 +20,16 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
# Provide real exception classes so ``except (NetworkError, ...)`` in
|
||||
# connect() doesn't blow up with "catching classes that do not inherit
|
||||
# from BaseException" when another xdist worker pollutes sys.modules.
|
||||
telegram_mod.error.NetworkError = type("NetworkError", (OSError,), {})
|
||||
telegram_mod.error.TimedOut = type("TimedOut", (OSError,), {})
|
||||
telegram_mod.error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
sys.modules.setdefault("telegram.error", telegram_mod.error)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
260
tests/gateway/test_telegram_reactions.py
Normal file
260
tests/gateway/test_telegram_reactions.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Tests for Telegram message reactions tied to processing lifecycle hooks."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_adapter(**extra_env):
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
adapter.config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter._bot = AsyncMock()
|
||||
adapter._bot.set_message_reaction = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(chat_id: str = "123", message_id: str = "456") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type="private",
|
||||
user_id="42",
|
||||
user_name="TestUser",
|
||||
),
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
|
||||
# ── _reactions_enabled ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_reactions_disabled_by_default(monkeypatch):
|
||||
"""Telegram reactions should be disabled by default."""
|
||||
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
def test_reactions_enabled_when_set_true(monkeypatch):
|
||||
"""Setting TELEGRAM_REACTIONS=true enables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is True
|
||||
|
||||
|
||||
def test_reactions_enabled_with_1(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=1 enables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "1")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is True
|
||||
|
||||
|
||||
def test_reactions_disabled_with_false(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=false disables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "false")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
def test_reactions_disabled_with_0(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=0 disables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "0")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
def test_reactions_disabled_with_no(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=no disables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "no")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
# ── _set_reaction ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_reaction_calls_bot_api(monkeypatch):
|
||||
"""_set_reaction should call bot.set_message_reaction with correct args."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
|
||||
result = await adapter._set_reaction("123", "456", "\U0001f440")
|
||||
|
||||
assert result is True
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\U0001f440",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_reaction_returns_false_without_bot(monkeypatch):
|
||||
"""_set_reaction should return False when bot is not available."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
adapter._bot = None
|
||||
|
||||
result = await adapter._set_reaction("123", "456", "\U0001f440")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_reaction_handles_api_error_gracefully(monkeypatch):
|
||||
"""API errors during reaction should not propagate."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
adapter._bot.set_message_reaction = AsyncMock(side_effect=RuntimeError("no perms"))
|
||||
|
||||
result = await adapter._set_reaction("123", "456", "\U0001f440")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ── on_processing_start ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_adds_eyes_reaction(monkeypatch):
|
||||
"""Processing start should add eyes reaction when enabled."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\U0001f440",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_skipped_when_disabled(monkeypatch):
|
||||
"""Processing start should not react when reactions are disabled."""
|
||||
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_handles_missing_ids(monkeypatch):
|
||||
"""Should handle events without chat_id or message_id gracefully."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=SimpleNamespace(chat_id=None),
|
||||
message_id=None,
|
||||
)
|
||||
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
# ── on_processing_complete ───────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_success(monkeypatch):
|
||||
"""Successful processing should set check mark reaction."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_complete(event, success=True)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\u2705",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_failure(monkeypatch):
|
||||
"""Failed processing should set cross mark reaction."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_complete(event, success=False)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\u274c",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_skipped_when_disabled(monkeypatch):
|
||||
"""Processing complete should not react when reactions are disabled."""
|
||||
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_complete(event, success=True)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
# ── config.py bridging ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_config_bridges_telegram_reactions(monkeypatch, tmp_path):
|
||||
"""gateway/config.py bridges telegram.reactions to TELEGRAM_REACTIONS env var."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"telegram": {
|
||||
"reactions": True,
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Use setenv (not delenv) so monkeypatch registers cleanup even when
|
||||
# the var doesn't exist yet — load_gateway_config will overwrite it.
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("TELEGRAM_REACTIONS") == "true"
|
||||
|
||||
|
||||
def test_config_reactions_env_takes_precedence(monkeypatch, tmp_path):
|
||||
"""Env var should take precedence over config.yaml for reactions."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"telegram": {
|
||||
"reactions": True,
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "false")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("TELEGRAM_REACTIONS") == "false"
|
||||
@@ -590,8 +590,15 @@ class TestSessionIsolation:
|
||||
class TestDeliveryCleanup:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_cleaned_after_send(self):
|
||||
"""send() pops delivery_info so the entry doesn't leak memory."""
|
||||
async def test_delivery_info_survives_multiple_sends(self):
|
||||
"""send() must NOT pop delivery_info.
|
||||
|
||||
Interim status messages (fallback notifications, context-pressure
|
||||
warnings, etc.) flow through the same send() path as the final
|
||||
response. If the entry were popped on the first send, the final
|
||||
response would silently downgrade to the ``log`` deliver type.
|
||||
Regression test for that bug.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
chat_id = "webhook:test:d-xyz"
|
||||
adapter._delivery_info[chat_id] = {
|
||||
@@ -599,10 +606,40 @@ class TestDeliveryCleanup:
|
||||
"deliver_extra": {},
|
||||
"payload": {"x": 1},
|
||||
}
|
||||
adapter._delivery_info_created[chat_id] = time.time()
|
||||
|
||||
result = await adapter.send(chat_id, "Agent response here")
|
||||
assert result.success is True
|
||||
assert chat_id not in adapter._delivery_info
|
||||
# First send (e.g. an interim status message)
|
||||
result1 = await adapter.send(chat_id, "Status: switching to fallback")
|
||||
assert result1.success is True
|
||||
# Entry must still be present so the final send can read it
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Second send (the final agent response)
|
||||
result2 = await adapter.send(chat_id, "Final agent response")
|
||||
assert result2.success is True
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_pruned_via_ttl(self):
|
||||
"""Stale delivery_info entries are dropped on the next POST."""
|
||||
adapter = _make_adapter()
|
||||
adapter._idempotency_ttl = 60 # short TTL for the test
|
||||
now = time.time()
|
||||
|
||||
# Stale entry — older than TTL
|
||||
adapter._delivery_info["webhook:test:old"] = {"deliver": "log"}
|
||||
adapter._delivery_info_created["webhook:test:old"] = now - 120
|
||||
|
||||
# Fresh entry — should survive
|
||||
adapter._delivery_info["webhook:test:new"] = {"deliver": "log"}
|
||||
adapter._delivery_info_created["webhook:test:new"] = now - 5
|
||||
|
||||
adapter._prune_delivery_info(now)
|
||||
|
||||
assert "webhook:test:old" not in adapter._delivery_info
|
||||
assert "webhook:test:old" not in adapter._delivery_info_created
|
||||
assert "webhook:test:new" in adapter._delivery_info
|
||||
assert "webhook:test:new" in adapter._delivery_info_created
|
||||
|
||||
|
||||
# ===================================================================
|
||||
@@ -617,3 +654,107 @@ class TestCheckRequirements:
|
||||
@patch("gateway.platforms.webhook.AIOHTTP_AVAILABLE", False)
|
||||
def test_returns_false_without_aiohttp(self):
|
||||
assert check_webhook_requirements() is False
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# __raw__ template token
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRawTemplateToken:
|
||||
"""Tests for the {__raw__} special token in _render_prompt."""
|
||||
|
||||
def test_raw_resolves_to_full_json_payload(self):
|
||||
"""{__raw__} in a template dumps the entire payload as JSON."""
|
||||
adapter = _make_adapter()
|
||||
payload = {"action": "opened", "number": 42}
|
||||
result = adapter._render_prompt(
|
||||
"Payload: {__raw__}", payload, "push", "test"
|
||||
)
|
||||
expected_json = json.dumps(payload, indent=2)
|
||||
assert result == f"Payload: {expected_json}"
|
||||
|
||||
def test_raw_truncated_at_4000_chars(self):
|
||||
"""{__raw__} output is truncated at 4000 characters for large payloads."""
|
||||
adapter = _make_adapter()
|
||||
# Build a payload whose JSON repr exceeds 4000 chars
|
||||
payload = {"data": "x" * 5000}
|
||||
result = adapter._render_prompt("{__raw__}", payload, "push", "test")
|
||||
assert len(result) <= 4000
|
||||
|
||||
def test_raw_mixed_with_other_variables(self):
|
||||
"""{__raw__} can be mixed with regular template variables."""
|
||||
adapter = _make_adapter()
|
||||
payload = {"action": "closed", "number": 7}
|
||||
result = adapter._render_prompt(
|
||||
"Action={action} Raw={__raw__}", payload, "push", "test"
|
||||
)
|
||||
assert result.startswith("Action=closed Raw=")
|
||||
assert '"action": "closed"' in result
|
||||
assert '"number": 7' in result
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Cross-platform delivery thread_id passthrough
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestDeliverCrossPlatformThreadId:
|
||||
"""Tests for thread_id passthrough in _deliver_cross_platform."""
|
||||
|
||||
def _setup_adapter_with_mock_target(self):
|
||||
"""Set up a webhook adapter with a mocked gateway_runner and target adapter."""
|
||||
adapter = _make_adapter()
|
||||
mock_target = AsyncMock()
|
||||
mock_target.send = AsyncMock(return_value=SendResult(success=True))
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.adapters = {Platform("telegram"): mock_target}
|
||||
mock_runner.config.get_home_channel.return_value = None
|
||||
|
||||
adapter.gateway_runner = mock_runner
|
||||
return adapter, mock_target
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_passed_as_metadata(self):
|
||||
"""thread_id from deliver_extra is passed as metadata to adapter.send()."""
|
||||
adapter, mock_target = self._setup_adapter_with_mock_target()
|
||||
delivery = {
|
||||
"deliver_extra": {
|
||||
"chat_id": "12345",
|
||||
"thread_id": "999",
|
||||
}
|
||||
}
|
||||
await adapter._deliver_cross_platform("telegram", "hello", delivery)
|
||||
mock_target.send.assert_awaited_once_with(
|
||||
"12345", "hello", metadata={"thread_id": "999"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_thread_id_passed_as_thread_id(self):
|
||||
"""message_thread_id from deliver_extra is mapped to thread_id in metadata."""
|
||||
adapter, mock_target = self._setup_adapter_with_mock_target()
|
||||
delivery = {
|
||||
"deliver_extra": {
|
||||
"chat_id": "12345",
|
||||
"message_thread_id": "888",
|
||||
}
|
||||
}
|
||||
await adapter._deliver_cross_platform("telegram", "hello", delivery)
|
||||
mock_target.send.assert_awaited_once_with(
|
||||
"12345", "hello", metadata={"thread_id": "888"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_id_sends_no_metadata(self):
|
||||
"""When no thread_id is present, metadata is None."""
|
||||
adapter, mock_target = self._setup_adapter_with_mock_target()
|
||||
delivery = {
|
||||
"deliver_extra": {
|
||||
"chat_id": "12345",
|
||||
}
|
||||
}
|
||||
await adapter._deliver_cross_platform("telegram", "hello", delivery)
|
||||
mock_target.send.assert_awaited_once_with(
|
||||
"12345", "hello", metadata=None
|
||||
)
|
||||
|
||||
@@ -257,10 +257,11 @@ class TestCrossPlatformDelivery:
|
||||
|
||||
assert result.success is True
|
||||
mock_tg_adapter.send.assert_awaited_once_with(
|
||||
"12345", "I've acknowledged the alert."
|
||||
"12345", "I've acknowledged the alert.", metadata=None
|
||||
)
|
||||
# Delivery info should be cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
# Delivery info is retained after send() so interim status messages
|
||||
# don't strand the final response (TTL-based cleanup happens on POST).
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
@@ -333,5 +334,6 @@ class TestGitHubCommentDelivery:
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
# Delivery info cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
# Delivery info is retained after send() so interim status messages
|
||||
# don't strand the final response (TTL-based cleanup happens on POST).
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
@@ -40,6 +40,7 @@ def test_run_anthropic_oauth_flow_manual_token_still_persists(tmp_path, monkeypa
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.is_claude_code_token_valid", lambda creds: False)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "sk-ant-oat01-manual-token")
|
||||
monkeypatch.setattr("getpass.getpass", lambda _prompt="": "sk-ant-oat01-manual-token")
|
||||
|
||||
from hermes_cli.main import _run_anthropic_oauth_flow
|
||||
|
||||
@@ -350,6 +350,7 @@ class TestResolveApiKeyProviderCredentials:
|
||||
|
||||
def test_resolve_zai_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-secret-key")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["provider"] == "zai"
|
||||
assert creds["api_key"] == "glm-secret-key"
|
||||
@@ -471,6 +472,7 @@ class TestResolveApiKeyProviderCredentials:
|
||||
"""GLM_API_KEY takes priority over ZAI_API_KEY."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "primary")
|
||||
monkeypatch.setenv("ZAI_API_KEY", "secondary")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["api_key"] == "primary"
|
||||
assert creds["source"] == "GLM_API_KEY"
|
||||
@@ -478,6 +480,7 @@ class TestResolveApiKeyProviderCredentials:
|
||||
def test_zai_key_fallback(self, monkeypatch):
|
||||
"""ZAI_API_KEY used when GLM_API_KEY not set."""
|
||||
monkeypatch.setenv("ZAI_API_KEY", "secondary")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["api_key"] == "secondary"
|
||||
assert creds["source"] == "ZAI_API_KEY"
|
||||
@@ -830,11 +833,58 @@ class TestKimiCodeCredentialAutoDetect:
|
||||
|
||||
def test_non_kimi_providers_unaffected(self, monkeypatch):
|
||||
"""Ensure the auto-detect logic doesn't leak to other providers."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "sk-kimi-looks-like-kimi-but-isnt")
|
||||
monkeypatch.setenv("GLM_API_KEY", "sk-kim...isnt")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
|
||||
|
||||
|
||||
class TestZaiEndpointAutoDetect:
|
||||
"""Test that resolve_api_key_provider_credentials auto-detects Z.AI endpoints."""
|
||||
|
||||
def test_probe_success_returns_detected_url(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-coding-key")
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.detect_zai_endpoint",
|
||||
lambda *a, **kw: {
|
||||
"id": "coding-global",
|
||||
"base_url": "https://api.z.ai/api/coding/paas/v4",
|
||||
"model": "glm-4.7",
|
||||
"label": "Global (Coding Plan)",
|
||||
},
|
||||
)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://api.z.ai/api/coding/paas/v4"
|
||||
|
||||
def test_probe_failure_falls_back_to_default(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-key")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
|
||||
|
||||
def test_env_override_skips_probe(self, monkeypatch):
|
||||
"""GLM_BASE_URL should always win without probing."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-key")
|
||||
monkeypatch.setenv("GLM_BASE_URL", "https://custom.example/v4")
|
||||
probe_called = False
|
||||
|
||||
def _never_called(*a, **kw):
|
||||
nonlocal probe_called
|
||||
probe_called = True
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", _never_called)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://custom.example/v4"
|
||||
assert not probe_called
|
||||
|
||||
def test_no_key_skips_probe(self, monkeypatch):
|
||||
"""Without an API key, no probe should occur."""
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_zai_endpoint", lambda *a, **kw: None)
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["api_key"] == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi / Moonshot model list isolation tests
|
||||
# =============================================================================
|
||||
399
tests/hermes_cli/test_auth_qwen_provider.py
Normal file
399
tests/hermes_cli/test_auth_qwen_provider.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Tests for Qwen OAuth provider authentication (hermes_cli/auth.py).
|
||||
|
||||
Covers: _qwen_cli_auth_path, _read_qwen_cli_tokens, _save_qwen_cli_tokens,
|
||||
_qwen_access_token_is_expiring, _refresh_qwen_cli_tokens,
|
||||
resolve_qwen_runtime_credentials, get_qwen_auth_status.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
DEFAULT_QWEN_BASE_URL,
|
||||
QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
_qwen_cli_auth_path,
|
||||
_read_qwen_cli_tokens,
|
||||
_save_qwen_cli_tokens,
|
||||
_qwen_access_token_is_expiring,
|
||||
_refresh_qwen_cli_tokens,
|
||||
resolve_qwen_runtime_credentials,
|
||||
get_qwen_auth_status,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_qwen_tokens(
|
||||
access_token="test-access-token",
|
||||
refresh_token="test-refresh-token",
|
||||
expiry_date=None,
|
||||
**extra,
|
||||
):
|
||||
"""Create a minimal Qwen CLI OAuth credential dict."""
|
||||
if expiry_date is None:
|
||||
# 1 hour from now in milliseconds
|
||||
expiry_date = int((time.time() + 3600) * 1000)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "Bearer",
|
||||
"expiry_date": expiry_date,
|
||||
"resource_url": "portal.qwen.ai",
|
||||
}
|
||||
data.update(extra)
|
||||
return data
|
||||
|
||||
|
||||
def _write_qwen_creds(tmp_path, tokens=None):
|
||||
"""Write tokens to the Qwen CLI credentials file and return the path."""
|
||||
qwen_dir = tmp_path / ".qwen"
|
||||
qwen_dir.mkdir(parents=True, exist_ok=True)
|
||||
creds_path = qwen_dir / "oauth_creds.json"
|
||||
if tokens is None:
|
||||
tokens = _make_qwen_tokens()
|
||||
creds_path.write_text(json.dumps(tokens), encoding="utf-8")
|
||||
return creds_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def qwen_env(tmp_path, monkeypatch):
|
||||
"""Redirect _qwen_cli_auth_path to tmp_path/.qwen/oauth_creds.json."""
|
||||
creds_path = tmp_path / ".qwen" / "oauth_creds.json"
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._qwen_cli_auth_path", lambda: creds_path
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _qwen_cli_auth_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_qwen_cli_auth_path_returns_expected_location():
|
||||
path = _qwen_cli_auth_path()
|
||||
assert path == Path.home() / ".qwen" / "oauth_creds.json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_qwen_cli_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_read_qwen_cli_tokens_success(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="my-access")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
result = _read_qwen_cli_tokens()
|
||||
assert result["access_token"] == "my-access"
|
||||
assert result["refresh_token"] == "test-refresh-token"
|
||||
|
||||
|
||||
def test_read_qwen_cli_tokens_missing_file(qwen_env):
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_read_qwen_cli_tokens()
|
||||
assert exc.value.code == "qwen_auth_missing"
|
||||
|
||||
|
||||
def test_read_qwen_cli_tokens_invalid_json(qwen_env):
|
||||
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
|
||||
creds_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
creds_path.write_text("not json{{{", encoding="utf-8")
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_read_qwen_cli_tokens()
|
||||
assert exc.value.code == "qwen_auth_read_failed"
|
||||
|
||||
|
||||
def test_read_qwen_cli_tokens_non_dict(qwen_env):
|
||||
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
|
||||
creds_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
creds_path.write_text(json.dumps(["a", "b"]), encoding="utf-8")
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_read_qwen_cli_tokens()
|
||||
assert exc.value.code == "qwen_auth_invalid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _save_qwen_cli_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_save_qwen_cli_tokens_roundtrip(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="saved-token")
|
||||
saved_path = _save_qwen_cli_tokens(tokens)
|
||||
assert saved_path.exists()
|
||||
loaded = json.loads(saved_path.read_text(encoding="utf-8"))
|
||||
assert loaded["access_token"] == "saved-token"
|
||||
|
||||
|
||||
def test_save_qwen_cli_tokens_creates_parent(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
saved_path = _save_qwen_cli_tokens(tokens)
|
||||
assert saved_path.parent.exists()
|
||||
|
||||
|
||||
def test_save_qwen_cli_tokens_permissions(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
saved_path = _save_qwen_cli_tokens(tokens)
|
||||
mode = saved_path.stat().st_mode
|
||||
assert mode & stat.S_IRUSR # owner read
|
||||
assert mode & stat.S_IWUSR # owner write
|
||||
assert not (mode & stat.S_IRGRP) # no group read
|
||||
assert not (mode & stat.S_IROTH) # no other read
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _qwen_access_token_is_expiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_expiring_token_not_expired():
|
||||
# 1 hour from now in milliseconds
|
||||
future_ms = int((time.time() + 3600) * 1000)
|
||||
assert not _qwen_access_token_is_expiring(future_ms)
|
||||
|
||||
|
||||
def test_expiring_token_already_expired():
|
||||
# 1 hour ago in milliseconds
|
||||
past_ms = int((time.time() - 3600) * 1000)
|
||||
assert _qwen_access_token_is_expiring(past_ms)
|
||||
|
||||
|
||||
def test_expiring_token_within_skew():
|
||||
# Just inside the default skew window
|
||||
near_ms = int((time.time() + QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS - 5) * 1000)
|
||||
assert _qwen_access_token_is_expiring(near_ms)
|
||||
|
||||
|
||||
def test_expiring_token_none_returns_true():
|
||||
assert _qwen_access_token_is_expiring(None)
|
||||
|
||||
|
||||
def test_expiring_token_non_numeric_returns_true():
|
||||
assert _qwen_access_token_is_expiring("not-a-number")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _refresh_qwen_cli_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_refresh_qwen_cli_tokens_success(qwen_env):
|
||||
tokens = _make_qwen_tokens(refresh_token="old-refresh")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 7200,
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
result = _refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
assert result["access_token"] == "new-access"
|
||||
assert result["refresh_token"] == "new-refresh"
|
||||
assert "expiry_date" in result
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_preserves_old_refresh_if_not_in_response(qwen_env):
|
||||
tokens = _make_qwen_tokens(refresh_token="keep-me")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"access_token": "new-access",
|
||||
# No refresh_token in response — should keep old one
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
result = _refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
assert result["refresh_token"] == "keep-me"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_missing_refresh_token():
|
||||
tokens = {"access_token": "at", "refresh_token": ""}
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_token_missing"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_http_error(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 401
|
||||
resp.text = "unauthorized"
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_failed"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_network_error(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.side_effect = ConnectionError("timeout")
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_failed"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_invalid_json_response(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.side_effect = ValueError("bad json")
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_invalid_json"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_missing_access_token_in_response(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"something": "but no access_token"}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_invalid_response"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_default_expires_in(qwen_env):
|
||||
"""When expires_in is missing, default to 6 hours."""
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"access_token": "new"}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
result = _refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
# Verify expiry_date is roughly now + 6h (within 60s tolerance)
|
||||
expected_ms = int(time.time() * 1000) + 6 * 60 * 60 * 1000
|
||||
assert abs(result["expiry_date"] - expected_ms) < 60_000
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_saves_to_disk(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"access_token": "disk-check",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
# Verify it was persisted
|
||||
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
|
||||
assert creds_path.exists()
|
||||
saved = json.loads(creds_path.read_text(encoding="utf-8"))
|
||||
assert saved["access_token"] == "disk-check"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_qwen_runtime_credentials
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_fresh_token(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="fresh-at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
assert creds["provider"] == "qwen-oauth"
|
||||
assert creds["api_key"] == "fresh-at"
|
||||
assert creds["base_url"] == DEFAULT_QWEN_BASE_URL
|
||||
assert creds["source"] == "qwen-cli"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_triggers_refresh(qwen_env):
|
||||
# Write an expired token
|
||||
expired_ms = int((time.time() - 3600) * 1000)
|
||||
tokens = _make_qwen_tokens(access_token="old", expiry_date=expired_ms)
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
refreshed = _make_qwen_tokens(access_token="refreshed-at")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
|
||||
) as mock_refresh:
|
||||
creds = resolve_qwen_runtime_credentials()
|
||||
mock_refresh.assert_called_once()
|
||||
assert creds["api_key"] == "refreshed-at"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_force_refresh(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="old-at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
refreshed = _make_qwen_tokens(access_token="force-refreshed")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
|
||||
) as mock_refresh:
|
||||
creds = resolve_qwen_runtime_credentials(force_refresh=True)
|
||||
mock_refresh.assert_called_once()
|
||||
assert creds["api_key"] == "force-refreshed"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_missing_access_token(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
with pytest.raises(AuthError) as exc:
|
||||
resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
assert exc.value.code == "qwen_access_token_missing"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_base_url_env_override(qwen_env, monkeypatch):
|
||||
tokens = _make_qwen_tokens(access_token="at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
monkeypatch.setenv("HERMES_QWEN_BASE_URL", "https://custom.qwen.ai/v1")
|
||||
|
||||
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
assert creds["base_url"] == "https://custom.qwen.ai/v1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_qwen_auth_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_qwen_auth_status_logged_in(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="status-at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
status = get_qwen_auth_status()
|
||||
assert status["logged_in"] is True
|
||||
assert status["api_key"] == "status-at"
|
||||
|
||||
|
||||
def test_get_qwen_auth_status_not_logged_in(qwen_env):
|
||||
# No credentials file
|
||||
status = get_qwen_auth_status()
|
||||
assert status["logged_in"] is False
|
||||
assert "error" in status
|
||||
63
tests/hermes_cli/test_banner_git_state.py
Normal file
63
tests/hermes_cli/test_banner_git_state.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_format_banner_version_label_without_git_state():
|
||||
from hermes_cli import banner
|
||||
|
||||
with patch.object(banner, "get_git_banner_state", return_value=None):
|
||||
value = banner.format_banner_version_label()
|
||||
|
||||
assert value == f"Hermes Agent v{banner.VERSION} ({banner.RELEASE_DATE})"
|
||||
|
||||
|
||||
def test_format_banner_version_label_on_upstream_main():
|
||||
from hermes_cli import banner
|
||||
|
||||
with patch.object(
|
||||
banner,
|
||||
"get_git_banner_state",
|
||||
return_value={"upstream": "b2f477a3", "local": "b2f477a3", "ahead": 0},
|
||||
):
|
||||
value = banner.format_banner_version_label()
|
||||
|
||||
assert value.endswith("· upstream b2f477a3")
|
||||
assert "local" not in value
|
||||
|
||||
|
||||
def test_format_banner_version_label_with_carried_commits():
|
||||
from hermes_cli import banner
|
||||
|
||||
with patch.object(
|
||||
banner,
|
||||
"get_git_banner_state",
|
||||
return_value={"upstream": "b2f477a3", "local": "af8aad31", "ahead": 3},
|
||||
):
|
||||
value = banner.format_banner_version_label()
|
||||
|
||||
assert "upstream b2f477a3" in value
|
||||
assert "local af8aad31" in value
|
||||
assert "+3 carried commits" in value
|
||||
|
||||
|
||||
def test_get_git_banner_state_reads_origin_and_head(tmp_path):
|
||||
from hermes_cli import banner
|
||||
|
||||
repo_dir = tmp_path / "repo"
|
||||
(repo_dir / ".git").mkdir(parents=True)
|
||||
|
||||
results = {
|
||||
("git", "rev-parse", "--short=8", "origin/main"): MagicMock(returncode=0, stdout="b2f477a3\n"),
|
||||
("git", "rev-parse", "--short=8", "HEAD"): MagicMock(returncode=0, stdout="af8aad31\n"),
|
||||
("git", "rev-list", "--count", "origin/main..HEAD"): MagicMock(returncode=0, stdout="3\n"),
|
||||
}
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
key = tuple(cmd)
|
||||
if key not in results:
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
return results[key]
|
||||
|
||||
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
|
||||
state = banner.get_git_banner_state(repo_dir)
|
||||
|
||||
assert state == {"upstream": "b2f477a3", "local": "af8aad31", "ahead": 3}
|
||||
@@ -136,3 +136,73 @@ def test_check_gateway_service_linger_skips_when_service_not_installed(monkeypat
|
||||
out = capsys.readouterr().out
|
||||
assert out == ""
|
||||
assert issues == []
|
||||
|
||||
|
||||
# ── Memory provider section (doctor should only check the *active* provider) ──
|
||||
|
||||
|
||||
class TestDoctorMemoryProviderSection:
|
||||
"""The ◆ Memory Provider section should respect memory.provider config."""
|
||||
|
||||
def _make_hermes_home(self, tmp_path, provider=""):
|
||||
"""Create a minimal HERMES_HOME with config.yaml."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
import yaml
|
||||
config = {"memory": {"provider": provider}} if provider else {"memory": {}}
|
||||
(home / "config.yaml").write_text(yaml.dump(config))
|
||||
return home
|
||||
|
||||
def _run_doctor_and_capture(self, monkeypatch, tmp_path, provider=""):
|
||||
"""Run doctor and capture stdout."""
|
||||
home = self._make_hermes_home(tmp_path, provider)
|
||||
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
|
||||
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", tmp_path / "project")
|
||||
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
|
||||
(tmp_path / "project").mkdir(exist_ok=True)
|
||||
|
||||
# Stub tool availability (returns empty) so doctor runs past it
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
|
||||
# Stub auth checks to avoid real API calls
|
||||
try:
|
||||
from hermes_cli import auth as _auth_mod
|
||||
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {})
|
||||
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import io, contextlib
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stdout(buf):
|
||||
doctor_mod.run_doctor(Namespace(fix=False))
|
||||
return buf.getvalue()
|
||||
|
||||
def test_no_provider_shows_builtin_ok(self, monkeypatch, tmp_path):
|
||||
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="")
|
||||
assert "Memory Provider" in out
|
||||
assert "Built-in memory active" in out
|
||||
# Should NOT mention Honcho or Mem0 errors
|
||||
assert "Honcho API key" not in out
|
||||
assert "Mem0" not in out
|
||||
|
||||
def test_honcho_provider_not_installed_shows_fail(self, monkeypatch, tmp_path):
|
||||
# Make honcho import fail
|
||||
monkeypatch.setitem(
|
||||
sys.modules, "plugins.memory.honcho.client", None
|
||||
)
|
||||
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="honcho")
|
||||
assert "Memory Provider" in out
|
||||
# Should show failure since honcho is set but not importable
|
||||
assert "Built-in memory active" not in out
|
||||
|
||||
def test_mem0_provider_not_installed_shows_fail(self, monkeypatch, tmp_path):
|
||||
# Make mem0 import fail
|
||||
monkeypatch.setitem(sys.modules, "plugins.memory.mem0", None)
|
||||
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="mem0")
|
||||
assert "Memory Provider" in out
|
||||
assert "Built-in memory active" not in out
|
||||
|
||||
@@ -641,3 +641,69 @@ class TestEnsureUserSystemdEnv:
|
||||
result = gateway_cli._systemctl_cmd(system=True)
|
||||
assert result == ["systemctl"]
|
||||
assert calls == []
|
||||
|
||||
|
||||
class TestProfileArg:
|
||||
"""Tests for _profile_arg — returns '--profile <name>' for named profiles."""
|
||||
|
||||
def test_default_hermes_home_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""Default ~/.hermes should not produce a --profile flag."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(hermes_home))
|
||||
assert result == ""
|
||||
|
||||
def test_named_profile_returns_flag(self, tmp_path, monkeypatch):
|
||||
"""~/.hermes/profiles/mybot should return '--profile mybot'."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(profile_dir))
|
||||
assert result == "--profile mybot"
|
||||
|
||||
def test_hash_path_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""Arbitrary non-profile HERMES_HOME should return empty string."""
|
||||
custom_home = tmp_path / "custom" / "hermes"
|
||||
custom_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(custom_home))
|
||||
assert result == ""
|
||||
|
||||
def test_nested_profile_path_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""~/.hermes/profiles/mybot/subdir should NOT match — too deep."""
|
||||
nested = tmp_path / ".hermes" / "profiles" / "mybot" / "subdir"
|
||||
nested.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(nested))
|
||||
assert result == ""
|
||||
|
||||
def test_invalid_profile_name_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""Profile names with invalid chars should not match the regex."""
|
||||
bad_profile = tmp_path / ".hermes" / "profiles" / "My Bot!"
|
||||
bad_profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(bad_profile))
|
||||
assert result == ""
|
||||
|
||||
def test_systemd_unit_includes_profile(self, tmp_path, monkeypatch):
|
||||
"""generate_systemd_unit should include --profile in ExecStart for named profiles."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
unit = gateway_cli.generate_systemd_unit(system=False)
|
||||
assert "--profile mybot" in unit
|
||||
assert "gateway run --replace" in unit
|
||||
|
||||
def test_launchd_plist_includes_profile(self, tmp_path, monkeypatch):
|
||||
"""generate_launchd_plist should include --profile in ProgramArguments for named profiles."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
assert "<string>--profile</string>" in plist
|
||||
assert "<string>mybot</string>" in plist
|
||||
|
||||
@@ -171,7 +171,11 @@ class TestGeminiModelNormalization:
|
||||
|
||||
class TestGeminiContextLength:
|
||||
def test_gemma_4_31b_context(self):
|
||||
ctx = get_model_context_length("gemma-4-31b-it", provider="gemini")
|
||||
# Mock external API lookups to test against hardcoded defaults
|
||||
# (models.dev and OpenRouter may return different values like 262144).
|
||||
with patch("agent.models_dev.lookup_models_dev_context", return_value=None), \
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}):
|
||||
ctx = get_model_context_length("gemma-4-31b-it", provider="gemini")
|
||||
assert ctx == 256000
|
||||
|
||||
def test_gemma_4_26b_context(self):
|
||||
@@ -1,6 +1,15 @@
|
||||
"""Tests for the hermes_cli models module."""
|
||||
|
||||
from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from hermes_cli.models import (
|
||||
OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model,
|
||||
filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS,
|
||||
is_nous_free_tier, partition_nous_models_by_tier,
|
||||
check_nous_free_tier, clear_nous_free_tier_cache,
|
||||
_FREE_TIER_CACHE_TTL,
|
||||
)
|
||||
import hermes_cli.models as _models_mod
|
||||
|
||||
|
||||
class TestModelIds:
|
||||
@@ -124,3 +133,226 @@ class TestDetectProviderForModel:
|
||||
result = detect_provider_for_model("claude-opus-4-6", "openai-codex")
|
||||
assert result is not None
|
||||
assert result[0] not in ("nous",) # nous has claude models but shouldn't be suggested
|
||||
|
||||
|
||||
class TestFilterNousFreeModels:
|
||||
"""Tests for filter_nous_free_models — Nous Portal free-model policy."""
|
||||
|
||||
_PAID = {"prompt": "0.000003", "completion": "0.000015"}
|
||||
_FREE = {"prompt": "0", "completion": "0"}
|
||||
|
||||
def test_paid_models_kept(self):
|
||||
"""Regular paid models pass through unchanged."""
|
||||
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
pricing = {m: self._PAID for m in models}
|
||||
assert filter_nous_free_models(models, pricing) == models
|
||||
|
||||
def test_free_non_allowlist_models_removed(self):
|
||||
"""Free models NOT in the allowlist are filtered out."""
|
||||
models = ["anthropic/claude-opus-4.6", "arcee-ai/trinity-large-preview:free"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"arcee-ai/trinity-large-preview:free": self._FREE,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == ["anthropic/claude-opus-4.6"]
|
||||
|
||||
def test_allowlist_model_kept_when_free(self):
|
||||
"""Allowlist models are kept when they report as free."""
|
||||
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"xiaomi/mimo-v2-pro": self._FREE,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
|
||||
def test_allowlist_model_removed_when_paid(self):
|
||||
"""Allowlist models are removed when they are NOT free."""
|
||||
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"xiaomi/mimo-v2-pro": self._PAID,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == ["anthropic/claude-opus-4.6"]
|
||||
|
||||
def test_no_pricing_returns_all(self):
|
||||
"""When pricing data is unavailable, all models pass through."""
|
||||
models = ["anthropic/claude-opus-4.6", "nvidia/nemotron-3-super-120b-a12b:free"]
|
||||
assert filter_nous_free_models(models, {}) == models
|
||||
|
||||
def test_model_with_no_pricing_entry_treated_as_paid(self):
|
||||
"""A model missing from the pricing dict is kept (assumed paid)."""
|
||||
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
pricing = {"anthropic/claude-opus-4.6": self._PAID} # gpt-5.4 not in pricing
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == models
|
||||
|
||||
def test_mixed_scenario(self):
|
||||
"""End-to-end: mix of paid, free-allowed, free-disallowed, allowlist-not-free."""
|
||||
models = [
|
||||
"anthropic/claude-opus-4.6", # paid, not allowlist → keep
|
||||
"nvidia/nemotron-3-super-120b-a12b:free", # free, not allowlist → drop
|
||||
"xiaomi/mimo-v2-pro", # free, allowlist → keep
|
||||
"xiaomi/mimo-v2-omni", # paid, allowlist → drop
|
||||
"openai/gpt-5.4", # paid, not allowlist → keep
|
||||
]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"nvidia/nemotron-3-super-120b-a12b:free": self._FREE,
|
||||
"xiaomi/mimo-v2-pro": self._FREE,
|
||||
"xiaomi/mimo-v2-omni": self._PAID,
|
||||
"openai/gpt-5.4": self._PAID,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"openai/gpt-5.4",
|
||||
]
|
||||
|
||||
def test_allowlist_contains_expected_models(self):
|
||||
"""Sanity: the allowlist has the models we expect."""
|
||||
assert "xiaomi/mimo-v2-pro" in _NOUS_ALLOWED_FREE_MODELS
|
||||
assert "xiaomi/mimo-v2-omni" in _NOUS_ALLOWED_FREE_MODELS
|
||||
|
||||
|
||||
class TestIsNousFreeTier:
|
||||
"""Tests for is_nous_free_tier — account tier detection."""
|
||||
|
||||
def test_paid_plus_tier(self):
|
||||
assert is_nous_free_tier({"subscription": {"plan": "Plus", "tier": 2, "monthly_charge": 20}}) is False
|
||||
|
||||
def test_free_tier_by_charge(self):
|
||||
assert is_nous_free_tier({"subscription": {"plan": "Free", "tier": 0, "monthly_charge": 0}}) is True
|
||||
|
||||
def test_no_charge_field_not_free(self):
|
||||
"""Missing monthly_charge defaults to not-free (don't block users)."""
|
||||
assert is_nous_free_tier({"subscription": {"plan": "Free", "tier": 0}}) is False
|
||||
|
||||
def test_plan_name_alone_not_free(self):
|
||||
"""Plan name alone is not enough — monthly_charge is required."""
|
||||
assert is_nous_free_tier({"subscription": {"plan": "free"}}) is False
|
||||
|
||||
def test_empty_subscription_not_free(self):
|
||||
"""Empty subscription dict defaults to not-free (don't block users)."""
|
||||
assert is_nous_free_tier({"subscription": {}}) is False
|
||||
|
||||
def test_no_subscription_not_free(self):
|
||||
"""Missing subscription key returns False."""
|
||||
assert is_nous_free_tier({}) is False
|
||||
|
||||
def test_empty_response_not_free(self):
|
||||
"""Completely empty response defaults to not-free."""
|
||||
assert is_nous_free_tier({}) is False
|
||||
|
||||
|
||||
class TestPartitionNousModelsByTier:
|
||||
"""Tests for partition_nous_models_by_tier — free vs paid tier model split."""
|
||||
|
||||
_PAID = {"prompt": "0.000003", "completion": "0.000015"}
|
||||
_FREE = {"prompt": "0", "completion": "0"}
|
||||
|
||||
def test_paid_tier_all_selectable(self):
|
||||
"""Paid users get all models as selectable, none unavailable."""
|
||||
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
pricing = {"anthropic/claude-opus-4.6": self._PAID, "xiaomi/mimo-v2-pro": self._FREE}
|
||||
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=False)
|
||||
assert sel == models
|
||||
assert unav == []
|
||||
|
||||
def test_free_tier_splits_correctly(self):
|
||||
"""Free users see only free models; paid ones are unavailable."""
|
||||
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro", "openai/gpt-5.4"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"xiaomi/mimo-v2-pro": self._FREE,
|
||||
"openai/gpt-5.4": self._PAID,
|
||||
}
|
||||
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
|
||||
assert sel == ["xiaomi/mimo-v2-pro"]
|
||||
assert unav == ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
|
||||
def test_no_pricing_returns_all(self):
|
||||
"""Without pricing data, all models are selectable."""
|
||||
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
sel, unav = partition_nous_models_by_tier(models, {}, free_tier=True)
|
||||
assert sel == models
|
||||
assert unav == []
|
||||
|
||||
def test_all_free_models(self):
|
||||
"""When all models are free, free-tier users can select all."""
|
||||
models = ["xiaomi/mimo-v2-pro", "xiaomi/mimo-v2-omni"]
|
||||
pricing = {m: self._FREE for m in models}
|
||||
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
|
||||
assert sel == models
|
||||
assert unav == []
|
||||
|
||||
def test_all_paid_models(self):
|
||||
"""When all models are paid, free-tier users have none selectable."""
|
||||
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
pricing = {m: self._PAID for m in models}
|
||||
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
|
||||
assert sel == []
|
||||
assert unav == models
|
||||
|
||||
|
||||
class TestCheckNousFreeTierCache:
|
||||
"""Tests for the TTL cache on check_nous_free_tier()."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset cache before each test."""
|
||||
clear_nous_free_tier_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset cache after each test."""
|
||||
clear_nous_free_tier_cache()
|
||||
|
||||
@patch("hermes_cli.models.fetch_nous_account_tier")
|
||||
@patch("hermes_cli.models.is_nous_free_tier", return_value=True)
|
||||
def test_result_is_cached(self, mock_is_free, mock_fetch):
|
||||
"""Second call within TTL returns cached result without API call."""
|
||||
mock_fetch.return_value = {"subscription": {"monthly_charge": 0}}
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \
|
||||
patch("hermes_cli.auth.resolve_nous_runtime_credentials"):
|
||||
result1 = check_nous_free_tier()
|
||||
result2 = check_nous_free_tier()
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
# fetch_nous_account_tier should only be called once (cached on second call)
|
||||
assert mock_fetch.call_count == 1
|
||||
|
||||
@patch("hermes_cli.models.fetch_nous_account_tier")
|
||||
@patch("hermes_cli.models.is_nous_free_tier", return_value=False)
|
||||
def test_cache_expires_after_ttl(self, mock_is_free, mock_fetch):
|
||||
"""After TTL expires, the API is called again."""
|
||||
mock_fetch.return_value = {"subscription": {"monthly_charge": 20}}
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \
|
||||
patch("hermes_cli.auth.resolve_nous_runtime_credentials"):
|
||||
result1 = check_nous_free_tier()
|
||||
assert mock_fetch.call_count == 1
|
||||
|
||||
# Simulate TTL expiry by backdating the cache timestamp
|
||||
cached_result, cached_at = _models_mod._free_tier_cache
|
||||
_models_mod._free_tier_cache = (cached_result, cached_at - _FREE_TIER_CACHE_TTL - 1)
|
||||
|
||||
result2 = check_nous_free_tier()
|
||||
assert mock_fetch.call_count == 2
|
||||
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
|
||||
def test_clear_cache_forces_refresh(self):
|
||||
"""clear_nous_free_tier_cache() invalidates the cached result."""
|
||||
# Manually seed the cache
|
||||
import time
|
||||
_models_mod._free_tier_cache = (True, time.monotonic())
|
||||
|
||||
clear_nous_free_tier_cache()
|
||||
assert _models_mod._free_tier_cache is None
|
||||
|
||||
def test_cache_ttl_is_short(self):
|
||||
"""TTL should be short enough to catch upgrades quickly (<=5 min)."""
|
||||
assert _FREE_TIER_CACHE_TTL <= 300
|
||||
|
||||
@@ -44,7 +44,62 @@ def test_get_nous_subscription_features_prefers_managed_modal_in_auto_mode(monke
|
||||
assert features.modal.direct_override is False
|
||||
|
||||
|
||||
def test_get_nous_subscription_features_prefers_camofox_over_managed_browserbase(monkeypatch):
|
||||
def test_get_nous_subscription_features_marks_browser_use_as_managed_when_gateway_ready(monkeypatch):
|
||||
monkeypatch.setattr(ns, "get_env_value", lambda name: "")
|
||||
monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {"logged_in": True})
|
||||
monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True)
|
||||
monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "browser")
|
||||
monkeypatch.setattr(ns, "_has_agent_browser", lambda: True)
|
||||
monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "")
|
||||
monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
ns,
|
||||
"is_managed_tool_gateway_ready",
|
||||
lambda vendor: vendor == "browser-use",
|
||||
)
|
||||
|
||||
features = ns.get_nous_subscription_features(
|
||||
{"browser": {"cloud_provider": "browser-use"}}
|
||||
)
|
||||
|
||||
assert features.browser.available is True
|
||||
assert features.browser.active is True
|
||||
assert features.browser.managed_by_nous is True
|
||||
assert features.browser.direct_override is False
|
||||
assert features.browser.current_provider == "Browser Use"
|
||||
|
||||
|
||||
def test_get_nous_subscription_features_uses_direct_browserbase_when_no_managed_gateway(monkeypatch):
|
||||
"""When direct Browserbase keys are set and no managed gateway is available,
|
||||
the unconfigured fallback should pick Browserbase as a direct provider."""
|
||||
env = {
|
||||
"BROWSERBASE_API_KEY": "bb-key",
|
||||
"BROWSERBASE_PROJECT_ID": "bb-project",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(ns, "get_env_value", lambda name: env.get(name, ""))
|
||||
monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {"logged_in": True})
|
||||
monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True)
|
||||
monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "browser")
|
||||
monkeypatch.setattr(ns, "_has_agent_browser", lambda: True)
|
||||
monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "")
|
||||
monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
ns,
|
||||
"is_managed_tool_gateway_ready",
|
||||
lambda vendor: False, # No managed gateway available
|
||||
)
|
||||
|
||||
features = ns.get_nous_subscription_features({})
|
||||
|
||||
assert features.browser.available is True
|
||||
assert features.browser.active is True
|
||||
assert features.browser.managed_by_nous is False
|
||||
assert features.browser.direct_override is True
|
||||
assert features.browser.current_provider == "Browserbase"
|
||||
|
||||
|
||||
def test_get_nous_subscription_features_prefers_camofox_over_managed_browser_use(monkeypatch):
|
||||
env = {"CAMOFOX_URL": "http://localhost:9377"}
|
||||
|
||||
monkeypatch.setattr(ns, "get_env_value", lambda name: env.get(name, ""))
|
||||
@@ -57,11 +112,11 @@ def test_get_nous_subscription_features_prefers_camofox_over_managed_browserbase
|
||||
monkeypatch.setattr(
|
||||
ns,
|
||||
"is_managed_tool_gateway_ready",
|
||||
lambda vendor: vendor == "browserbase",
|
||||
lambda vendor: vendor == "browser-use",
|
||||
)
|
||||
|
||||
features = ns.get_nous_subscription_features(
|
||||
{"browser": {"cloud_provider": "browserbase"}}
|
||||
{"browser": {"cloud_provider": "browser-use"}}
|
||||
)
|
||||
|
||||
assert features.browser.available is True
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user