Merge upstream/main and address Copilot review feedback
Merge resolved conflicts in web/src/{i18n/{en,zh,types}.ts,lib/api.ts}
by keeping both this branch's `profiles` additions and upstream's new
`models` page additions.
Copilot review feedback:
- Implement POST /api/profiles/{name}/open-terminal endpoint (already
present); align Windows branch to `cmd.exe /c start "" <cmd>` so it
matches the new test and spawns a fresh window instead of /k reusing
the parent console.
- Move backslash escaping out of the macOS AppleScript f-string
expression (Python <3.12 disallows backslashes inside f-string
expression parts).
- Patch `_get_wrapper_dir` via monkeypatch in
test_profiles_create_creates_wrapper_alias_when_safe so the test no
longer writes to the real `~/.local/bin`.
- Extend test_dashboard_browser_safe_imports to scan `.ts` files in
addition to `.tsx`.
- Switch upstream's new ModelsPage.tsx away from the `@nous-research/ui`
root barrel onto per-component subpaths to satisfy the stricter scan.
- Fix NouiTypography `leading-1.4` -> `leading-[1.4]` so Tailwind
actually emits the line-height for the `sm` variant.
- Guard ProfilesPage.openSoulEditor against out-of-order responses by
tracking the latest requested profile via a ref.
- Replace ProfilesPage's hand-rolled setup command with a fetch to
`/api/profiles/{name}/setup-command` so the copied command always
matches what the backend would actually run (handles wrapper-alias
collisions and reserved names correctly).
- Wire SOUL.md textarea label `htmlFor` -> textarea `id` so screen
readers and clicking the label work as expected.
This commit is contained in:
@@ -11,6 +11,7 @@ import acp
|
||||
from acp.agent.router import build_agent_router
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AgentMessageChunk,
|
||||
AuthenticateResponse,
|
||||
AvailableCommandsUpdate,
|
||||
Implementation,
|
||||
@@ -27,6 +28,7 @@ from acp.schema import (
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
Usage,
|
||||
UserMessageChunk,
|
||||
)
|
||||
from acp_adapter.server import HermesACPAgent, HERMES_VERSION
|
||||
from acp_adapter.session import SessionManager
|
||||
@@ -224,6 +226,58 @@ class TestSessionOps:
|
||||
resp = await agent.load_session(cwd="/tmp", session_id="bogus")
|
||||
assert resp is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_replays_persisted_history_to_client(self, agent):
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [
|
||||
{"role": "system", "content": "hidden system"},
|
||||
{"role": "user", "content": "what controls the / slash commands?"},
|
||||
{"role": "assistant", "content": "HermesACPAgent._ADVERTISED_COMMANDS controls them."},
|
||||
{"role": "tool", "content": "tool output should not replay"},
|
||||
]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, LoadSessionResponse)
|
||||
calls = mock_conn.session_update.await_args_list
|
||||
replay_calls = [
|
||||
call for call in calls
|
||||
if getattr(call.kwargs.get("update"), "session_update", None)
|
||||
in {"user_message_chunk", "agent_message_chunk"}
|
||||
]
|
||||
assert len(replay_calls) == 2
|
||||
assert isinstance(replay_calls[0].kwargs["update"], UserMessageChunk)
|
||||
assert replay_calls[0].kwargs["update"].content.text == "what controls the / slash commands?"
|
||||
assert isinstance(replay_calls[1].kwargs["update"], AgentMessageChunk)
|
||||
assert replay_calls[1].kwargs["update"].content.text.startswith("HermesACPAgent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_replays_persisted_history_to_client(self, agent):
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.history = [{"role": "user", "content": "So tell me the current state"}]
|
||||
|
||||
mock_conn.session_update.reset_mock()
|
||||
resp = await agent.resume_session(cwd="/tmp", session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, ResumeSessionResponse)
|
||||
updates = [call.kwargs["update"] for call in mock_conn.session_update.await_args_list]
|
||||
assert any(
|
||||
isinstance(update, UserMessageChunk)
|
||||
and update.content.text == "So tell me the current state"
|
||||
for update in updates
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_creates_new_if_missing(self, agent):
|
||||
resume_resp = await agent.resume_session(cwd="/tmp", session_id="nonexistent")
|
||||
|
||||
@@ -66,34 +66,29 @@ class TestBuildAnthropicClient:
|
||||
assert "claude-code-20250219" in betas
|
||||
assert "interleaved-thinking-2025-05-14" in betas
|
||||
assert "fine-grained-tool-streaming-2025-05-14" in betas
|
||||
# Default: 1M-context beta stays IN for OAuth so 1M-capable
|
||||
# subscriptions keep full context. The reactive recovery path
|
||||
# in run_agent.py flips it off only after a subscription
|
||||
# actually rejects the beta.
|
||||
assert "context-1m-2025-08-07" in betas
|
||||
assert "api_key" not in kwargs
|
||||
|
||||
def test_oauth_does_not_send_claude_code_spoof_headers(self):
|
||||
"""OAuth requests identify as Hermes — no claude-cli UA, no x-app: cli.
|
||||
|
||||
Anthropic's OAuth-gated Messages API accepts requests from non-Claude-Code
|
||||
clients as long as auth is correct and the OAuth beta headers are present.
|
||||
See commit that removed fingerprinting for the live-test write-up.
|
||||
"""
|
||||
def test_oauth_drop_context_1m_beta_strips_only_1m(self):
|
||||
"""drop_context_1m_beta=True strips context-1m-2025-08-07 while
|
||||
preserving every other OAuth-relevant beta."""
|
||||
with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk:
|
||||
build_anthropic_client("sk-ant-oat01-" + "x" * 60)
|
||||
headers = mock_sdk.Anthropic.call_args[1]["default_headers"]
|
||||
assert "user-agent" not in {k.lower() for k in headers}
|
||||
assert "x-app" not in {k.lower() for k in headers}
|
||||
|
||||
def test_oauth_strips_context_1m_beta(self):
|
||||
"""context-1m-2025-08-07 is incompatible with OAuth auth — must be stripped.
|
||||
|
||||
Anthropic returns HTTP 400 "This authentication style is incompatible
|
||||
with the long context beta header." when OAuth traffic carries it.
|
||||
"""
|
||||
with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk:
|
||||
build_anthropic_client("sk-ant-oat01-" + "x" * 60)
|
||||
betas = mock_sdk.Anthropic.call_args[1]["default_headers"]["anthropic-beta"]
|
||||
build_anthropic_client(
|
||||
"sk-ant-oat01-" + "x" * 60,
|
||||
drop_context_1m_beta=True,
|
||||
)
|
||||
kwargs = mock_sdk.Anthropic.call_args[1]
|
||||
betas = kwargs["default_headers"]["anthropic-beta"]
|
||||
assert "context-1m-2025-08-07" not in betas
|
||||
# But other common betas still flow through
|
||||
assert "interleaved-thinking-2025-05-14" in betas
|
||||
# Everything else must still be there.
|
||||
assert "oauth-2025-04-20" in betas
|
||||
assert "claude-code-20250219" in betas
|
||||
assert "interleaved-thinking-2025-05-14" in betas
|
||||
assert "fine-grained-tool-streaming-2025-05-14" in betas
|
||||
|
||||
def test_api_key_uses_api_key(self):
|
||||
with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk:
|
||||
@@ -104,6 +99,7 @@ class TestBuildAnthropicClient:
|
||||
# API key auth should still get common betas
|
||||
betas = kwargs["default_headers"]["anthropic-beta"]
|
||||
assert "interleaved-thinking-2025-05-14" in betas
|
||||
assert "context-1m-2025-08-07" in betas
|
||||
assert "oauth-2025-04-20" not in betas # OAuth-only beta NOT present
|
||||
assert "claude-code-20250219" not in betas # OAuth-only beta NOT present
|
||||
|
||||
@@ -113,7 +109,7 @@ class TestBuildAnthropicClient:
|
||||
kwargs = mock_sdk.Anthropic.call_args[1]
|
||||
assert kwargs["base_url"] == "https://custom.api.com"
|
||||
assert kwargs["default_headers"] == {
|
||||
"anthropic-beta": "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
|
||||
"anthropic-beta": "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07"
|
||||
}
|
||||
|
||||
def test_minimax_anthropic_endpoint_uses_bearer_auth_for_regular_api_keys(self):
|
||||
@@ -990,6 +986,42 @@ class TestBuildAnthropicKwargs:
|
||||
)
|
||||
assert kwargs["model"] == "claude-sonnet-4-20250514"
|
||||
|
||||
def test_fast_mode_oauth_default_keeps_context_1m_beta(self):
|
||||
"""Default OAuth fast-mode requests still carry context-1m-2025-08-07."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-opus-4-6",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config=None,
|
||||
is_oauth=True,
|
||||
fast_mode=True,
|
||||
)
|
||||
betas = kwargs["extra_headers"]["anthropic-beta"]
|
||||
assert "fast-mode-2026-02-01" in betas
|
||||
assert "oauth-2025-04-20" in betas
|
||||
assert "context-1m-2025-08-07" in betas
|
||||
|
||||
def test_fast_mode_oauth_drop_context_1m_beta_strips_only_1m(self):
|
||||
"""drop_context_1m_beta=True strips context-1m from fast-mode
|
||||
extra_headers while preserving every other OAuth + fast-mode beta."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-opus-4-6",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config=None,
|
||||
is_oauth=True,
|
||||
fast_mode=True,
|
||||
drop_context_1m_beta=True,
|
||||
)
|
||||
betas = kwargs["extra_headers"]["anthropic-beta"]
|
||||
assert "context-1m-2025-08-07" not in betas
|
||||
assert "fast-mode-2026-02-01" in betas
|
||||
assert "oauth-2025-04-20" in betas
|
||||
assert "claude-code-20250219" in betas
|
||||
assert "interleaved-thinking-2025-05-14" in betas
|
||||
|
||||
def test_reasoning_config_maps_to_manual_thinking_for_pre_4_6_models(self):
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-20250514",
|
||||
|
||||
@@ -259,7 +259,7 @@ class TestAnthropicOAuthFlag:
|
||||
assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled"
|
||||
|
||||
|
||||
class TestTryCodex:
|
||||
class TestBuildCodexClient:
|
||||
def test_pool_without_selected_entry_falls_back_to_auth_store(self):
|
||||
with (
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)),
|
||||
@@ -267,15 +267,23 @@ class TestTryCodex:
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_codex
|
||||
from agent.auxiliary_client import _build_codex_client
|
||||
|
||||
client, model = _try_codex()
|
||||
client, model = _build_codex_client("gpt-5.4")
|
||||
|
||||
assert client is not None
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert model == "gpt-5.4"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "codex-auth-token"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
def test_rejects_missing_model(self):
|
||||
"""Callers must pass an explicit model; no hardcoded default."""
|
||||
from agent.auxiliary_client import _build_codex_client
|
||||
|
||||
client, model = _build_codex_client("")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestExpiredCodexFallback:
|
||||
"""Test that expired Codex tokens don't block the auto chain."""
|
||||
@@ -507,14 +515,14 @@ class TestGetTextAuxiliaryClient:
|
||||
patch("agent.auxiliary_client.OpenAI"),
|
||||
patch("hermes_cli.auth._read_codex_tokens", side_effect=AssertionError("legacy codex store should not run")),
|
||||
):
|
||||
from agent.auxiliary_client import _try_codex
|
||||
from agent.auxiliary_client import _build_codex_client
|
||||
|
||||
client, model = _try_codex()
|
||||
client, model = _build_codex_client("gpt-5.4")
|
||||
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert model == "gpt-5.4"
|
||||
|
||||
def test_returns_none_when_nothing_available(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
@@ -783,11 +791,15 @@ class TestIsPaymentError:
|
||||
class TestGetProviderChain:
|
||||
"""_get_provider_chain() resolves functions at call time (testable)."""
|
||||
|
||||
def test_returns_five_entries(self):
|
||||
def test_returns_four_entries(self):
|
||||
chain = _get_provider_chain()
|
||||
assert len(chain) == 5
|
||||
assert len(chain) == 4
|
||||
labels = [label for label, _ in chain]
|
||||
assert labels == ["openrouter", "nous", "local/custom", "openai-codex", "api-key"]
|
||||
assert labels == ["openrouter", "nous", "local/custom", "api-key"]
|
||||
# Codex is deliberately NOT in this chain — see _get_provider_chain
|
||||
# docstring. ChatGPT-account Codex has a shifting model allow-list;
|
||||
# guessing a model to fall back on breaks more often than it helps.
|
||||
assert "openai-codex" not in labels
|
||||
|
||||
def test_picks_up_patched_functions(self):
|
||||
"""Patches on _try_* functions must be visible in the chain."""
|
||||
@@ -814,7 +826,6 @@ class TestTryPaymentFallback:
|
||||
with patch("agent.auxiliary_client._try_openrouter", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._try_nous", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._try_custom_endpoint", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"):
|
||||
client, model, label = _try_payment_fallback("openrouter")
|
||||
@@ -825,23 +836,26 @@ class TestTryPaymentFallback:
|
||||
"""'codex' should map to 'openai-codex' in the skip set."""
|
||||
mock_client = MagicMock()
|
||||
with patch("agent.auxiliary_client._try_openrouter", return_value=(mock_client, "or-model")), \
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="openai-codex"):
|
||||
client, model, label = _try_payment_fallback("openai-codex", task="vision")
|
||||
assert client is mock_client
|
||||
assert label == "openrouter"
|
||||
|
||||
def test_skips_to_codex_when_or_and_nous_fail(self):
|
||||
mock_codex = MagicMock()
|
||||
def test_codex_not_in_fallback_chain(self):
|
||||
"""Codex is deliberately NOT a fallback rung (shifting model allow-list).
|
||||
|
||||
When OR/Nous/custom/api-key all fail, payment-fallback returns None —
|
||||
Codex is never tried with a guessed model.
|
||||
"""
|
||||
with patch("agent.auxiliary_client._try_openrouter", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._try_nous", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._try_custom_endpoint", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(mock_codex, "gpt-5.2-codex")), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"):
|
||||
client, model, label = _try_payment_fallback("openrouter")
|
||||
assert client is mock_codex
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert label == "openai-codex"
|
||||
assert client is None
|
||||
assert model is None
|
||||
assert label == ""
|
||||
|
||||
|
||||
class TestCallLlmPaymentFallback:
|
||||
@@ -1360,14 +1374,14 @@ class TestAuxiliaryAuthRefreshRetry:
|
||||
with (
|
||||
patch(
|
||||
"agent.auxiliary_client.resolve_vision_provider_client",
|
||||
side_effect=[("openai-codex", failing_client, "gpt-5.2-codex"), ("openai-codex", fresh_client, "gpt-5.2-codex")],
|
||||
side_effect=[("openai-codex", failing_client, "gpt-5.4"), ("openai-codex", fresh_client, "gpt-5.4")],
|
||||
),
|
||||
patch("agent.auxiliary_client._refresh_provider_credentials", return_value=True) as mock_refresh,
|
||||
):
|
||||
resp = call_llm(
|
||||
task="vision",
|
||||
provider="openai-codex",
|
||||
model="gpt-5.2-codex",
|
||||
model="gpt-5.4",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
@@ -1384,14 +1398,14 @@ class TestAuxiliaryAuthRefreshRetry:
|
||||
fresh_client.chat.completions.create.return_value = _DummyResponse("fresh-non-vision")
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.2-codex", None, None, None)),
|
||||
patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.2-codex"), (fresh_client, "gpt-5.2-codex")]),
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.4", None, None, None)),
|
||||
patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.4"), (fresh_client, "gpt-5.4")]),
|
||||
patch("agent.auxiliary_client._refresh_provider_credentials", return_value=True) as mock_refresh,
|
||||
):
|
||||
resp = call_llm(
|
||||
task="compression",
|
||||
provider="openai-codex",
|
||||
model="gpt-5.2-codex",
|
||||
model="gpt-5.4",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
@@ -1439,14 +1453,14 @@ class TestAuxiliaryAuthRefreshRetry:
|
||||
with (
|
||||
patch(
|
||||
"agent.auxiliary_client.resolve_vision_provider_client",
|
||||
side_effect=[("openai-codex", failing_client, "gpt-5.2-codex"), ("openai-codex", fresh_client, "gpt-5.2-codex")],
|
||||
side_effect=[("openai-codex", failing_client, "gpt-5.4"), ("openai-codex", fresh_client, "gpt-5.4")],
|
||||
),
|
||||
patch("agent.auxiliary_client._refresh_provider_credentials", return_value=True) as mock_refresh,
|
||||
):
|
||||
resp = await async_call_llm(
|
||||
task="vision",
|
||||
provider="openai-codex",
|
||||
model="gpt-5.2-codex",
|
||||
model="gpt-5.4",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
@@ -1635,3 +1649,106 @@ class TestCodexAdapterReasoningTranslation:
|
||||
)
|
||||
assert "reasoning" not in captured
|
||||
|
||||
|
||||
|
||||
class TestVisionAutoSkipsKimiCoding:
|
||||
"""_resolve_auto vision branch skips providers that have no vision on
|
||||
their main endpoint (e.g. Kimi Coding Plan /coding) and falls through
|
||||
to the aggregator chain instead of handing back a client that will 404
|
||||
on every request (#17076).
|
||||
"""
|
||||
|
||||
def test_kimi_coding_skipped_falls_through_to_openrouter(self, monkeypatch):
|
||||
"""kimi-coding as main + vision auto → OpenRouter (not kimi)."""
|
||||
fake_or_client = MagicMock(name="openrouter_client")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_main_provider", lambda: "kimi-coding",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_main_model", lambda: "kimi-code",
|
||||
)
|
||||
# Guard: if the skip doesn't fire, _resolve_strict_vision_backend
|
||||
# and resolve_provider_client both would try kimi-coding — detect
|
||||
# either via the main-provider call and fail loud.
|
||||
rpc_mock = MagicMock(side_effect=AssertionError(
|
||||
"resolve_provider_client should NOT be called for kimi-coding "
|
||||
"on the vision auto path"))
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client.resolve_provider_client", rpc_mock,
|
||||
)
|
||||
|
||||
def fake_strict(provider, model=None):
|
||||
if provider == "openrouter":
|
||||
return fake_or_client, "google/gemini-3-flash-preview"
|
||||
if provider == "nous":
|
||||
return None, None
|
||||
raise AssertionError(
|
||||
f"strict vision backend should not be called for {provider!r} "
|
||||
"when main provider is kimi-coding"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
fake_strict,
|
||||
)
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
assert provider == "openrouter"
|
||||
assert client is fake_or_client
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_kimi_coding_cn_skipped_too(self, monkeypatch):
|
||||
"""Same skip applies to the CN variant."""
|
||||
fake_or_client = MagicMock(name="openrouter_client")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_main_provider", lambda: "kimi-coding-cn",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_main_model", lambda: "kimi-code",
|
||||
)
|
||||
rpc_mock = MagicMock(side_effect=AssertionError(
|
||||
"resolve_provider_client should NOT be called for kimi-coding-cn"))
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client.resolve_provider_client", rpc_mock,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
lambda p, m=None: (fake_or_client, "gemini")
|
||||
if p == "openrouter"
|
||||
else (None, None),
|
||||
)
|
||||
|
||||
provider, client, _ = resolve_vision_provider_client()
|
||||
assert provider == "openrouter"
|
||||
assert client is fake_or_client
|
||||
|
||||
def test_explicit_override_to_kimi_coding_still_honored(self, monkeypatch):
|
||||
"""When a user *explicitly* requests kimi-coding for vision (e.g.
|
||||
they know what they're doing, or are running a future build that
|
||||
adds image_in capability to Kimi Code), the explicit path still
|
||||
routes to kimi-coding — only the auto branch applies the skip.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_main_provider", lambda: "openrouter",
|
||||
)
|
||||
fake_kimi_client = MagicMock(name="kimi_client")
|
||||
gcc_mock = MagicMock(return_value=(fake_kimi_client, "kimi-code"))
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._get_cached_client", gcc_mock,
|
||||
)
|
||||
|
||||
provider, client, model = resolve_vision_provider_client(
|
||||
provider="kimi-coding",
|
||||
)
|
||||
assert provider == "kimi-coding"
|
||||
assert client is fake_kimi_client
|
||||
gcc_mock.assert_called_once()
|
||||
|
||||
def test_skip_set_covers_exactly_known_entries(self):
|
||||
"""Guard against accidental widening of the skip list."""
|
||||
from agent.auxiliary_client import _PROVIDERS_WITHOUT_VISION
|
||||
assert _PROVIDERS_WITHOUT_VISION == frozenset({
|
||||
"kimi-coding",
|
||||
"kimi-coding-cn",
|
||||
})
|
||||
|
||||
@@ -10,7 +10,7 @@ of auth correctness.
|
||||
``_codex_cloudflare_headers`` in ``agent.auxiliary_client`` centralizes the
|
||||
header set so the primary chat client (``run_agent.AIAgent.__init__`` +
|
||||
``_apply_client_headers_for_base_url``) and the auxiliary client paths
|
||||
(``_try_codex`` and the ``raw_codex`` branch of ``resolve_provider_client``)
|
||||
(``_build_codex_client`` and the ``raw_codex`` branch of ``resolve_provider_client``)
|
||||
all emit the same headers.
|
||||
|
||||
These tests pin:
|
||||
@@ -207,9 +207,10 @@ class TestPrimaryClientWiring:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAuxiliaryClientWiring:
|
||||
def test_try_codex_passes_codex_headers(self, monkeypatch):
|
||||
"""_try_codex builds the OpenAI client used for compression / vision /
|
||||
title generation when routed through Codex. Must emit codex headers."""
|
||||
def test_build_codex_client_passes_codex_headers(self, monkeypatch):
|
||||
"""_build_codex_client builds the OpenAI client used for compression /
|
||||
vision / title generation when routed through Codex. Must emit codex
|
||||
headers."""
|
||||
from agent import auxiliary_client
|
||||
token = _make_codex_jwt("acct-aux-try-codex")
|
||||
|
||||
@@ -225,7 +226,7 @@ class TestAuxiliaryClientWiring:
|
||||
)
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = auxiliary_client._try_codex()
|
||||
client, model = auxiliary_client._build_codex_client("gpt-5.4")
|
||||
assert client is not None
|
||||
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
|
||||
assert headers.get("originator") == "codex_cli_rs"
|
||||
@@ -244,7 +245,7 @@ class TestAuxiliaryClientWiring:
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = auxiliary_client.resolve_provider_client(
|
||||
"openai-codex", raw_codex=True,
|
||||
"openai-codex", model="gpt-5.4", raw_codex=True,
|
||||
)
|
||||
assert client is not None
|
||||
headers = mock_openai.call_args.kwargs.get("default_headers") or {}
|
||||
|
||||
@@ -80,15 +80,19 @@ class CopilotACPClientSafetyTests(unittest.TestCase):
|
||||
secret_file = root / "config.env"
|
||||
secret_file.write_text("OPENAI_API_KEY=sk-proj-abc123def456ghi789jkl012")
|
||||
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "fs/read_text_file",
|
||||
"params": {"path": str(secret_file)},
|
||||
},
|
||||
cwd=str(root),
|
||||
)
|
||||
# agent.redact snapshots HERMES_REDACT_SECRETS at import time into
|
||||
# _REDACT_ENABLED, so patching os.environ is a no-op. Flip the
|
||||
# module-level constant directly for the duration of the call.
|
||||
with patch("agent.redact._REDACT_ENABLED", True):
|
||||
response = self._dispatch(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "fs/read_text_file",
|
||||
"params": {"path": str(secret_file)},
|
||||
},
|
||||
cwd=str(root),
|
||||
)
|
||||
|
||||
content = ((response.get("result") or {}).get("content") or "")
|
||||
self.assertNotIn("abc123def456", content)
|
||||
|
||||
@@ -271,10 +271,17 @@ def test_run_review_synchronous_invokes_llm_stub(curator_env, monkeypatch):
|
||||
_write_skill(skills_dir, "a")
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
c, "_run_llm_review",
|
||||
lambda prompt: (calls.append(prompt), "stubbed-summary")[1],
|
||||
)
|
||||
def _stub(prompt):
|
||||
calls.append(prompt)
|
||||
return {
|
||||
"final": "stubbed-summary",
|
||||
"summary": "stubbed-summary",
|
||||
"model": "stub-model",
|
||||
"provider": "stub-provider",
|
||||
"tool_calls": [],
|
||||
"error": None,
|
||||
}
|
||||
monkeypatch.setattr(c, "_run_llm_review", _stub)
|
||||
|
||||
captured = []
|
||||
c.run_curator_review(on_summary=lambda s: captured.append(s), synchronous=True)
|
||||
@@ -478,3 +485,153 @@ def test_cli_pin_refuses_bundled_skill(curator_env, capsys):
|
||||
captured = capsys.readouterr()
|
||||
assert rc == 1
|
||||
assert "bundled" in captured.out.lower() or "hub" in captured.out.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# curator review-model resolution (canonical auxiliary.curator slot)
|
||||
#
|
||||
# Curator was unified with the rest of the aux task system in Apr 2026 so
|
||||
# `hermes model` → auxiliary picker, the dashboard Models tab, and the full
|
||||
# per-task config (timeout, base_url, api_key, extra_body) all work for it.
|
||||
# Voscko report: curator.auxiliary.{provider,model} was advertised but never
|
||||
# read. Fix wires curator through auxiliary.curator with a legacy fallback.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_review_model_defaults_to_main_when_slot_is_auto(curator_env):
|
||||
"""auxiliary.curator absent (or auto/empty) → use main model.provider/model."""
|
||||
curator = curator_env["curator"]
|
||||
cfg = {
|
||||
"model": {"provider": "openrouter", "default": "openai/gpt-5.5"},
|
||||
}
|
||||
assert curator._resolve_review_model(cfg) == ("openrouter", "openai/gpt-5.5")
|
||||
|
||||
# Explicit auto/empty slot — still main model.
|
||||
cfg["auxiliary"] = {"curator": {"provider": "auto", "model": ""}}
|
||||
assert curator._resolve_review_model(cfg) == ("openrouter", "openai/gpt-5.5")
|
||||
|
||||
|
||||
def test_review_model_honors_auxiliary_curator_slot(curator_env):
|
||||
"""auxiliary.curator.{provider,model} fully set → that pair wins."""
|
||||
curator = curator_env["curator"]
|
||||
cfg = {
|
||||
"model": {"provider": "openrouter", "default": "openai/gpt-5.5"},
|
||||
"auxiliary": {
|
||||
"curator": {
|
||||
"provider": "openrouter",
|
||||
"model": "openai/gpt-5.4-mini",
|
||||
},
|
||||
},
|
||||
}
|
||||
assert curator._resolve_review_model(cfg) == (
|
||||
"openrouter", "openai/gpt-5.4-mini",
|
||||
)
|
||||
|
||||
|
||||
def test_review_model_auxiliary_curator_partial_override_falls_back(curator_env):
|
||||
"""Only one of slot provider/model set → fall back to the main pair.
|
||||
|
||||
Prevents half-configured overrides from sending an empty side to
|
||||
resolve_runtime_provider.
|
||||
"""
|
||||
curator = curator_env["curator"]
|
||||
base_main = {"provider": "openrouter", "default": "openai/gpt-5.5"}
|
||||
|
||||
cfg_provider_only = {
|
||||
"model": dict(base_main),
|
||||
"auxiliary": {"curator": {"provider": "openrouter", "model": ""}},
|
||||
}
|
||||
assert curator._resolve_review_model(cfg_provider_only) == (
|
||||
"openrouter", "openai/gpt-5.5",
|
||||
)
|
||||
|
||||
cfg_model_only = {
|
||||
"model": dict(base_main),
|
||||
"auxiliary": {"curator": {"provider": "auto", "model": "gpt-5.4-mini"}},
|
||||
}
|
||||
assert curator._resolve_review_model(cfg_model_only) == (
|
||||
"openrouter", "openai/gpt-5.5",
|
||||
)
|
||||
|
||||
|
||||
def test_review_model_legacy_curator_auxiliary_still_works(curator_env, caplog):
|
||||
"""Pre-unification users set curator.auxiliary.{provider,model} — honor it.
|
||||
|
||||
Emits a deprecation log line but keeps their config working.
|
||||
"""
|
||||
curator = curator_env["curator"]
|
||||
cfg = {
|
||||
"model": {"provider": "openrouter", "default": "openai/gpt-5.5"},
|
||||
"curator": {
|
||||
"auxiliary": {
|
||||
"provider": "openrouter",
|
||||
"model": "openai/gpt-5.4-mini",
|
||||
},
|
||||
},
|
||||
}
|
||||
import logging
|
||||
with caplog.at_level(logging.INFO, logger="agent.curator"):
|
||||
result = curator._resolve_review_model(cfg)
|
||||
assert result == ("openrouter", "openai/gpt-5.4-mini")
|
||||
assert any(
|
||||
"deprecated curator.auxiliary" in rec.message for rec in caplog.records
|
||||
), "expected deprecation warning when legacy curator.auxiliary is used"
|
||||
|
||||
|
||||
def test_review_model_new_slot_wins_over_legacy(curator_env):
|
||||
"""When BOTH new and legacy are set, the canonical slot wins."""
|
||||
curator = curator_env["curator"]
|
||||
cfg = {
|
||||
"model": {"provider": "openrouter", "default": "openai/gpt-5.5"},
|
||||
"auxiliary": {
|
||||
"curator": {"provider": "nous", "model": "new-winner"},
|
||||
},
|
||||
"curator": {
|
||||
"auxiliary": {"provider": "openrouter", "model": "legacy-loser"},
|
||||
},
|
||||
}
|
||||
assert curator._resolve_review_model(cfg) == ("nous", "new-winner")
|
||||
|
||||
|
||||
def test_review_model_handles_missing_sections(curator_env):
|
||||
"""Missing auxiliary/curator sections never raise — fall back cleanly."""
|
||||
curator = curator_env["curator"]
|
||||
cfg = {"model": {"provider": "anthropic", "model": "claude-sonnet-4-6"}}
|
||||
assert curator._resolve_review_model(cfg) == (
|
||||
"anthropic", "claude-sonnet-4-6",
|
||||
)
|
||||
|
||||
# Completely empty config → ("auto", "") — resolve_runtime_provider
|
||||
# handles the auto-detection chain from there.
|
||||
assert curator._resolve_review_model({}) == ("auto", "")
|
||||
|
||||
|
||||
def test_curator_slot_is_canonical_aux_task():
|
||||
"""Curator must be a first-class slot in every aux-task registry.
|
||||
|
||||
Four sources of truth, all checked by the shared registry test
|
||||
(test_aux_config.py) for the main tasks — this test pins `curator`
|
||||
specifically so the unification doesn't silently regress.
|
||||
"""
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
from hermes_cli.main import _AUX_TASKS
|
||||
from hermes_cli.web_server import _AUX_TASK_SLOTS
|
||||
|
||||
# 1. DEFAULT_CONFIG.auxiliary — schema source
|
||||
assert "curator" in DEFAULT_CONFIG["auxiliary"], \
|
||||
"curator missing from DEFAULT_CONFIG['auxiliary']"
|
||||
slot = DEFAULT_CONFIG["auxiliary"]["curator"]
|
||||
assert slot["provider"] == "auto"
|
||||
assert slot["model"] == ""
|
||||
assert slot["timeout"] > 0, "curator timeout should be set (reviews run long)"
|
||||
|
||||
# 2. hermes_cli/main.py _AUX_TASKS — CLI picker
|
||||
aux_keys = {k for k, _name, _desc in _AUX_TASKS}
|
||||
assert "curator" in aux_keys, "curator missing from _AUX_TASKS (CLI picker)"
|
||||
|
||||
# 3. hermes_cli/web_server.py _AUX_TASK_SLOTS — REST API allowlist
|
||||
assert "curator" in _AUX_TASK_SLOTS, \
|
||||
"curator missing from _AUX_TASK_SLOTS (dashboard REST API)"
|
||||
|
||||
# 4. web/src/pages/ModelsPage.tsx is checked at build time; the tsx
|
||||
# array and this tuple share a ``Must match _AUX_TASK_SLOTS`` comment.
|
||||
|
||||
258
tests/agent/test_curator_reports.py
Normal file
258
tests/agent/test_curator_reports.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for the curator per-run report writer (run.json + REPORT.md).
|
||||
|
||||
Reports live under ``~/.hermes/logs/curator/{YYYYMMDD-HHMMSS}/`` alongside
|
||||
the standard log dir, not inside the user's ``skills/`` data directory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def curator_env(tmp_path, monkeypatch):
|
||||
"""Isolated HERMES_HOME with a skills/ dir + reset curator module state."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
(home / "skills").mkdir()
|
||||
(home / "logs").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
import importlib
|
||||
import hermes_constants
|
||||
importlib.reload(hermes_constants)
|
||||
from agent import curator
|
||||
importlib.reload(curator)
|
||||
from tools import skill_usage
|
||||
importlib.reload(skill_usage)
|
||||
yield {"home": home, "curator": curator, "skill_usage": skill_usage}
|
||||
|
||||
|
||||
def _make_llm_meta(**overrides):
|
||||
base = {
|
||||
"final": "short summary of the pass",
|
||||
"summary": "short summary",
|
||||
"model": "test-model",
|
||||
"provider": "test-provider",
|
||||
"tool_calls": [],
|
||||
"error": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def test_reports_root_is_under_logs_not_skills(curator_env):
|
||||
"""Reports live in logs/curator/, not skills/ — operational telemetry
|
||||
belongs with the logs, not with user-authored skill data."""
|
||||
curator = curator_env["curator"]
|
||||
root = curator._reports_root()
|
||||
home = curator_env["home"]
|
||||
# Must be under logs/
|
||||
assert root == home / "logs" / "curator"
|
||||
# Must NOT be under skills/
|
||||
assert "skills" not in root.parts
|
||||
|
||||
|
||||
def test_write_run_report_creates_both_files(curator_env):
|
||||
"""Each run writes both a run.json (machine) and a REPORT.md (human)."""
|
||||
curator = curator_env["curator"]
|
||||
start = datetime.now(timezone.utc)
|
||||
|
||||
run_dir = curator._write_run_report(
|
||||
started_at=start,
|
||||
elapsed_seconds=12.345,
|
||||
auto_counts={"checked": 5, "marked_stale": 1, "archived": 0, "reactivated": 0},
|
||||
auto_summary="1 marked stale",
|
||||
before_report=[],
|
||||
before_names=set(),
|
||||
after_report=[],
|
||||
llm_meta=_make_llm_meta(),
|
||||
)
|
||||
assert run_dir is not None
|
||||
assert run_dir.is_dir()
|
||||
assert (run_dir / "run.json").exists()
|
||||
assert (run_dir / "REPORT.md").exists()
|
||||
|
||||
# The directory name is a timestamp under logs/curator/
|
||||
assert run_dir.parent == curator._reports_root()
|
||||
|
||||
|
||||
def test_run_json_has_expected_shape(curator_env):
|
||||
"""run.json must carry the machine-readable fields downstream tooling needs."""
|
||||
curator = curator_env["curator"]
|
||||
start = datetime.now(timezone.utc)
|
||||
|
||||
before_report = [
|
||||
{"name": "old-thing", "state": "active", "pinned": False},
|
||||
{"name": "keeper", "state": "active", "pinned": True},
|
||||
]
|
||||
after_report = [
|
||||
{"name": "keeper", "state": "active", "pinned": True},
|
||||
{"name": "new-umbrella", "state": "active", "pinned": False},
|
||||
]
|
||||
|
||||
run_dir = curator._write_run_report(
|
||||
started_at=start,
|
||||
elapsed_seconds=42.0,
|
||||
auto_counts={"checked": 2, "marked_stale": 0, "archived": 0, "reactivated": 0},
|
||||
auto_summary="no changes",
|
||||
before_report=before_report,
|
||||
before_names={r["name"] for r in before_report},
|
||||
after_report=after_report,
|
||||
llm_meta=_make_llm_meta(
|
||||
final="I consolidated the whole universe.",
|
||||
tool_calls=[
|
||||
{"name": "skills_list", "arguments": "{}"},
|
||||
{"name": "skill_manage", "arguments": '{"action":"create"}'},
|
||||
{"name": "terminal", "arguments": "mv ..."},
|
||||
],
|
||||
),
|
||||
)
|
||||
payload = json.loads((run_dir / "run.json").read_text())
|
||||
|
||||
# top-level shape
|
||||
for k in (
|
||||
"started_at", "duration_seconds", "model", "provider",
|
||||
"auto_transitions", "counts", "tool_call_counts",
|
||||
"archived", "added", "state_transitions",
|
||||
"llm_final", "llm_summary", "llm_error", "tool_calls",
|
||||
):
|
||||
assert k in payload, f"missing key: {k}"
|
||||
|
||||
# Diff logic
|
||||
assert payload["archived"] == ["old-thing"]
|
||||
assert payload["added"] == ["new-umbrella"]
|
||||
# Counts reflect the diff
|
||||
assert payload["counts"]["before"] == 2
|
||||
assert payload["counts"]["after"] == 2
|
||||
assert payload["counts"]["archived_this_run"] == 1
|
||||
assert payload["counts"]["added_this_run"] == 1
|
||||
# Tool call counts are aggregated
|
||||
assert payload["tool_call_counts"]["skills_list"] == 1
|
||||
assert payload["tool_call_counts"]["skill_manage"] == 1
|
||||
assert payload["tool_call_counts"]["terminal"] == 1
|
||||
assert payload["counts"]["tool_calls_total"] == 3
|
||||
|
||||
|
||||
def test_report_md_is_human_readable(curator_env):
|
||||
"""REPORT.md should be a valid markdown doc with the key sections visible."""
|
||||
curator = curator_env["curator"]
|
||||
start = datetime.now(timezone.utc)
|
||||
|
||||
run_dir = curator._write_run_report(
|
||||
started_at=start,
|
||||
elapsed_seconds=75.0,
|
||||
auto_counts={"checked": 10, "marked_stale": 2, "archived": 1, "reactivated": 0},
|
||||
auto_summary="2 marked stale, 1 archived",
|
||||
before_report=[{"name": "foo", "state": "active", "pinned": False}],
|
||||
before_names={"foo"},
|
||||
after_report=[{"name": "foo-umbrella", "state": "active", "pinned": False}],
|
||||
llm_meta=_make_llm_meta(
|
||||
final="Consolidated foo-like skills into foo-umbrella.",
|
||||
model="claude-opus-4.7",
|
||||
provider="openrouter",
|
||||
),
|
||||
)
|
||||
md = (run_dir / "REPORT.md").read_text()
|
||||
|
||||
# Structural checks
|
||||
assert "# Curator run" in md
|
||||
assert "Auto-transitions" in md
|
||||
assert "LLM consolidation pass" in md
|
||||
assert "Recovery" in md
|
||||
|
||||
# The model / provider we passed in show up
|
||||
assert "claude-opus-4.7" in md
|
||||
assert "openrouter" in md
|
||||
|
||||
# The added/archived lists are present
|
||||
assert "Skills archived" in md
|
||||
assert "`foo`" in md
|
||||
assert "New skills this run" in md
|
||||
assert "`foo-umbrella`" in md
|
||||
|
||||
# The full LLM final response is included verbatim (no 240-char truncation)
|
||||
assert "Consolidated foo-like skills into foo-umbrella." in md
|
||||
|
||||
|
||||
def test_same_second_reruns_get_unique_dirs(curator_env):
|
||||
"""If the curator somehow runs twice in the same second, the second
|
||||
report still gets its own directory rather than overwriting the first."""
|
||||
curator = curator_env["curator"]
|
||||
start = datetime(2026, 4, 29, 5, 33, 34, tzinfo=timezone.utc)
|
||||
|
||||
kwargs = dict(
|
||||
started_at=start,
|
||||
elapsed_seconds=1.0,
|
||||
auto_counts={"checked": 0, "marked_stale": 0, "archived": 0, "reactivated": 0},
|
||||
auto_summary="no changes",
|
||||
before_report=[],
|
||||
before_names=set(),
|
||||
after_report=[],
|
||||
llm_meta=_make_llm_meta(),
|
||||
)
|
||||
a = curator._write_run_report(**kwargs)
|
||||
b = curator._write_run_report(**kwargs)
|
||||
assert a != b
|
||||
assert a is not None and b is not None
|
||||
# Second dir has a numeric disambiguator suffix
|
||||
assert b.name.startswith(a.name)
|
||||
|
||||
|
||||
def test_report_captures_llm_error_and_continues(curator_env):
|
||||
"""If the LLM pass recorded an error, the report still writes and
|
||||
surfaces the error prominently."""
|
||||
curator = curator_env["curator"]
|
||||
run_dir = curator._write_run_report(
|
||||
started_at=datetime.now(timezone.utc),
|
||||
elapsed_seconds=2.0,
|
||||
auto_counts={"checked": 0, "marked_stale": 0, "archived": 0, "reactivated": 0},
|
||||
auto_summary="no changes",
|
||||
before_report=[],
|
||||
before_names=set(),
|
||||
after_report=[],
|
||||
llm_meta=_make_llm_meta(
|
||||
error="HTTP 400: No models provided",
|
||||
final="",
|
||||
summary="error",
|
||||
),
|
||||
)
|
||||
md = (run_dir / "REPORT.md").read_text()
|
||||
assert "HTTP 400" in md
|
||||
payload = json.loads((run_dir / "run.json").read_text())
|
||||
assert payload["llm_error"] == "HTTP 400: No models provided"
|
||||
|
||||
|
||||
def test_state_transitions_captured_in_report(curator_env):
|
||||
"""When a skill moves active → stale or stale → archived between
|
||||
before/after snapshots, the report records it."""
|
||||
curator = curator_env["curator"]
|
||||
start = datetime.now(timezone.utc)
|
||||
|
||||
before = [{"name": "getting-old", "state": "active", "pinned": False}]
|
||||
after = [{"name": "getting-old", "state": "stale", "pinned": False}]
|
||||
|
||||
run_dir = curator._write_run_report(
|
||||
started_at=start,
|
||||
elapsed_seconds=1.0,
|
||||
auto_counts={"checked": 1, "marked_stale": 1, "archived": 0, "reactivated": 0},
|
||||
auto_summary="1 marked stale",
|
||||
before_report=before,
|
||||
before_names={r["name"] for r in before},
|
||||
after_report=after,
|
||||
llm_meta=_make_llm_meta(),
|
||||
)
|
||||
payload = json.loads((run_dir / "run.json").read_text())
|
||||
assert payload["state_transitions"] == [
|
||||
{"name": "getting-old", "from": "active", "to": "stale"}
|
||||
]
|
||||
md = (run_dir / "REPORT.md").read_text()
|
||||
assert "State transitions" in md
|
||||
assert "getting-old" in md
|
||||
assert "active → stale" in md
|
||||
242
tests/agent/test_deepseek_anthropic_thinking.py
Normal file
242
tests/agent/test_deepseek_anthropic_thinking.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Regression guard: preserve thinking blocks on DeepSeek's /anthropic endpoint.
|
||||
|
||||
DeepSeek's ``api.deepseek.com/anthropic`` route speaks the Anthropic Messages
|
||||
protocol but, when thinking mode is enabled, requires ``thinking`` blocks from
|
||||
prior assistant turns to round-trip on subsequent requests. The generic
|
||||
third-party path strips them (signatures are Anthropic-proprietary and other
|
||||
proxies cannot validate them), so without a DeepSeek-specific carve-out the
|
||||
next tool-call turn fails with HTTP 400::
|
||||
|
||||
The content[].thinking in the thinking mode must be passed back to the
|
||||
API.
|
||||
|
||||
DeepSeek's compatibility matrix lists ``thinking`` as supported but
|
||||
``redacted_thinking`` and ``cache_control`` on thinking blocks as not
|
||||
supported. Handling is the same as Kimi's ``/coding`` endpoint: strip
|
||||
Anthropic-signed blocks (DeepSeek can't validate them) but preserve unsigned
|
||||
blocks that Hermes synthesises from ``reasoning_content``.
|
||||
|
||||
See hermes-agent#16748.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDeepSeekAnthropicPreservesThinking:
|
||||
"""convert_messages_to_anthropic must replay DeepSeek thinking blocks."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url",
|
||||
[
|
||||
"https://api.deepseek.com/anthropic",
|
||||
"https://api.deepseek.com/anthropic/",
|
||||
"https://api.deepseek.com/anthropic/v1",
|
||||
"https://API.DeepSeek.com/anthropic",
|
||||
],
|
||||
)
|
||||
def test_unsigned_thinking_block_survives_replay(self, base_url: str) -> None:
|
||||
"""Unsigned thinking (synthesised from reasoning_content) must be preserved."""
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "planning the tool call",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "skill_view", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "ok"},
|
||||
]
|
||||
_system, converted = convert_messages_to_anthropic(
|
||||
messages, base_url=base_url
|
||||
)
|
||||
|
||||
assistant_msg = next(m for m in converted if m["role"] == "assistant")
|
||||
thinking_blocks = [
|
||||
b for b in assistant_msg["content"]
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert len(thinking_blocks) == 1, (
|
||||
f"DeepSeek /anthropic ({base_url}) must preserve unsigned thinking "
|
||||
"blocks synthesised from reasoning_content — upstream rejects "
|
||||
"replayed tool-call messages without them."
|
||||
)
|
||||
assert thinking_blocks[0]["thinking"] == "planning the tool call"
|
||||
# Synthesised block — never has a signature
|
||||
assert "signature" not in thinking_blocks[0]
|
||||
|
||||
def test_unsigned_thinking_preserved_on_non_latest_assistant_turn(self) -> None:
|
||||
"""DeepSeek validates history across every prior assistant turn, not just last."""
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "q1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "r1",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "ok"},
|
||||
{"role": "user", "content": "q2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "r2",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "ok"},
|
||||
]
|
||||
_system, converted = convert_messages_to_anthropic(
|
||||
messages, base_url="https://api.deepseek.com/anthropic"
|
||||
)
|
||||
|
||||
assistants = [m for m in converted if m["role"] == "assistant"]
|
||||
assert len(assistants) == 2
|
||||
for assistant, expected in zip(assistants, ("r1", "r2")):
|
||||
thinking = [
|
||||
b for b in assistant["content"]
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert len(thinking) == 1
|
||||
assert thinking[0]["thinking"] == expected
|
||||
|
||||
def test_signed_anthropic_thinking_block_is_stripped(self) -> None:
|
||||
"""Anthropic-signed blocks (that leaked through) must still be stripped.
|
||||
|
||||
DeepSeek issues its own signatures and cannot validate Anthropic's —
|
||||
the strip-signed / keep-unsigned split matches the Kimi policy.
|
||||
"""
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "anthropic-signed payload",
|
||||
"signature": "anthropic-sig-xyz",
|
||||
},
|
||||
{"type": "text", "text": "hello"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "again"},
|
||||
]
|
||||
_system, converted = convert_messages_to_anthropic(
|
||||
messages, base_url="https://api.deepseek.com/anthropic"
|
||||
)
|
||||
|
||||
assistant_msg = next(m for m in converted if m["role"] == "assistant")
|
||||
thinking_blocks = [
|
||||
b for b in assistant_msg["content"]
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert thinking_blocks == [], (
|
||||
"Signed Anthropic thinking blocks must be stripped on DeepSeek — "
|
||||
"DeepSeek cannot validate Anthropic-proprietary signatures."
|
||||
)
|
||||
|
||||
def test_cache_control_stripped_from_thinking_block(self) -> None:
|
||||
"""cache_control must still be stripped even when the block is preserved.
|
||||
|
||||
DeepSeek's compatibility matrix lists cache_control on thinking blocks
|
||||
as ignored — cache markers interfere with signature validation on
|
||||
upstreams that do check them, so Hermes strips them everywhere.
|
||||
"""
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "r1",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "ok"},
|
||||
]
|
||||
# Inject cache_control on the synthesised thinking block after-the-fact
|
||||
# by running conversion once, mutating, then re-running would be
|
||||
# indirect. Instead check the simpler invariant: no thinking block in
|
||||
# the converted output carries cache_control.
|
||||
_system, converted = convert_messages_to_anthropic(
|
||||
messages, base_url="https://api.deepseek.com/anthropic"
|
||||
)
|
||||
for m in converted:
|
||||
if not isinstance(m.get("content"), list):
|
||||
continue
|
||||
for b in m["content"]:
|
||||
if isinstance(b, dict) and b.get("type") in ("thinking", "redacted_thinking"):
|
||||
assert "cache_control" not in b
|
||||
|
||||
def test_openai_compat_deepseek_base_is_not_matched(self) -> None:
|
||||
"""The OpenAI-compatible ``api.deepseek.com`` base must NOT trigger the
|
||||
DeepSeek /anthropic branch — it never reaches this adapter, but the
|
||||
detector should still fail closed so an accidental misuse doesn't
|
||||
quietly send signed Anthropic blocks to an OpenAI endpoint.
|
||||
"""
|
||||
from agent.anthropic_adapter import _is_deepseek_anthropic_endpoint
|
||||
|
||||
assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com") is False
|
||||
assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com/v1") is False
|
||||
assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com/anthropic") is True
|
||||
assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com/anthropic/v1") is True
|
||||
|
||||
def test_non_deepseek_third_party_still_strips_all_thinking(self) -> None:
|
||||
"""MiniMax and other third-party Anthropic endpoints must keep the
|
||||
generic strip-all behaviour (they reject unsigned blocks outright).
|
||||
"""
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "r1",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "ok"},
|
||||
]
|
||||
_system, converted = convert_messages_to_anthropic(
|
||||
messages, base_url="https://api.minimax.io/anthropic"
|
||||
)
|
||||
assistant_msg = next(m for m in converted if m["role"] == "assistant")
|
||||
thinking_blocks = [
|
||||
b for b in assistant_msg["content"]
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert thinking_blocks == [], (
|
||||
"Non-DeepSeek third-party endpoints must keep the generic "
|
||||
"strip-all-thinking behaviour — unsigned blocks get rejected."
|
||||
)
|
||||
@@ -57,7 +57,9 @@ class TestFailoverReason:
|
||||
"context_overflow", "payload_too_large", "image_too_large",
|
||||
"model_not_found", "format_error",
|
||||
"provider_policy_blocked",
|
||||
"thinking_signature", "long_context_tier", "unknown",
|
||||
"thinking_signature", "long_context_tier",
|
||||
"oauth_long_context_beta_forbidden",
|
||||
"unknown",
|
||||
}
|
||||
actual = {r.value for r in FailoverReason}
|
||||
assert expected == actual
|
||||
@@ -458,6 +460,40 @@ class TestClassifyApiError:
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
# ── Provider-specific: Anthropic OAuth 1M-context beta forbidden ──
|
||||
|
||||
def test_anthropic_oauth_1m_beta_forbidden(self):
|
||||
"""400 + 'long context beta is not yet available for this subscription'
|
||||
→ oauth_long_context_beta_forbidden (retryable, no compression)."""
|
||||
e = MockAPIError(
|
||||
"The long context beta is not yet available for this subscription.",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4.6")
|
||||
assert result.reason == FailoverReason.oauth_long_context_beta_forbidden
|
||||
assert result.retryable is True
|
||||
assert result.should_compress is False
|
||||
|
||||
def test_anthropic_oauth_1m_beta_forbidden_does_not_collide_with_tier_gate(self):
|
||||
"""The 429 'extra usage' + 'long context' tier gate keeps its own
|
||||
classification even though its message mentions 'long context'."""
|
||||
e = MockAPIError(
|
||||
"Extra usage is required for long context requests over 200k tokens",
|
||||
status_code=429,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4.6")
|
||||
assert result.reason == FailoverReason.long_context_tier
|
||||
|
||||
def test_400_without_beta_phrase_is_not_1m_beta_forbidden(self):
|
||||
"""A generic 400 that happens to mention 'long context' but not the
|
||||
exact beta-availability phrase should not be misclassified."""
|
||||
e = MockAPIError(
|
||||
"long context window exceeded",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason != FailoverReason.oauth_long_context_beta_forbidden
|
||||
|
||||
# ── Transport errors ──
|
||||
|
||||
def test_read_timeout(self):
|
||||
|
||||
@@ -94,13 +94,16 @@ class TestKimiCodingSkipsAnthropicThinking:
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
||||
def test_kimi_root_endpoint_unaffected(self) -> None:
|
||||
"""Only the /coding route is special-cased — plain api.kimi.com is not.
|
||||
def test_kimi_root_endpoint_via_anthropic_transport_omits_thinking(self) -> None:
|
||||
"""Plain ``api.kimi.com`` hit via the Anthropic transport also omits thinking.
|
||||
|
||||
``api.kimi.com`` without ``/coding`` uses the chat_completions transport
|
||||
(see runtime_provider._detect_api_mode_for_url); build_anthropic_kwargs
|
||||
should never see it, but if it somehow does we should not suppress
|
||||
thinking there — that path has different semantics.
|
||||
Auto-detection routes ``api.kimi.com/v1`` to ``chat_completions`` by
|
||||
default, but users can explicitly configure
|
||||
``api_mode: anthropic_messages`` against any Kimi host. The upstream
|
||||
validation (reasoning_content required on replayed tool-call
|
||||
messages) is the same regardless of URL path, so the thinking
|
||||
suppression must apply to every Kimi host, not just ``/coding``.
|
||||
See #17057.
|
||||
"""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
@@ -112,4 +115,98 @@ class TestKimiCodingSkipsAnthropicThinking:
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url="https://api.kimi.com/v1",
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
|
||||
# ── #17057: custom / proxied Kimi-compatible endpoints ──────────
|
||||
@pytest.mark.parametrize(
|
||||
"base_url,model",
|
||||
[
|
||||
# Custom host with Kimi-family model — the reporter's case
|
||||
("http://my-kimi-proxy.internal", "kimi-2.6"),
|
||||
("https://llm.example.com/anthropic", "kimi-k2.5"),
|
||||
("https://llm.example.com/anthropic", "moonshot-v1-8k"),
|
||||
("https://llm.example.com/anthropic", "kimi_thinking"),
|
||||
("https://llm.example.com/anthropic", "moonshotai/kimi-k2.5"),
|
||||
# Official Moonshot host (previously uncovered)
|
||||
("https://api.moonshot.ai/anthropic", "moonshot-v1-32k"),
|
||||
("https://api.moonshot.cn/anthropic", "moonshot-v1-32k"),
|
||||
],
|
||||
)
|
||||
def test_kimi_family_custom_endpoint_omits_thinking(
|
||||
self, base_url: str, model: str
|
||||
) -> None:
|
||||
"""Custom / proxied Kimi endpoints must also strip Anthropic thinking."""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url=base_url,
|
||||
)
|
||||
assert "thinking" not in kwargs, (
|
||||
f"Kimi-family endpoint ({base_url}, {model}) must not receive "
|
||||
f"Anthropic thinking — upstream validates reasoning_content on "
|
||||
f"replayed tool-call history we don't preserve."
|
||||
)
|
||||
assert "output_config" not in kwargs
|
||||
|
||||
def test_custom_endpoint_non_kimi_model_keeps_thinking(self) -> None:
|
||||
"""Custom endpoint with a non-Kimi model must keep thinking intact.
|
||||
|
||||
Guards against over-broad model-family matching — only model names
|
||||
starting with a Kimi/Moonshot prefix should trigger suppression.
|
||||
"""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url="https://my-llm-proxy.example.com/anthropic",
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
def test_kimi_family_replay_preserves_unsigned_thinking(self) -> None:
|
||||
"""On a custom Kimi endpoint, unsigned reasoning_content thinking
|
||||
blocks must survive the third-party signature-stripping pass so
|
||||
the upstream's message-history validation passes.
|
||||
"""
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "planning the tool call",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "skill_view", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "ok"},
|
||||
]
|
||||
_, converted = convert_messages_to_anthropic(
|
||||
messages,
|
||||
base_url="http://my-kimi-proxy.internal",
|
||||
model="kimi-2.6",
|
||||
)
|
||||
# The assistant message still carries the unsigned thinking block
|
||||
# synthesised from reasoning_content (required by Kimi's history
|
||||
# validation). A plain third-party endpoint would have stripped it.
|
||||
assistant_msg = next(m for m in converted if m["role"] == "assistant")
|
||||
assistant_blocks = assistant_msg["content"]
|
||||
thinking_blocks = [
|
||||
b for b in assistant_blocks
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert len(thinking_blocks) == 1
|
||||
assert thinking_blocks[0]["thinking"] == "planning the tool call"
|
||||
|
||||
320
tests/agent/test_memory_session_switch.py
Normal file
320
tests/agent/test_memory_session_switch.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Tests for the on_session_switch hook and session_id propagation.
|
||||
|
||||
Covers #6672: memory providers must be notified when AIAgent.session_id
|
||||
rotates mid-process (via /resume, /branch, /reset, /new, or context
|
||||
compression). Without the notification, providers that cache per-session
|
||||
state in initialize() (Hindsight, and any plugin that stores session_id
|
||||
for scoped writes) keep writing into the old session's record.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
|
||||
class _RecordingProvider(MemoryProvider):
|
||||
"""Provider that records every lifecycle call for assertion."""
|
||||
|
||||
def __init__(self, name="rec"):
|
||||
self._name = name
|
||||
self.switch_calls: list[dict] = []
|
||||
self.sync_calls: list[dict] = []
|
||||
self.queue_calls: list[dict] = []
|
||||
self.initialize_calls: list[dict] = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool: # pragma: no cover - unused
|
||||
return True
|
||||
|
||||
def initialize(self, session_id, **kwargs):
|
||||
self.initialize_calls.append({"session_id": session_id, **kwargs})
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return []
|
||||
|
||||
def sync_turn(self, user_content, assistant_content, *, session_id=""):
|
||||
self.sync_calls.append(
|
||||
{"user": user_content, "asst": assistant_content, "session_id": session_id}
|
||||
)
|
||||
|
||||
def queue_prefetch(self, query, *, session_id=""):
|
||||
self.queue_calls.append({"query": query, "session_id": session_id})
|
||||
|
||||
def on_session_switch(
|
||||
self,
|
||||
new_session_id,
|
||||
*,
|
||||
parent_session_id="",
|
||||
reset=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.switch_calls.append(
|
||||
{
|
||||
"new": new_session_id,
|
||||
"parent": parent_session_id,
|
||||
"reset": reset,
|
||||
"extra": kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider ABC — default on_session_switch is a no-op
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MinimalProvider(MemoryProvider):
|
||||
"""Provider that does NOT override on_session_switch — ABC default must no-op."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "minimal"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def initialize(self, session_id, **kwargs): # pragma: no cover - unused
|
||||
pass
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return []
|
||||
|
||||
|
||||
def test_abc_default_on_session_switch_is_noop():
|
||||
"""Providers that don't override the hook must not raise."""
|
||||
p = _MinimalProvider()
|
||||
# All three call styles must be accepted without raising
|
||||
p.on_session_switch("new-id")
|
||||
p.on_session_switch("new-id", parent_session_id="old-id")
|
||||
p.on_session_switch("new-id", parent_session_id="old-id", reset=True)
|
||||
p.on_session_switch("new-id", parent_session_id="old-id", reset=True, reason="new_session")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryManager.on_session_switch — fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_manager_fans_out_to_all_providers():
|
||||
mm = MemoryManager()
|
||||
# Only one external provider is allowed; use the builtin slot for p1.
|
||||
p1 = _RecordingProvider(name="builtin")
|
||||
p2 = _RecordingProvider(name="hindsight")
|
||||
mm.add_provider(p1)
|
||||
mm.add_provider(p2)
|
||||
|
||||
mm.on_session_switch("new-sid", parent_session_id="old-sid", reset=False, reason="resume")
|
||||
|
||||
assert len(p1.switch_calls) == 1
|
||||
assert len(p2.switch_calls) == 1
|
||||
for call in (p1.switch_calls[0], p2.switch_calls[0]):
|
||||
assert call["new"] == "new-sid"
|
||||
assert call["parent"] == "old-sid"
|
||||
assert call["reset"] is False
|
||||
assert call["extra"] == {"reason": "resume"}
|
||||
|
||||
|
||||
def test_manager_ignores_empty_session_id():
|
||||
"""Empty string session_id must not trigger provider hooks.
|
||||
|
||||
Prevents accidental fires during shutdown when self.session_id may be
|
||||
cleared. Providers expect a meaningful id to switch TO.
|
||||
"""
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.on_session_switch("")
|
||||
mm.on_session_switch(None) # type: ignore[arg-type]
|
||||
assert p.switch_calls == []
|
||||
|
||||
|
||||
def test_manager_isolates_provider_failures():
|
||||
"""A provider that raises must not block other providers."""
|
||||
|
||||
class _Broken(_RecordingProvider):
|
||||
def on_session_switch(self, *args, **kwargs): # type: ignore[override]
|
||||
raise RuntimeError("boom")
|
||||
|
||||
mm = MemoryManager()
|
||||
# MemoryManager rejects a second external provider, so pair broken
|
||||
# (builtin slot) with a good external one.
|
||||
broken = _Broken(name="builtin")
|
||||
good = _RecordingProvider(name="good")
|
||||
mm.add_provider(broken)
|
||||
mm.add_provider(good)
|
||||
|
||||
# Must not raise — exceptions in one provider are swallowed + logged
|
||||
mm.on_session_switch("new-sid", parent_session_id="old-sid")
|
||||
assert len(good.switch_calls) == 1
|
||||
assert good.switch_calls[0]["new"] == "new-sid"
|
||||
|
||||
|
||||
def test_manager_reset_flag_preserved():
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.on_session_switch("new-sid", reset=True, reason="new_session")
|
||||
assert p.switch_calls[0]["reset"] is True
|
||||
assert p.switch_calls[0]["extra"] == {"reason": "new_session"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryManager.sync_all / queue_prefetch_all — session_id propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sync_all_propagates_session_id_to_providers():
|
||||
"""run_agent.py's sync_all call must pass session_id through to providers.
|
||||
|
||||
Without this, a provider that updates _session_id defensively in
|
||||
sync_turn (as Hindsight does at hindsight/__init__.py:1199) never
|
||||
sees the new id and keeps writing under the old one.
|
||||
"""
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.sync_all("hello", "world", session_id="sess-42")
|
||||
assert p.sync_calls == [
|
||||
{"user": "hello", "asst": "world", "session_id": "sess-42"}
|
||||
]
|
||||
|
||||
|
||||
def test_queue_prefetch_all_propagates_session_id_to_providers():
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.queue_prefetch_all("next query", session_id="sess-42")
|
||||
assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hindsight reference implementation — state-flush semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hindsight_provider():
|
||||
"""Build a bare HindsightMemoryProvider that skips network setup.
|
||||
|
||||
We instantiate without importing optional deps at class-level by
|
||||
bypassing __init__ and seeding the attributes on_session_switch
|
||||
reads/writes. This keeps the test hermetic.
|
||||
"""
|
||||
import threading
|
||||
hindsight_mod = pytest.importorskip("plugins.memory.hindsight")
|
||||
provider = object.__new__(hindsight_mod.HindsightMemoryProvider)
|
||||
provider._session_id = "old-sid"
|
||||
provider._parent_session_id = ""
|
||||
provider._document_id = "old-sid-20260101_000000_000000"
|
||||
provider._session_turns = ["turn-1", "turn-2"]
|
||||
provider._turn_counter = 2
|
||||
provider._turn_index = 2
|
||||
# Attrs read by _build_metadata / _build_retain_kwargs when the
|
||||
# buffer-flush path on session switch fires. Empty strings keep the
|
||||
# metadata minimal but well-formed.
|
||||
provider._retain_source = ""
|
||||
provider._platform = ""
|
||||
provider._user_id = ""
|
||||
provider._user_name = ""
|
||||
provider._chat_id = ""
|
||||
provider._chat_name = ""
|
||||
provider._chat_type = ""
|
||||
provider._thread_id = ""
|
||||
provider._agent_identity = ""
|
||||
provider._agent_workspace = ""
|
||||
provider._retain_tags = []
|
||||
provider._retain_context = "test-context"
|
||||
provider._retain_async = False
|
||||
provider._bank_id = "test-bank"
|
||||
# Prefetch state the switch path drains/clears.
|
||||
provider._prefetch_thread = None
|
||||
provider._prefetch_lock = threading.Lock()
|
||||
provider._prefetch_result = ""
|
||||
# Sync thread tracking (legacy alias at the writer).
|
||||
provider._sync_thread = None
|
||||
# Writer queue infra the flush-on-switch path enqueues onto. We stub
|
||||
# _ensure_writer / _register_atexit so no real thread is spawned;
|
||||
# tests exercising flush delivery live in
|
||||
# tests/plugins/memory/test_hindsight_provider.py where the full
|
||||
# writer-queue wiring is in place.
|
||||
import queue as _queue
|
||||
provider._retain_queue = _queue.Queue()
|
||||
provider._shutting_down = threading.Event()
|
||||
provider._atexit_registered = True
|
||||
provider._ensure_writer = lambda: None
|
||||
provider._register_atexit = lambda: None
|
||||
# Stub the network-touching helper so any enqueued flush closure is
|
||||
# a no-op if ever drained in a unit test.
|
||||
provider._run_hindsight_operation = lambda _op: None
|
||||
return provider
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_updates_session_id_and_mints_fresh_doc():
|
||||
provider = _make_hindsight_provider()
|
||||
old_doc = provider._document_id
|
||||
|
||||
provider.on_session_switch(
|
||||
"new-sid", parent_session_id="old-sid", reset=False, reason="resume"
|
||||
)
|
||||
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._parent_session_id == "old-sid"
|
||||
# Document id MUST be fresh — else next retain overwrites old session doc
|
||||
assert provider._document_id != old_doc
|
||||
assert provider._document_id.startswith("new-sid-")
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_clears_turn_buffers():
|
||||
"""Accumulated _session_turns must not leak into the next session.
|
||||
|
||||
Hindsight batches turns under a single _document_id. If the buffer
|
||||
isn't cleared on switch, the next retain under the new _document_id
|
||||
flushes turns that belong to the previous session.
|
||||
"""
|
||||
provider = _make_hindsight_provider()
|
||||
provider.on_session_switch("new-sid", parent_session_id="old-sid")
|
||||
assert provider._session_turns == []
|
||||
assert provider._turn_counter == 0
|
||||
assert provider._turn_index == 0
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_clears_on_reset_true():
|
||||
"""reset=True (from /new, /reset) must also flush buffers."""
|
||||
provider = _make_hindsight_provider()
|
||||
provider.on_session_switch("new-sid", reset=True, reason="new_session")
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._session_turns == []
|
||||
assert provider._turn_counter == 0
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_ignores_empty_id():
|
||||
"""Empty new_session_id must be a no-op to avoid corrupting state."""
|
||||
provider = _make_hindsight_provider()
|
||||
before = (
|
||||
provider._session_id,
|
||||
provider._document_id,
|
||||
list(provider._session_turns),
|
||||
provider._turn_counter,
|
||||
)
|
||||
provider.on_session_switch("")
|
||||
provider.on_session_switch(None) # type: ignore[arg-type]
|
||||
after = (
|
||||
provider._session_id,
|
||||
provider._document_id,
|
||||
list(provider._session_turns),
|
||||
provider._turn_counter,
|
||||
)
|
||||
assert before == after
|
||||
|
||||
|
||||
def test_hindsight_preserves_parent_across_empty_parent_arg():
|
||||
"""Omitting parent_session_id must NOT overwrite an existing one."""
|
||||
provider = _make_hindsight_provider()
|
||||
provider._parent_session_id = "original-parent"
|
||||
provider.on_session_switch("new-sid") # no parent passed
|
||||
assert provider._parent_session_id == "original-parent"
|
||||
@@ -308,10 +308,15 @@ class TestMinimaxPreserveDots:
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name("MiniMax-M2.7", preserve_dots=True) == "MiniMax-M2.7"
|
||||
|
||||
def test_normalize_converts_without_preserve(self):
|
||||
def test_normalize_preserves_non_anthropic_dots_without_preserve(self):
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
# Without preserve_dots, dots become hyphens (broken for MiniMax)
|
||||
assert normalize_model_name("MiniMax-M2.7", preserve_dots=False) == "MiniMax-M2-7"
|
||||
# Non-Anthropic model families use dots as canonical version separators;
|
||||
# only Claude/Anthropic names are hyphen-normalized by default.
|
||||
assert normalize_model_name("MiniMax-M2.7", preserve_dots=False) == "MiniMax-M2.7"
|
||||
|
||||
def test_normalize_still_converts_claude_dots_without_preserve(self):
|
||||
from agent.anthropic_adapter import normalize_model_name
|
||||
assert normalize_model_name("claude-opus-4.6", preserve_dots=False) == "claude-opus-4-6"
|
||||
|
||||
|
||||
class TestMinimaxSwitchModelCredentialGuard:
|
||||
|
||||
@@ -205,11 +205,22 @@ class TestDetectOpenclawResidue:
|
||||
|
||||
|
||||
class TestOpenclawResidueHint:
|
||||
def test_hint_mentions_cleanup_command(self):
|
||||
def test_hint_mentions_migrate_command(self):
|
||||
# `migrate` is the non-destructive path — should lead the banner.
|
||||
msg = openclaw_residue_hint_cli()
|
||||
assert "hermes claw cleanup" in msg
|
||||
assert "hermes claw migrate" in msg
|
||||
assert "~/.openclaw" in msg
|
||||
|
||||
def test_hint_mentions_cleanup_command(self):
|
||||
# `cleanup` is mentioned as the follow-up archive step.
|
||||
assert "hermes claw cleanup" in openclaw_residue_hint_cli()
|
||||
|
||||
def test_hint_warns_cleanup_breaks_openclaw(self):
|
||||
# Archiving the directory breaks OpenClaw for users still running it —
|
||||
# the banner must flag that side effect.
|
||||
msg = openclaw_residue_hint_cli().lower()
|
||||
assert "openclaw will stop working" in msg or "stop working" in msg
|
||||
|
||||
def test_hint_not_empty(self):
|
||||
assert openclaw_residue_hint_cli().strip()
|
||||
|
||||
|
||||
160
tests/agent/test_skill_commands_reload.py
Normal file
160
tests/agent/test_skill_commands_reload.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Tests for ``agent.skill_commands.reload_skills``.
|
||||
|
||||
Covers the helper that powers ``/reload-skills`` (CLI + gateway slash command).
|
||||
The helper rescans the skills directory and returns a diff of what changed.
|
||||
It does NOT invalidate the skills system-prompt cache — skills are invoked
|
||||
at runtime via ``/skill-name``, ``skills_list``, or ``skill_view`` and don't
|
||||
need to live in the system prompt.
|
||||
|
||||
``added`` and ``removed`` are lists of ``{"name": str, "description": str}``
|
||||
dicts. Descriptions are truncated to 60 chars.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _write_skill(skills_dir: Path, name: str, description: str = "") -> Path:
|
||||
skill_dir = skills_dir / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
name: {name}
|
||||
description: {description or f'{name} skill'}
|
||||
---
|
||||
body
|
||||
"""
|
||||
)
|
||||
)
|
||||
return skill_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_home(monkeypatch):
|
||||
"""Isolate HERMES_HOME for ``reload_skills`` tests.
|
||||
|
||||
Rather than popping cache-bearing modules from ``sys.modules`` (which
|
||||
races against pytest-xdist's parallel workers), we monkeypatch the
|
||||
module-level ``HERMES_HOME`` / ``SKILLS_DIR`` constants in place so the
|
||||
isolation is local to this fixture's scope.
|
||||
"""
|
||||
td = tempfile.mkdtemp(prefix="hermes-reload-skills-")
|
||||
monkeypatch.setenv("HERMES_HOME", td)
|
||||
home = Path(td)
|
||||
(home / "skills").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Import lazily (inside fixture) so the modules are already resident,
|
||||
# then redirect their captured paths at the new temp dir.
|
||||
import tools.skills_tool as _st
|
||||
import agent.skill_commands as _sc
|
||||
|
||||
monkeypatch.setattr(_st, "HERMES_HOME", home, raising=False)
|
||||
monkeypatch.setattr(_st, "SKILLS_DIR", home / "skills", raising=False)
|
||||
# Reset the in-process slash-command cache so each test starts from zero.
|
||||
monkeypatch.setattr(_sc, "_skill_commands", {}, raising=False)
|
||||
|
||||
yield home
|
||||
|
||||
shutil.rmtree(td, ignore_errors=True)
|
||||
|
||||
|
||||
class TestReloadSkillsHelper:
|
||||
"""``agent.skill_commands.reload_skills``."""
|
||||
|
||||
def test_returns_expected_keys(self, hermes_home):
|
||||
from agent.skill_commands import reload_skills
|
||||
|
||||
result = reload_skills()
|
||||
assert set(result) == {"added", "removed", "unchanged", "total", "commands"}
|
||||
assert result["total"] == 0
|
||||
assert result["added"] == []
|
||||
assert result["removed"] == []
|
||||
|
||||
def test_detects_newly_added_skill_with_description(self, hermes_home):
|
||||
from agent.skill_commands import reload_skills, get_skill_commands
|
||||
|
||||
# Prime the cache so subsequent diff is meaningful
|
||||
get_skill_commands()
|
||||
|
||||
_write_skill(hermes_home / "skills", "demo", "a demo skill")
|
||||
result = reload_skills()
|
||||
|
||||
assert result["added"] == [{"name": "demo", "description": "a demo skill"}]
|
||||
assert result["removed"] == []
|
||||
assert result["total"] == 1
|
||||
assert result["commands"] == 1
|
||||
|
||||
def test_detects_removed_skill_carries_description(self, hermes_home):
|
||||
from agent.skill_commands import reload_skills
|
||||
|
||||
skill_dir = _write_skill(hermes_home / "skills", "demo", "soon to be gone")
|
||||
# First reload: demo present
|
||||
first = reload_skills()
|
||||
assert first["total"] == 1
|
||||
assert first["added"] == [{"name": "demo", "description": "soon to be gone"}]
|
||||
|
||||
# Remove and reload — the description must survive the removal diff
|
||||
# (we cached it from the pre-rescan snapshot).
|
||||
shutil.rmtree(skill_dir)
|
||||
second = reload_skills()
|
||||
|
||||
assert second["removed"] == [{"name": "demo", "description": "soon to be gone"}]
|
||||
assert second["added"] == []
|
||||
assert second["total"] == 0
|
||||
|
||||
def test_description_passes_through_verbatim(self, hermes_home):
|
||||
"""``description`` must be the full SKILL.md frontmatter string — no
|
||||
truncation. The system prompt renders skills as
|
||||
`` - name: description`` without a length cap, and the reload
|
||||
note mirrors that format, so truncating here would make the diff
|
||||
render differently from the original catalog."""
|
||||
from agent.skill_commands import reload_skills, get_skill_commands
|
||||
|
||||
get_skill_commands() # prime
|
||||
long_desc = "x" * 200
|
||||
_write_skill(hermes_home / "skills", "longdesc", long_desc)
|
||||
|
||||
result = reload_skills()
|
||||
assert len(result["added"]) == 1
|
||||
assert result["added"][0]["description"] == long_desc
|
||||
|
||||
def test_unchanged_skills_appear_in_unchanged_list(self, hermes_home):
|
||||
from agent.skill_commands import reload_skills, get_skill_commands
|
||||
|
||||
_write_skill(hermes_home / "skills", "alpha")
|
||||
# Prime cache
|
||||
get_skill_commands()
|
||||
|
||||
# Call reload again with no FS changes
|
||||
result = reload_skills()
|
||||
assert "alpha" in result["unchanged"]
|
||||
assert result["added"] == []
|
||||
assert result["removed"] == []
|
||||
|
||||
def test_does_not_invalidate_prompt_cache_snapshot(self, hermes_home):
|
||||
"""reload_skills must NOT delete the skills prompt-cache snapshot.
|
||||
|
||||
Skills are called at runtime — the system prompt doesn't need to
|
||||
mention them for the model to use them — so reloading them should
|
||||
preserve prefix caching.
|
||||
"""
|
||||
from agent.prompt_builder import _skills_prompt_snapshot_path
|
||||
from agent.skill_commands import reload_skills
|
||||
|
||||
snapshot = _skills_prompt_snapshot_path()
|
||||
snapshot.parent.mkdir(parents=True, exist_ok=True)
|
||||
snapshot.write_text("{}")
|
||||
assert snapshot.exists()
|
||||
|
||||
reload_skills()
|
||||
|
||||
assert snapshot.exists(), (
|
||||
"prompt cache snapshot should be preserved — skills don't live "
|
||||
"in the system prompt so there's no reason to invalidate it"
|
||||
)
|
||||
@@ -122,21 +122,25 @@ class TestChatCompletionsBuildKwargs:
|
||||
)
|
||||
assert kw["extra_body"]["think"] is False
|
||||
|
||||
def test_gemini_without_explicit_reasoning_config_keeps_existing_behavior(self, transport):
|
||||
def test_gemini_native_without_explicit_reasoning_config_keeps_existing_behavior(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
)
|
||||
assert "thinking_config" not in kw.get("extra_body", {})
|
||||
assert "google" not in kw.get("extra_body", {})
|
||||
assert "extra_body" not in kw.get("extra_body", {})
|
||||
|
||||
def test_gemini_flash_reasoning_maps_to_thinking_config(self, transport):
|
||||
def test_gemini_native_flash_reasoning_maps_to_top_level_thinking_config(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"] == {
|
||||
@@ -144,52 +148,85 @@ class TestChatCompletionsBuildKwargs:
|
||||
"thinkingLevel": "high",
|
||||
}
|
||||
|
||||
def test_gemini_25_reasoning_only_enables_visible_thoughts(self, transport):
|
||||
def test_gemini_openai_compat_flash_reasoning_maps_to_nested_google_thinking_config(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
)
|
||||
assert "thinking_config" not in kw["extra_body"]
|
||||
assert kw["extra_body"]["extra_body"]["google"]["thinking_config"] == {
|
||||
"include_thoughts": True,
|
||||
"thinking_level": "high",
|
||||
}
|
||||
|
||||
def test_gemini_native_25_reasoning_only_enables_visible_thoughts(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-2.5-flash",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"] == {
|
||||
"includeThoughts": True,
|
||||
}
|
||||
|
||||
def test_gemini_pro_reasoning_clamps_to_supported_levels(self, transport):
|
||||
def test_gemini_openai_compat_pro_reasoning_clamps_to_supported_levels(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="google/gemini-3.1-pro-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"] == {
|
||||
"includeThoughts": True,
|
||||
"thinkingLevel": "low",
|
||||
assert kw["extra_body"]["extra_body"]["google"]["thinking_config"] == {
|
||||
"include_thoughts": True,
|
||||
"thinking_level": "low",
|
||||
}
|
||||
|
||||
def test_gemini_disabled_reasoning_hides_thoughts(self, transport):
|
||||
def test_gemini_native_disabled_reasoning_hides_thoughts(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning_config={"enabled": False},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"] == {
|
||||
"includeThoughts": False,
|
||||
}
|
||||
|
||||
def test_gemini_xhigh_clamps_to_high(self, transport):
|
||||
def test_gemini_openai_compat_xhigh_clamps_to_high(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
reasoning_config={"enabled": True, "effort": "xhigh"},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"]["thinkingLevel"] == "high"
|
||||
assert kw["extra_body"]["extra_body"]["google"]["thinking_config"]["thinking_level"] == "high"
|
||||
|
||||
def test_google_gemini_cli_keeps_top_level_thinking_config(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="google-gemini-cli",
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"] == {
|
||||
"includeThoughts": True,
|
||||
"thinkingLevel": "high",
|
||||
}
|
||||
assert "google" not in kw["extra_body"]
|
||||
|
||||
def test_gemini_flash_minimal_clamps_to_low(self, transport):
|
||||
# Gemini 3 Flash documents low/medium/high; "minimal" isn't accepted,
|
||||
@@ -199,11 +236,12 @@ class TestChatCompletionsBuildKwargs:
|
||||
model="gemini-3-flash-preview",
|
||||
messages=msgs,
|
||||
provider_name="gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
reasoning_config={"enabled": True, "effort": "minimal"},
|
||||
)
|
||||
assert kw["extra_body"]["thinking_config"] == {
|
||||
"includeThoughts": True,
|
||||
"thinkingLevel": "low",
|
||||
assert kw["extra_body"]["extra_body"]["google"]["thinking_config"] == {
|
||||
"include_thoughts": True,
|
||||
"thinking_level": "low",
|
||||
}
|
||||
|
||||
def test_max_tokens_with_fn(self, transport):
|
||||
|
||||
@@ -192,6 +192,33 @@ class TestBranchCommandCLI:
|
||||
|
||||
assert cli_instance._resumed is True
|
||||
|
||||
def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db):
|
||||
"""The /branch command must notify memory providers of the rotation.
|
||||
|
||||
Without this, providers that cache per-session state in
|
||||
initialize() keep writing under the old session_id. See #6672.
|
||||
"""
|
||||
from cli import HermesCLI
|
||||
|
||||
# Wire a real-ish agent object with a MagicMock memory_manager
|
||||
agent = MagicMock()
|
||||
mm = MagicMock()
|
||||
agent._memory_manager = mm
|
||||
cli_instance.agent = agent
|
||||
original_id = cli_instance.session_id
|
||||
|
||||
HermesCLI._handle_branch_command(cli_instance, "/branch")
|
||||
|
||||
# Hook must have been called exactly once with the new session_id,
|
||||
# parent pointing at the branched-from session, reset=False, and
|
||||
# reason="branch" for diagnostics.
|
||||
assert mm.on_session_switch.call_count == 1
|
||||
_, kwargs = mm.on_session_switch.call_args
|
||||
assert mm.on_session_switch.call_args.args[0] == cli_instance.session_id
|
||||
assert kwargs["parent_session_id"] == original_id
|
||||
assert kwargs["reset"] is False
|
||||
assert kwargs["reason"] == "branch"
|
||||
|
||||
def test_fork_alias(self):
|
||||
"""The /fork alias should resolve to 'branch'."""
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
@@ -296,6 +296,30 @@ class TestRootLevelProviderOverride:
|
||||
# Root-level "opencode-go" must NOT leak through
|
||||
assert cfg["model"]["provider"] != "opencode-go"
|
||||
|
||||
def test_terminal_vercel_runtime_bridged_to_env(self, tmp_path, monkeypatch):
|
||||
"""Classic CLI must expose terminal.vercel_runtime to terminal_tool.py."""
|
||||
import yaml
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("TERMINAL_VERCEL_RUNTIME", raising=False)
|
||||
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(yaml.safe_dump({
|
||||
"terminal": {
|
||||
"backend": "vercel_sandbox",
|
||||
"vercel_runtime": "python3.13",
|
||||
},
|
||||
}))
|
||||
|
||||
import cli
|
||||
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
|
||||
cfg = cli.load_cli_config()
|
||||
|
||||
assert cfg["terminal"]["vercel_runtime"] == "python3.13"
|
||||
assert os.environ["TERMINAL_VERCEL_RUNTIME"] == "python3.13"
|
||||
|
||||
def test_normalize_root_model_keys_moves_to_model(self):
|
||||
"""_normalize_root_model_keys migrates root keys into model section."""
|
||||
from hermes_cli.config import _normalize_root_model_keys
|
||||
|
||||
@@ -49,8 +49,15 @@ class TestCLILoadingIndicator:
|
||||
seen["status"] = cli_obj._command_status
|
||||
print("reload done")
|
||||
|
||||
# /reload-mcp now wraps the actual reload in a prompt-cache-invalidation
|
||||
# confirmation prompt (commit 4d7fc0f37). This test exercises the
|
||||
# loading-indicator path, not the confirmation UX, so pre-approve the
|
||||
# reload via config so the handler goes straight into _reload_mcp().
|
||||
fake_cfg = {"approvals": {"mcp_reload_confirm": False}}
|
||||
|
||||
with patch.object(cli_obj, "_reload_mcp", side_effect=fake_reload), \
|
||||
patch.object(cli_obj, "_invalidate") as invalidate_mock:
|
||||
patch.object(cli_obj, "_invalidate") as invalidate_mock, \
|
||||
patch("cli.load_cli_config", return_value=fake_cfg):
|
||||
assert cli_obj.process_command("/reload-mcp")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
|
||||
99
tests/cli/test_cli_reload_skills.py
Normal file
99
tests/cli/test_cli_reload_skills.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for the ``/reload-skills`` CLI slash command (``HermesCLI._reload_skills``).
|
||||
|
||||
The CLI handler prints the diff (name + description) for the user and —
|
||||
when any skills were added or removed — queues a one-shot note on
|
||||
``self._pending_skills_reload_note``. The note is prepended to the NEXT
|
||||
user message (see cli.py ~L8770, same pattern as
|
||||
``_pending_model_switch_note``) and cleared after use, so no phantom user
|
||||
turn is persisted to ``conversation_history``.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _make_cli():
|
||||
"""Build a minimal HermesCLI shell exposing ``_reload_skills``."""
|
||||
import cli as cli_mod
|
||||
|
||||
obj = object.__new__(cli_mod.HermesCLI)
|
||||
obj._command_running = False
|
||||
obj.conversation_history = []
|
||||
obj.agent = None
|
||||
return obj
|
||||
|
||||
|
||||
class TestReloadSkillsCLI:
|
||||
def test_reports_added_and_removed_and_queues_note(self, capsys):
|
||||
cli = _make_cli()
|
||||
with patch(
|
||||
"agent.skill_commands.reload_skills",
|
||||
return_value={
|
||||
"added": [
|
||||
{"name": "alpha", "description": "Run alpha to do xyz"},
|
||||
{"name": "beta", "description": "Run beta to do abc"},
|
||||
],
|
||||
"removed": [
|
||||
{"name": "gamma", "description": "Old removed skill"},
|
||||
],
|
||||
"unchanged": ["delta"],
|
||||
"total": 3,
|
||||
"commands": 3,
|
||||
},
|
||||
):
|
||||
cli._reload_skills()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Added Skills:" in out
|
||||
assert "- alpha: Run alpha to do xyz" in out
|
||||
assert "- beta: Run beta to do abc" in out
|
||||
assert "Removed Skills:" in out
|
||||
assert "- gamma: Old removed skill" in out
|
||||
assert "3 skill(s) available" in out
|
||||
|
||||
# Must NOT pollute conversation_history — alternation-safe.
|
||||
assert cli.conversation_history == []
|
||||
|
||||
# One-shot note queued with system-prompt-style formatting.
|
||||
note = getattr(cli, "_pending_skills_reload_note", None)
|
||||
assert note is not None
|
||||
assert note.startswith("[USER INITIATED SKILLS RELOAD:")
|
||||
assert note.endswith("Use skills_list to see the updated catalog.]")
|
||||
assert "Added Skills:" in note
|
||||
assert " - alpha: Run alpha to do xyz" in note
|
||||
assert " - beta: Run beta to do abc" in note
|
||||
assert "Removed Skills:" in note
|
||||
assert " - gamma: Old removed skill" in note
|
||||
|
||||
def test_reports_no_changes_and_queues_nothing(self, capsys):
|
||||
cli = _make_cli()
|
||||
with patch(
|
||||
"agent.skill_commands.reload_skills",
|
||||
return_value={
|
||||
"added": [],
|
||||
"removed": [],
|
||||
"unchanged": ["alpha"],
|
||||
"total": 1,
|
||||
"commands": 1,
|
||||
},
|
||||
):
|
||||
cli._reload_skills()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "No new skills detected" in out
|
||||
assert "1 skill(s) available" in out
|
||||
assert cli.conversation_history == []
|
||||
assert getattr(cli, "_pending_skills_reload_note", None) is None
|
||||
|
||||
def test_handles_reload_failure_gracefully(self, capsys):
|
||||
cli = _make_cli()
|
||||
with patch(
|
||||
"agent.skill_commands.reload_skills",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
cli._reload_skills()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Skills reload failed" in out
|
||||
assert "boom" in out
|
||||
assert cli.conversation_history == []
|
||||
assert getattr(cli, "_pending_skills_reload_note", None) is None
|
||||
@@ -20,6 +20,7 @@ test runner at ``scripts/run_tests.sh``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
@@ -174,7 +175,10 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
||||
"HERMES_SESSION_KEY",
|
||||
"HERMES_GATEWAY_SESSION",
|
||||
"HERMES_PLATFORM",
|
||||
"HERMES_MODEL",
|
||||
"HERMES_INFERENCE_MODEL",
|
||||
"HERMES_INFERENCE_PROVIDER",
|
||||
"HERMES_TUI_PROVIDER",
|
||||
"HERMES_MANAGED",
|
||||
"HERMES_DEV",
|
||||
"HERMES_CONTAINER",
|
||||
@@ -184,6 +188,14 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
||||
"HERMES_BACKGROUND_NOTIFICATIONS",
|
||||
"HERMES_EXEC_ASK",
|
||||
"HERMES_HOME_MODE",
|
||||
"TERMINAL_CWD",
|
||||
"TERMINAL_ENV",
|
||||
"TERMINAL_VERCEL_RUNTIME",
|
||||
"TERMINAL_CONTAINER_CPU",
|
||||
"TERMINAL_CONTAINER_DISK",
|
||||
"TERMINAL_CONTAINER_MEMORY",
|
||||
"TERMINAL_CONTAINER_PERSISTENT",
|
||||
"TERMINAL_DOCKER_RUN_AS_HOST_USER",
|
||||
"BROWSER_CDP_URL",
|
||||
"CAMOFOX_URL",
|
||||
# Platform allowlists — not credentials, but if set from any source
|
||||
@@ -326,6 +338,14 @@ def _reset_module_state():
|
||||
that don't exist yet (test collection before production import) are
|
||||
skipped silently — production import later creates fresh empty state.
|
||||
"""
|
||||
# --- logging — quiet/one-shot paths mutate process-global logger state ---
|
||||
logging.disable(logging.NOTSET)
|
||||
for _logger_name in ("tools", "run_agent", "trajectory_compressor", "cron", "hermes_cli"):
|
||||
_logger = logging.getLogger(_logger_name)
|
||||
_logger.disabled = False
|
||||
_logger.setLevel(logging.NOTSET)
|
||||
_logger.propagate = True
|
||||
|
||||
# --- tools.approval — the single biggest source of cross-test pollution ---
|
||||
try:
|
||||
from tools import approval as _approval_mod
|
||||
@@ -380,6 +400,26 @@ def _reset_module_state():
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.terminal_tool — active environment/cwd cache ---
|
||||
# File tools prefer a live terminal cwd when one is cached for the task.
|
||||
# Clear terminal environments between tests so a prior terminal call can't
|
||||
# override TERMINAL_CWD in path-resolution tests.
|
||||
try:
|
||||
from tools import terminal_tool as _term_mod
|
||||
_envs_to_cleanup = []
|
||||
with _term_mod._env_lock:
|
||||
_envs_to_cleanup = list(_term_mod._active_environments.values())
|
||||
_term_mod._active_environments.clear()
|
||||
_term_mod._last_activity.clear()
|
||||
_term_mod._creation_locks.clear()
|
||||
for _env in _envs_to_cleanup:
|
||||
try:
|
||||
_env.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.credential_files — ContextVar<dict> ---
|
||||
try:
|
||||
from tools import credential_files as _credf_mod
|
||||
|
||||
87
tests/cron/test_compute_next_run_last_run_at.py
Normal file
87
tests/cron/test_compute_next_run_last_run_at.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test that compute_next_run uses last_run_at for cron jobs.
|
||||
|
||||
Regression test for: cron jobs computing next_run_at from _hermes_now()
|
||||
instead of from last_run_at, making them inconsistent with interval jobs.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
pytest.importorskip("croniter")
|
||||
|
||||
from cron.jobs import compute_next_run
|
||||
|
||||
|
||||
class TestCronComputeNextRunUsesLastRunAt:
|
||||
"""compute_next_run MUST use last_run_at as the croniter base for cron jobs,
|
||||
consistent with how interval jobs work."""
|
||||
|
||||
def test_cron_uses_last_run_at_for_every_6h_schedule(self, monkeypatch):
|
||||
"""For a schedule like 'every 6 hours', the base time matters.
|
||||
If last_run_at is Apr 6 14:10, next should be Apr 6 18:00.
|
||||
If now is Apr 10 22:00, next should be Apr 11 00:00.
|
||||
compute_next_run must use last_run_at, not now."""
|
||||
morocco = ZoneInfo("Africa/Casablanca")
|
||||
|
||||
# Job last ran April 6 at 14:10
|
||||
last_run = datetime(2026, 4, 6, 14, 10, 0, tzinfo=morocco)
|
||||
|
||||
# But now it's April 10 at 22:00 (e.g., gateway restarted)
|
||||
now = datetime(2026, 4, 10, 22, 0, 0, tzinfo=morocco)
|
||||
monkeypatch.setattr("cron.jobs._hermes_now", lambda: now)
|
||||
|
||||
schedule = {"kind": "cron", "expr": "0 */6 * * *"} # every 6 hours
|
||||
|
||||
result = compute_next_run(schedule, last_run_at=last_run.isoformat())
|
||||
assert result is not None
|
||||
next_dt = datetime.fromisoformat(result)
|
||||
|
||||
# With last_run_at as base (Apr 6 14:10), next is Apr 6 18:00.
|
||||
# With now as base (Apr 10 22:00), next is Apr 11 00:00.
|
||||
# The fix should use last_run_at, returning Apr 6 18:00
|
||||
# (stale detection in get_due_jobs() fast-forwards from there).
|
||||
assert next_dt.date().isoformat() == "2026-04-06", (
|
||||
f"Expected next run on Apr 6 (from last_run_at), got {next_dt}"
|
||||
)
|
||||
assert next_dt.hour == 18
|
||||
|
||||
def test_cron_without_last_run_at_uses_now(self, monkeypatch):
|
||||
"""When last_run_at is NOT provided, compute_next_run falls back to
|
||||
_hermes_now() as the croniter base (existing behavior)."""
|
||||
morocco = ZoneInfo("Africa/Casablanca")
|
||||
|
||||
now = datetime(2026, 4, 10, 22, 0, 0, tzinfo=morocco)
|
||||
monkeypatch.setattr("cron.jobs._hermes_now", lambda: now)
|
||||
|
||||
schedule = {"kind": "cron", "expr": "0 */6 * * *"}
|
||||
|
||||
result = compute_next_run(schedule)
|
||||
assert result is not None
|
||||
next_dt = datetime.fromisoformat(result)
|
||||
|
||||
# Without last_run_at, should compute from now -> Apr 11 00:00
|
||||
assert next_dt.date().isoformat() == "2026-04-11", (
|
||||
f"Expected next run on Apr 11 (from now), got {next_dt}"
|
||||
)
|
||||
assert next_dt.hour == 0
|
||||
|
||||
def test_cron_weekly_consistent_with_interval(self, monkeypatch):
|
||||
"""Both cron and interval jobs should anchor to last_run_at when
|
||||
provided, producing consistent behavior after a crash/restart."""
|
||||
morocco = ZoneInfo("Africa/Casablanca")
|
||||
|
||||
last_run = datetime(2026, 4, 6, 14, 10, 0, tzinfo=morocco)
|
||||
now = datetime(2026, 4, 10, 22, 0, 0, tzinfo=morocco)
|
||||
monkeypatch.setattr("cron.jobs._hermes_now", lambda: now)
|
||||
|
||||
cron_schedule = {"kind": "cron", "expr": "0 14 * * 1"}
|
||||
interval_schedule = {"kind": "interval", "minutes": 7 * 24 * 60}
|
||||
|
||||
cron_result = compute_next_run(cron_schedule, last_run_at=last_run.isoformat())
|
||||
interval_result = compute_next_run(interval_schedule, last_run_at=last_run.isoformat())
|
||||
|
||||
# Both should be after last_run_at
|
||||
cron_dt = datetime.fromisoformat(cron_result)
|
||||
interval_dt = datetime.fromisoformat(interval_result)
|
||||
assert cron_dt > last_run, f"Cron next {cron_dt} should be after last_run {last_run}"
|
||||
assert interval_dt > last_run, f"Interval next {interval_dt} should be after last_run {last_run}"
|
||||
@@ -169,10 +169,20 @@ class TestInactivityTimeout:
|
||||
|
||||
assert result["final_response"] == "Done"
|
||||
|
||||
def _parse_cron_timeout(self, raw_value):
|
||||
"""Mirror the defensive parsing logic from cron/scheduler.py run_job()."""
|
||||
if raw_value:
|
||||
try:
|
||||
return float(raw_value)
|
||||
except (ValueError, TypeError):
|
||||
return 600.0
|
||||
return 600.0
|
||||
|
||||
def test_timeout_env_var_parsing(self, monkeypatch):
|
||||
"""HERMES_CRON_TIMEOUT env var is respected."""
|
||||
monkeypatch.setenv("HERMES_CRON_TIMEOUT", "1200")
|
||||
_cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600))
|
||||
raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip()
|
||||
_cron_timeout = self._parse_cron_timeout(raw)
|
||||
assert _cron_timeout == 1200.0
|
||||
|
||||
_cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None
|
||||
@@ -181,10 +191,27 @@ class TestInactivityTimeout:
|
||||
def test_timeout_zero_means_unlimited(self, monkeypatch):
|
||||
"""HERMES_CRON_TIMEOUT=0 yields None (unlimited)."""
|
||||
monkeypatch.setenv("HERMES_CRON_TIMEOUT", "0")
|
||||
_cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600))
|
||||
raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip()
|
||||
_cron_timeout = self._parse_cron_timeout(raw)
|
||||
_cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None
|
||||
assert _cron_inactivity_limit is None
|
||||
|
||||
def test_timeout_invalid_value_falls_back_to_default(self, monkeypatch):
|
||||
"""HERMES_CRON_TIMEOUT=abc should fall back to 600s, not raise ValueError."""
|
||||
monkeypatch.setenv("HERMES_CRON_TIMEOUT", "abc")
|
||||
raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip()
|
||||
_cron_timeout = self._parse_cron_timeout(raw)
|
||||
assert _cron_timeout == 600.0
|
||||
_cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None
|
||||
assert _cron_inactivity_limit == 600.0
|
||||
|
||||
def test_timeout_empty_string_uses_default(self, monkeypatch):
|
||||
"""HERMES_CRON_TIMEOUT='' (empty) should use the 600s default."""
|
||||
monkeypatch.setenv("HERMES_CRON_TIMEOUT", "")
|
||||
raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip()
|
||||
_cron_timeout = self._parse_cron_timeout(raw)
|
||||
assert _cron_timeout == 600.0
|
||||
|
||||
def test_timeout_error_includes_diagnostics(self):
|
||||
"""The TimeoutError message should include last activity info."""
|
||||
agent = SlowFakeAgent(
|
||||
|
||||
@@ -265,6 +265,7 @@ class TestRunJobTerminalCwd:
|
||||
class FakeAgent:
|
||||
def __init__(self, **kwargs):
|
||||
observed["skip_context_files"] = kwargs.get("skip_context_files")
|
||||
observed["load_soul_identity"] = kwargs.get("load_soul_identity")
|
||||
observed["terminal_cwd_during_init"] = os.environ.get(
|
||||
"TERMINAL_CWD", "_UNSET_"
|
||||
)
|
||||
@@ -335,6 +336,7 @@ class TestRunJobTerminalCwd:
|
||||
|
||||
# AIAgent was built with skip_context_files=False (feature ON).
|
||||
assert observed["skip_context_files"] is False
|
||||
assert observed["load_soul_identity"] is True
|
||||
# TERMINAL_CWD was pointing at the job workdir while the agent ran.
|
||||
assert observed["terminal_cwd_during_init"] == str(tmp_path.resolve())
|
||||
assert observed["terminal_cwd_during_run"] == str(tmp_path.resolve())
|
||||
@@ -373,6 +375,8 @@ class TestRunJobTerminalCwd:
|
||||
|
||||
# Feature is OFF — skip_context_files stays True.
|
||||
assert observed["skip_context_files"] is True
|
||||
# Cron still forces SOUL.md identity even when cwd context files stay off.
|
||||
assert observed["load_soul_identity"] is True
|
||||
# TERMINAL_CWD saw the same value during init as it had before.
|
||||
assert observed["terminal_cwd_during_init"] == before
|
||||
# And after run_job completes, it's still the sentinel (nothing
|
||||
|
||||
@@ -279,6 +279,44 @@ class TestResolveDeliveryTarget:
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
def test_list_form_deliver_is_normalized(self, monkeypatch):
|
||||
"""deliver=['telegram'] (Python list) should resolve like 'telegram' string.
|
||||
|
||||
Regression test for #17139: MCP clients / scripts that pass the deliver
|
||||
field as an array-shaped value used to fail with "no delivery target
|
||||
resolved for deliver=['telegram']" because ``str(['telegram'])`` was
|
||||
passed through to ``split(',')`` verbatim.
|
||||
"""
|
||||
monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "-4004")
|
||||
job = {
|
||||
"deliver": ["telegram"],
|
||||
"origin": None,
|
||||
}
|
||||
|
||||
assert _resolve_delivery_target(job) == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-4004",
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
def test_list_form_multiple_platforms_normalized(self, monkeypatch):
|
||||
"""deliver=['telegram', 'discord'] resolves to multiple targets."""
|
||||
from cron.scheduler import _resolve_delivery_targets
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "-111")
|
||||
monkeypatch.setenv("DISCORD_HOME_CHANNEL", "-222")
|
||||
job = {"deliver": ["telegram", "discord"], "origin": None}
|
||||
|
||||
targets = _resolve_delivery_targets(job)
|
||||
platforms = sorted(t["platform"] for t in targets)
|
||||
assert platforms == ["discord", "telegram"]
|
||||
|
||||
def test_empty_list_form_deliver_resolves_to_local(self):
|
||||
"""deliver=[] is treated as local (no delivery)."""
|
||||
from cron.scheduler import _resolve_delivery_targets
|
||||
|
||||
assert _resolve_delivery_targets({"deliver": []}) == []
|
||||
|
||||
|
||||
class TestDeliverResultWrapping:
|
||||
"""Verify that cron deliveries are wrapped with header/footer and no longer mirrored."""
|
||||
@@ -513,14 +551,14 @@ class TestDeliverResultWrapping:
|
||||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"MEDIA:/tmp/voice.ogg",
|
||||
"[[audio_as_voice]]\nMEDIA:/tmp/voice.ogg",
|
||||
adapters={Platform.TELEGRAM: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
# Text send should NOT be called (no text after stripping MEDIA tag)
|
||||
adapter.send.assert_not_called()
|
||||
# Audio should still be delivered
|
||||
# Audio should still be delivered as a voice bubble
|
||||
adapter.send_voice.assert_called_once()
|
||||
|
||||
def test_live_adapter_sends_cleaned_text_not_raw(self):
|
||||
@@ -989,6 +1027,80 @@ class TestRunJobSessionPersistence:
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID") is None
|
||||
fake_db.close.assert_called_once()
|
||||
|
||||
def test_run_job_clears_stale_auto_delivery_thread_id_between_jobs(self, tmp_path, monkeypatch):
|
||||
jobs = [
|
||||
{
|
||||
"id": "threaded-job",
|
||||
"name": "threaded",
|
||||
"prompt": "hello",
|
||||
"deliver": "telegram:-1001:42",
|
||||
},
|
||||
{
|
||||
"id": "threadless-job",
|
||||
"name": "threadless",
|
||||
"prompt": "hello again",
|
||||
"deliver": "telegram:-2002",
|
||||
},
|
||||
]
|
||||
fake_db = MagicMock()
|
||||
seen = []
|
||||
|
||||
monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID", raising=False)
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def run_conversation(self, *args, **kwargs):
|
||||
from gateway.session_context import get_session_env
|
||||
|
||||
seen.append(
|
||||
{
|
||||
"platform": get_session_env("HERMES_CRON_AUTO_DELIVER_PLATFORM") or None,
|
||||
"chat_id": get_session_env("HERMES_CRON_AUTO_DELIVER_CHAT_ID") or None,
|
||||
"thread_id": get_session_env("HERMES_CRON_AUTO_DELIVER_THREAD_ID") or None,
|
||||
}
|
||||
)
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("run_agent.AIAgent", FakeAgent):
|
||||
for job in jobs:
|
||||
success, output, final_response, error = run_job(job)
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
assert "ok" in output
|
||||
|
||||
assert seen == [
|
||||
{
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"thread_id": "42",
|
||||
},
|
||||
{
|
||||
"platform": "telegram",
|
||||
"chat_id": "-2002",
|
||||
"thread_id": None,
|
||||
},
|
||||
]
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM") is None
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID") is None
|
||||
assert fake_db.close.call_count == 2
|
||||
|
||||
|
||||
class TestRunJobConfigLogging:
|
||||
"""Verify that config.yaml parse failures are logged, not silently swallowed."""
|
||||
|
||||
72
tests/gateway/_plugin_adapter_loader.py
Normal file
72
tests/gateway/_plugin_adapter_loader.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Shared helper for loading platform-plugin ``adapter.py`` modules in tests.
|
||||
|
||||
Every platform plugin under ``plugins/platforms/<name>/`` ships its own
|
||||
``adapter.py``. If two tests independently do::
|
||||
|
||||
sys.path.insert(0, "plugins/platforms/irc")
|
||||
from adapter import IRCAdapter
|
||||
|
||||
sys.path.insert(0, "plugins/platforms/teams")
|
||||
from adapter import TeamsAdapter
|
||||
|
||||
…then whichever collects first in an xdist worker wins
|
||||
``sys.modules["adapter"]``, and the other raises ``ImportError`` at
|
||||
collection time. The fallout cascades across unrelated tests sharing that
|
||||
worker because ``sys.path`` is still polluted.
|
||||
|
||||
Use :func:`load_plugin_adapter` instead of ad-hoc ``sys.path`` tricks.
|
||||
It loads the adapter from an explicit file path under a unique module
|
||||
name (``plugin_adapter_<plugin_name>``), so it cannot collide with any
|
||||
other plugin's adapter module.
|
||||
|
||||
The ``tests/gateway/conftest.py`` guard rejects the anti-pattern at
|
||||
collection time so this can't regress when new plugin adapter tests are
|
||||
added.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
_PLUGINS_DIR = _REPO_ROOT / "plugins" / "platforms"
|
||||
|
||||
|
||||
def load_plugin_adapter(plugin_name: str) -> ModuleType:
|
||||
"""Import ``plugins/platforms/<plugin_name>/adapter.py`` in isolation.
|
||||
|
||||
The module is registered under the unique name
|
||||
``plugin_adapter_<plugin_name>`` in ``sys.modules``. No ``sys.path``
|
||||
mutation. Safe to call multiple times — repeat calls return the
|
||||
already-loaded module.
|
||||
"""
|
||||
module_name = f"plugin_adapter_{plugin_name}"
|
||||
cached = sys.modules.get(module_name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
adapter_path = _PLUGINS_DIR / plugin_name / "adapter.py"
|
||||
if not adapter_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Plugin adapter not found: {adapter_path}. "
|
||||
f"Known plugins: {sorted(p.name for p in _PLUGINS_DIR.iterdir() if p.is_dir())}"
|
||||
)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, adapter_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Could not build import spec for {adapter_path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
# Register BEFORE exec so the module can find itself if needed (some
|
||||
# modules do ``sys.modules[__name__]`` reflection during import).
|
||||
sys.modules[module_name] = module
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception:
|
||||
sys.modules.pop(module_name, None)
|
||||
raise
|
||||
return module
|
||||
@@ -12,11 +12,32 @@ ImportError fallback, causing 30+ downstream test failures wherever
|
||||
|
||||
Individual test files may still call their own ``_ensure_telegram_mock``
|
||||
— it short-circuits when the mock is already present.
|
||||
|
||||
Plugin-adapter anti-pattern guard
|
||||
---------------------------------
|
||||
Tests for platform plugins (``plugins/platforms/<name>/adapter.py``)
|
||||
must load the adapter via
|
||||
:func:`tests.gateway._plugin_adapter_loader.load_plugin_adapter`, not by
|
||||
adding the plugin directory to ``sys.path`` and doing a bare
|
||||
``from adapter import ...``. The guard at the bottom of this file
|
||||
scans test module ASTs at collection time and fails collection with a
|
||||
pointer to the helper if the anti-pattern is detected.
|
||||
|
||||
Rationale: every plugin ships its own ``adapter.py``, and two tests each
|
||||
inserting their plugin dir on ``sys.path[0]`` race for
|
||||
``sys.modules["adapter"]`` in the same xdist worker. Whichever collects
|
||||
first wins; the other fails with ``ImportError``, and the polluted
|
||||
``sys.path`` cascades into unrelated tests. See PR #17764 for the
|
||||
incident.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _ensure_telegram_mock() -> None:
|
||||
"""Install a comprehensive telegram mock in sys.modules.
|
||||
@@ -197,3 +218,128 @@ def _ensure_discord_mock() -> None:
|
||||
# Run at collection time — before any test file's module-level imports.
|
||||
_ensure_telegram_mock()
|
||||
_ensure_discord_mock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin-adapter anti-pattern guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GATEWAY_DIR = Path(__file__).resolve().parent
|
||||
_GUARD_HINT = (
|
||||
"Plugin adapter tests must use "
|
||||
"``from tests.gateway._plugin_adapter_loader import load_plugin_adapter`` "
|
||||
"and call ``load_plugin_adapter('<plugin_name>')`` instead of inserting "
|
||||
"``plugins/platforms/<name>/`` on sys.path and doing a bare ``import "
|
||||
"adapter`` / ``from adapter import ...``. See the 'Plugin-adapter "
|
||||
"anti-pattern guard' docstring in tests/gateway/conftest.py."
|
||||
)
|
||||
|
||||
|
||||
def _scan_for_plugin_adapter_antipattern(source: str) -> list[str]:
|
||||
"""Return a list of offending-line descriptions, or [] if clean.
|
||||
|
||||
Flags two things:
|
||||
1. ``sys.path.insert(..., <something mentioning 'plugins/platforms'>)``
|
||||
2. ``import adapter`` or ``from adapter import ...`` at module level.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
return [] # Let pytest surface the real syntax error.
|
||||
|
||||
offenses: list[str] = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# sys.path.insert(0, ".../plugins/platforms/...")
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
target_name: str | None = None
|
||||
if isinstance(func, ast.Attribute):
|
||||
# sys.path.insert / sys.path.append
|
||||
if (
|
||||
isinstance(func.value, ast.Attribute)
|
||||
and isinstance(func.value.value, ast.Name)
|
||||
and func.value.value.id == "sys"
|
||||
and func.value.attr == "path"
|
||||
and func.attr in ("insert", "append", "extend")
|
||||
):
|
||||
target_name = f"sys.path.{func.attr}"
|
||||
|
||||
if target_name is not None:
|
||||
call_src = ast.unparse(node)
|
||||
# Match both the string-literal form
|
||||
# ``.../plugins/platforms/...`` and the Path-operator form
|
||||
# ``Path(...) / 'plugins' / 'platforms' / ...`` that
|
||||
# plugin tests typically use.
|
||||
_src_no_ws = "".join(call_src.split())
|
||||
if (
|
||||
"plugins/platforms" in call_src
|
||||
or "plugins\\platforms" in call_src
|
||||
or "'plugins'/'platforms'" in _src_no_ws
|
||||
or '"plugins"/"platforms"' in _src_no_ws
|
||||
):
|
||||
offenses.append(
|
||||
f"line {node.lineno}: {target_name}(...) points into "
|
||||
f"plugins/platforms/"
|
||||
)
|
||||
|
||||
# Bare `import adapter` / `from adapter import ...` anywhere (module level
|
||||
# OR inside functions — both are symptoms of the same pattern).
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name == "adapter":
|
||||
offenses.append(
|
||||
f"line {node.lineno}: ``import adapter`` "
|
||||
f"(bare — resolves to whichever plugin's adapter.py "
|
||||
f"is first on sys.path)"
|
||||
)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module == "adapter" and node.level == 0:
|
||||
offenses.append(
|
||||
f"line {node.lineno}: ``from adapter import ...`` "
|
||||
f"(bare — resolves to whichever plugin's adapter.py "
|
||||
f"is first on sys.path)"
|
||||
)
|
||||
|
||||
return offenses
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Reject plugin-adapter tests that use the sys.path anti-pattern.
|
||||
|
||||
Runs once per pytest session on the controller, BEFORE any xdist
|
||||
worker is spawned. If any file under ``tests/gateway/`` matches the
|
||||
anti-pattern, we fail the whole session with a clear message —
|
||||
before a polluted ``sys.path`` can cascade across workers.
|
||||
"""
|
||||
# Only run on the xdist controller (or in non-xdist runs). Skip on
|
||||
# worker subprocesses so we don't scan the filesystem N times.
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
violations: list[str] = []
|
||||
for path in _GATEWAY_DIR.rglob("test_*.py"):
|
||||
if path.name in {"_plugin_adapter_loader.py", "conftest.py"}:
|
||||
continue
|
||||
try:
|
||||
source = path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
continue
|
||||
if "adapter" not in source and "plugins/platforms" not in source:
|
||||
continue
|
||||
offenses = _scan_for_plugin_adapter_antipattern(source)
|
||||
if offenses:
|
||||
violations.append(
|
||||
f" {path.relative_to(_GATEWAY_DIR.parent.parent)}:\n "
|
||||
+ "\n ".join(offenses)
|
||||
)
|
||||
|
||||
if violations:
|
||||
raise pytest.UsageError(
|
||||
"Plugin-adapter-import anti-pattern detected in gateway tests:\n"
|
||||
+ "\n".join(violations)
|
||||
+ "\n\n"
|
||||
+ _GUARD_HINT
|
||||
)
|
||||
|
||||
|
||||
@@ -170,6 +170,22 @@ class TestAgentConfigSignature:
|
||||
)
|
||||
assert sig_a == sig_b
|
||||
|
||||
def test_tool_registry_generation_change_busts_cache(self):
|
||||
"""MCP reloads mutate the tool registry, so cached agents must rebuild."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "k", "base_url": "u", "provider": "p"}
|
||||
sig_before = GatewayRunner._agent_config_signature(
|
||||
"m", runtime, ["telegram"], "",
|
||||
cache_keys={"tools.registry_generation": 10},
|
||||
)
|
||||
sig_after = GatewayRunner._agent_config_signature(
|
||||
"m", runtime, ["telegram"], "",
|
||||
cache_keys={"tools.registry_generation": 11},
|
||||
)
|
||||
|
||||
assert sig_before != sig_after
|
||||
|
||||
|
||||
class TestExtractCacheBustingConfig:
|
||||
"""Verify _extract_cache_busting_config pulls the documented subset of
|
||||
@@ -229,6 +245,17 @@ class TestExtractCacheBustingConfig:
|
||||
out = GatewayRunner._extract_cache_busting_config(None)
|
||||
for section, key in GatewayRunner._CACHE_BUSTING_CONFIG_KEYS:
|
||||
assert out[f"{section}.{key}"] is None
|
||||
assert "tools.registry_generation" in out
|
||||
|
||||
def test_extract_includes_live_tool_registry_generation(self, monkeypatch):
|
||||
from gateway.run import GatewayRunner
|
||||
from tools.registry import registry
|
||||
|
||||
monkeypatch.setattr(registry, "_generation", 12345)
|
||||
|
||||
out = GatewayRunner._extract_cache_busting_config({})
|
||||
|
||||
assert out["tools.registry_generation"] == 12345
|
||||
|
||||
def test_full_round_trip_busts_cache_on_real_edit(self):
|
||||
"""End-to-end: simulate a config edit on main and verify the
|
||||
|
||||
@@ -314,6 +314,7 @@ def _create_app(adapter: APIServerAdapter) -> web.Application:
|
||||
app.router.add_get("/health/detailed", adapter._handle_health_detailed)
|
||||
app.router.add_get("/v1/health", adapter._handle_health)
|
||||
app.router.add_get("/v1/models", adapter._handle_models)
|
||||
app.router.add_get("/v1/capabilities", adapter._handle_capabilities)
|
||||
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
|
||||
app.router.add_post("/v1/responses", adapter._handle_responses)
|
||||
app.router.add_get("/v1/responses/{response_id}", adapter._handle_get_response)
|
||||
@@ -491,6 +492,46 @@ class TestModelsEndpoint:
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /v1/capabilities endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCapabilitiesEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_capabilities_advertises_plugin_safe_contract(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/v1/capabilities")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["object"] == "hermes.api_server.capabilities"
|
||||
assert data["platform"] == "hermes-agent"
|
||||
assert data["model"] == "hermes-agent"
|
||||
assert data["auth"]["type"] == "bearer"
|
||||
assert data["auth"]["required"] is False
|
||||
assert data["features"]["chat_completions"] is True
|
||||
assert data["features"]["run_status"] is True
|
||||
assert data["features"]["run_events_sse"] is True
|
||||
assert data["features"]["session_continuity_header"] == "X-Hermes-Session-Id"
|
||||
assert data["endpoints"]["run_status"]["path"] == "/v1/runs/{run_id}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capabilities_requires_auth_when_key_configured(self, auth_adapter):
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/v1/capabilities")
|
||||
assert resp.status == 401
|
||||
|
||||
authed = await cli.get(
|
||||
"/v1/capabilities",
|
||||
headers={"Authorization": "Bearer sk-secret"},
|
||||
)
|
||||
assert authed.status == 200
|
||||
data = await authed.json()
|
||||
assert data["auth"]["required"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /v1/chat/completions endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -647,17 +688,17 @@ class TestChatCompletionsEndpoint:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_includes_tool_progress(self, adapter):
|
||||
"""tool_progress_callback fires → progress appears as custom SSE event, not in delta.content."""
|
||||
"""tool_start_callback fires → progress appears as custom SSE event, not in delta.content."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
# Simulate tool progress before streaming content
|
||||
if tp_cb:
|
||||
tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"})
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
# Simulate the structured tool start the gateway now consumes.
|
||||
if ts_cb:
|
||||
ts_cb("call_terminal_1", "terminal", {"command": "ls -la"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Here are the files.")
|
||||
@@ -683,7 +724,10 @@ class TestChatCompletionsEndpoint:
|
||||
# markers instead of calling tools (#6972).
|
||||
assert "event: hermes.tool.progress" in body
|
||||
assert '"tool": "terminal"' in body
|
||||
assert '"label": "ls -la"' in body
|
||||
# ``label`` is now derived by ``build_tool_preview`` from the
|
||||
# tool args rather than passed by the caller, so we assert
|
||||
# only that *some* label exists rather than a literal value.
|
||||
assert '"label":' in body
|
||||
# The progress marker must NOT appear inside any
|
||||
# chat.completion.chunk delta.content field.
|
||||
import json as _json
|
||||
@@ -703,17 +747,17 @@ class TestChatCompletionsEndpoint:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_progress_skips_internal_events(self, adapter):
|
||||
"""Internal events (name starting with _) are not streamed."""
|
||||
"""Internal tool calls (name starting with ``_``) are not streamed."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
if tp_cb:
|
||||
tp_cb("tool.started", "_thinking", "some internal state", {})
|
||||
tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"})
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
if ts_cb:
|
||||
ts_cb("call_internal_1", "_thinking", {"text": "some internal state"})
|
||||
ts_cb("call_search_1", "web_search", {"query": "Python docs"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Found it.")
|
||||
@@ -735,10 +779,142 @@ class TestChatCompletionsEndpoint:
|
||||
body = await resp.text()
|
||||
# Internal _thinking event should NOT appear anywhere
|
||||
assert "some internal state" not in body
|
||||
assert "call_internal_1" not in body
|
||||
# Real tool progress should appear as custom SSE event
|
||||
assert "event: hermes.tool.progress" in body
|
||||
assert '"tool": "web_search"' in body
|
||||
assert '"label": "Python docs"' in body
|
||||
# Label is derived from the args dict by build_tool_preview;
|
||||
# asserting on the structural fact (label exists, call id
|
||||
# is correlated) rather than a literal preview string keeps
|
||||
# the test robust against preview-formatter tweaks.
|
||||
assert '"label":' in body
|
||||
assert '"toolCallId": "call_search_1"' in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_emits_tool_lifecycle_with_call_id(self, adapter):
|
||||
"""Regression for #16588.
|
||||
|
||||
``/v1/chat/completions`` streaming previously emitted only a
|
||||
``tool.started``-style ``hermes.tool.progress`` event; clients
|
||||
rendering tool lifecycle UI had no way to mark a tool as finished
|
||||
because no matching ``status: completed`` event was emitted, and
|
||||
no ``toolCallId`` was carried for correlation.
|
||||
|
||||
The fix adds ``tool_start_callback`` / ``tool_complete_callback``
|
||||
to the chat completions agent invocation and writes both halves
|
||||
of the lifecycle pair on the same ``event: hermes.tool.progress``
|
||||
SSE line, with stable ``toolCallId`` and ``status``.
|
||||
"""
|
||||
import asyncio
|
||||
import json as _json
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
tc_cb = kwargs.get("tool_complete_callback")
|
||||
# The structured callbacks own the chat-completions SSE
|
||||
# channel now; ``tool_progress_callback`` is intentionally
|
||||
# not wired so each tool start emits exactly one event.
|
||||
if ts_cb:
|
||||
ts_cb("call_terminal_1", "terminal", {"command": "ls -la"})
|
||||
if tc_cb:
|
||||
tc_cb("call_terminal_1", "terminal", {"command": "ls -la"}, "ok")
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("done.")
|
||||
return (
|
||||
{"final_response": "done.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "list"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
|
||||
# Walk the SSE body and collect *(status, toolCallId)* pairs
|
||||
# per event so the assertions verify per-event correlation —
|
||||
# an event missing ``toolCallId`` would not pass even if a
|
||||
# different event happens to carry the right id.
|
||||
pairs: list[tuple[str | None, str | None]] = []
|
||||
lines = body.splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip() != "event: hermes.tool.progress":
|
||||
continue
|
||||
for follow in lines[i + 1: i + 4]:
|
||||
if follow.startswith("data: "):
|
||||
try:
|
||||
payload = _json.loads(follow[len("data: "):])
|
||||
except _json.JSONDecodeError:
|
||||
break
|
||||
pairs.append((payload.get("status"), payload.get("toolCallId")))
|
||||
break
|
||||
|
||||
# Each tool start must emit exactly one event (no duplicate
|
||||
# legacy + new emit), and each lifecycle pair must carry the
|
||||
# same toolCallId on every event — not just somewhere in the
|
||||
# aggregate.
|
||||
assert len(pairs) == 2, f"expected 2 events (running+completed), got {pairs}"
|
||||
assert pairs[0] == ("running", "call_terminal_1"), pairs
|
||||
assert pairs[1] == ("completed", "call_terminal_1"), pairs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_lifecycle_skips_internal_and_orphan_completes(self, adapter):
|
||||
"""Internal tools (``_thinking``-style) and ``completed`` events
|
||||
without a prior matching ``running`` must produce no lifecycle
|
||||
events on the wire — otherwise clients would see orphaned
|
||||
``status: completed`` updates they cannot correlate."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
tc_cb = kwargs.get("tool_complete_callback")
|
||||
# Internal tool — must be filtered.
|
||||
if ts_cb:
|
||||
ts_cb("call_internal_1", "_thinking", {})
|
||||
if tc_cb:
|
||||
tc_cb("call_internal_1", "_thinking", {}, "")
|
||||
# Completion without start — orphan, must be dropped.
|
||||
if tc_cb:
|
||||
tc_cb("call_orphan_1", "web_search", {}, "ok")
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("ok.")
|
||||
return (
|
||||
{"final_response": "ok.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "ok"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
|
||||
# Neither the internal call_id nor the orphan call_id should
|
||||
# surface as a lifecycle payload on the wire.
|
||||
assert "call_internal_1" not in body
|
||||
assert "call_orphan_1" not in body
|
||||
assert '"status": "running"' not in body
|
||||
assert '"status": "completed"' not in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(self, adapter):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Tests for /v1/runs endpoints: start, events, and stop.
|
||||
"""Tests for /v1/runs endpoints: start, status, events, and stop.
|
||||
|
||||
Covers:
|
||||
- POST /v1/runs — start a run (202)
|
||||
- GET /v1/runs/{run_id} — poll run status
|
||||
- GET /v1/runs/{run_id}/events — SSE event stream
|
||||
- POST /v1/runs/{run_id}/stop — interrupt a running agent
|
||||
- Auth, error handling, and cleanup
|
||||
@@ -46,6 +47,7 @@ def _create_runs_app(adapter: APIServerAdapter) -> web.Application:
|
||||
app = web.Application(middlewares=mws)
|
||||
app["api_server_adapter"] = adapter
|
||||
app.router.add_post("/v1/runs", adapter._handle_runs)
|
||||
app.router.add_get("/v1/runs/{run_id}", adapter._handle_get_run)
|
||||
app.router.add_get("/v1/runs/{run_id}/events", adapter._handle_run_events)
|
||||
app.router.add_post("/v1/runs/{run_id}/stop", adapter._handle_stop_run)
|
||||
return app
|
||||
@@ -116,6 +118,13 @@ class TestStartRun:
|
||||
assert data["status"] == "started"
|
||||
assert data["run_id"].startswith("run_")
|
||||
|
||||
status_resp = await cli.get(f"/v1/runs/{data['run_id']}")
|
||||
assert status_resp.status == 200
|
||||
status = await status_resp.json()
|
||||
assert status["run_id"] == data["run_id"]
|
||||
assert status["status"] in {"queued", "running", "completed"}
|
||||
assert status["object"] == "hermes.run"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_invalid_json_returns_400(self, adapter):
|
||||
app = _create_runs_app(adapter)
|
||||
@@ -143,6 +152,18 @@ class TestStartRun:
|
||||
resp = await cli.post("/v1/runs", json={"input": ""})
|
||||
assert resp.status == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_invalid_history_does_not_allocate_run(self, adapter):
|
||||
app = _create_runs_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/v1/runs",
|
||||
json={"input": "hello", "conversation_history": {"role": "user"}},
|
||||
)
|
||||
assert resp.status == 400
|
||||
assert adapter._run_streams == {}
|
||||
assert adapter._run_statuses == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_requires_auth(self, auth_adapter):
|
||||
app = _create_runs_app(auth_adapter)
|
||||
@@ -170,6 +191,89 @@ class TestStartRun:
|
||||
assert resp.status == 202
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /v1/runs/{run_id} — poll run status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_completed_run_includes_output_and_usage(self, adapter):
|
||||
app = _create_runs_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_create_agent") as mock_create:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.return_value = {"final_response": "done"}
|
||||
mock_agent.session_prompt_tokens = 4
|
||||
mock_agent.session_completion_tokens = 2
|
||||
mock_agent.session_total_tokens = 6
|
||||
mock_create.return_value = mock_agent
|
||||
|
||||
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||||
data = await resp.json()
|
||||
run_id = data["run_id"]
|
||||
|
||||
for _ in range(20):
|
||||
status_resp = await cli.get(f"/v1/runs/{run_id}")
|
||||
assert status_resp.status == 200
|
||||
status = await status_resp.json()
|
||||
if status["status"] == "completed":
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert status["status"] == "completed"
|
||||
assert status["output"] == "done"
|
||||
assert status["usage"]["total_tokens"] == 6
|
||||
assert status["last_event"] == "run.completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_reflects_explicit_session_id(self, adapter):
|
||||
app = _create_runs_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_create_agent") as mock_create:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.return_value = {"final_response": "done"}
|
||||
mock_agent.session_prompt_tokens = 0
|
||||
mock_agent.session_completion_tokens = 0
|
||||
mock_agent.session_total_tokens = 0
|
||||
mock_create.return_value = mock_agent
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/runs",
|
||||
json={"input": "hello", "session_id": "space-session"},
|
||||
)
|
||||
data = await resp.json()
|
||||
run_id = data["run_id"]
|
||||
|
||||
for _ in range(20):
|
||||
status_resp = await cli.get(f"/v1/runs/{run_id}")
|
||||
status = await status_resp.json()
|
||||
if status["status"] == "completed":
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
mock_agent.run_conversation.assert_called_once()
|
||||
# task_id stays "default" so the Runs API shares one sandbox
|
||||
# container with CLI/gateway; session_id is surfaced in status
|
||||
# for external UIs to correlate runs with their own session IDs.
|
||||
assert mock_agent.run_conversation.call_args.kwargs["task_id"] == "default"
|
||||
assert status["session_id"] == "space-session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_not_found_returns_404(self, adapter):
|
||||
app = _create_runs_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/v1/runs/run_nonexistent")
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_requires_auth(self, auth_adapter):
|
||||
app = _create_runs_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/v1/runs/run_any")
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /v1/runs/{run_id}/events — SSE event stream
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -257,6 +361,11 @@ class TestStopRun:
|
||||
# Agent interrupt should have been called
|
||||
mock_agent.interrupt.assert_called_once_with("Stop requested via API")
|
||||
|
||||
status_resp = await cli.get(f"/v1/runs/{run_id}")
|
||||
assert status_resp.status == 200
|
||||
status_data = await status_resp.json()
|
||||
assert status_data["status"] in {"stopping", "cancelled"}
|
||||
|
||||
# Refs should be cleaned up
|
||||
await asyncio.sleep(0.5)
|
||||
assert run_id not in adapter._active_run_agents
|
||||
|
||||
@@ -33,6 +33,11 @@ def _simulate_config_bridge(cfg: dict, initial_env: dict | None = None):
|
||||
"backend": "TERMINAL_ENV",
|
||||
"cwd": "TERMINAL_CWD",
|
||||
"timeout": "TERMINAL_TIMEOUT",
|
||||
"vercel_runtime": "TERMINAL_VERCEL_RUNTIME",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"container_cpu": "TERMINAL_CONTAINER_CPU",
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
}
|
||||
for cfg_key, env_var in terminal_env_map.items():
|
||||
if cfg_key in terminal_cfg:
|
||||
@@ -240,3 +245,24 @@ class TestTildeExpansion:
|
||||
}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == os.path.expanduser("~/nested")
|
||||
|
||||
|
||||
class TestVercelTerminalBridge:
|
||||
def test_vercel_terminal_settings_bridge(self):
|
||||
cfg = {
|
||||
"terminal": {
|
||||
"backend": "vercel_sandbox",
|
||||
"vercel_runtime": "python3.13",
|
||||
"container_persistent": True,
|
||||
"container_cpu": 2,
|
||||
"container_memory": 4096,
|
||||
"container_disk": 51200,
|
||||
}
|
||||
}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/from/env"})
|
||||
assert result["TERMINAL_ENV"] == "vercel_sandbox"
|
||||
assert result["TERMINAL_VERCEL_RUNTIME"] == "python3.13"
|
||||
assert result["TERMINAL_CONTAINER_PERSISTENT"] == "True"
|
||||
assert result["TERMINAL_CONTAINER_CPU"] == "2"
|
||||
assert result["TERMINAL_CONTAINER_MEMORY"] == "4096"
|
||||
assert result["TERMINAL_CONTAINER_DISK"] == "51200"
|
||||
|
||||
502
tests/gateway/test_irc_adapter.py
Normal file
502
tests/gateway/test_irc_adapter.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""Tests for the IRC platform adapter plugin."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from tests.gateway._plugin_adapter_loader import load_plugin_adapter
|
||||
|
||||
# Load plugins/platforms/irc/adapter.py under a unique module name
|
||||
# (plugin_adapter_irc) so it cannot collide with other plugin adapters
|
||||
# loaded by sibling tests in the same xdist worker.
|
||||
_irc_mod = load_plugin_adapter("irc")
|
||||
|
||||
_parse_irc_message = _irc_mod._parse_irc_message
|
||||
_extract_nick = _irc_mod._extract_nick
|
||||
IRCAdapter = _irc_mod.IRCAdapter
|
||||
check_requirements = _irc_mod.check_requirements
|
||||
validate_config = _irc_mod.validate_config
|
||||
register = _irc_mod.register
|
||||
|
||||
|
||||
class TestIRCProtocolHelpers:
|
||||
|
||||
def test_parse_simple_command(self):
|
||||
msg = _parse_irc_message("PING :server.example.com")
|
||||
assert msg["command"] == "PING"
|
||||
assert msg["params"] == ["server.example.com"]
|
||||
assert msg["prefix"] == ""
|
||||
|
||||
def test_parse_prefixed_message(self):
|
||||
msg = _parse_irc_message(":nick!user@host PRIVMSG #channel :Hello world")
|
||||
assert msg["prefix"] == "nick!user@host"
|
||||
assert msg["command"] == "PRIVMSG"
|
||||
assert msg["params"] == ["#channel", "Hello world"]
|
||||
|
||||
def test_parse_numeric_reply(self):
|
||||
msg = _parse_irc_message(":server 001 hermes-bot :Welcome to IRC")
|
||||
assert msg["prefix"] == "server"
|
||||
assert msg["command"] == "001"
|
||||
assert msg["params"] == ["hermes-bot", "Welcome to IRC"]
|
||||
|
||||
def test_parse_nick_collision(self):
|
||||
msg = _parse_irc_message(":server 433 * hermes-bot :Nickname is already in use")
|
||||
assert msg["command"] == "433"
|
||||
|
||||
def test_extract_nick_full_prefix(self):
|
||||
assert _extract_nick("nick!user@host") == "nick"
|
||||
|
||||
def test_extract_nick_bare(self):
|
||||
assert _extract_nick("server.example.com") == "server.example.com"
|
||||
|
||||
|
||||
# ── IRC Adapter ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIRCAdapterInit:
|
||||
|
||||
def test_init_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.test.net")
|
||||
monkeypatch.setenv("IRC_PORT", "6667")
|
||||
monkeypatch.setenv("IRC_NICKNAME", "testbot")
|
||||
monkeypatch.setenv("IRC_CHANNEL", "#test")
|
||||
monkeypatch.setenv("IRC_USE_TLS", "false")
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
adapter = IRCAdapter(cfg)
|
||||
|
||||
assert adapter.server == "irc.test.net"
|
||||
assert adapter.port == 6667
|
||||
assert adapter.nickname == "testbot"
|
||||
assert adapter.channel == "#test"
|
||||
assert adapter.use_tls is False
|
||||
|
||||
def test_init_from_config_extra(self, monkeypatch):
|
||||
# Clear any env vars
|
||||
for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server": "irc.libera.chat",
|
||||
"port": 6697,
|
||||
"nickname": "hermes",
|
||||
"channel": "#hermes-dev",
|
||||
"use_tls": True,
|
||||
},
|
||||
)
|
||||
adapter = IRCAdapter(cfg)
|
||||
|
||||
assert adapter.server == "irc.libera.chat"
|
||||
assert adapter.port == 6697
|
||||
assert adapter.nickname == "hermes"
|
||||
assert adapter.channel == "#hermes-dev"
|
||||
assert adapter.use_tls is True
|
||||
|
||||
def test_env_overrides_config(self, monkeypatch):
|
||||
monkeypatch.setenv("IRC_SERVER", "env-server.net")
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"server": "config-server.net", "channel": "#ch"},
|
||||
)
|
||||
adapter = IRCAdapter(cfg)
|
||||
assert adapter.server == "env-server.net"
|
||||
|
||||
|
||||
class TestIRCAdapterSend:
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, monkeypatch):
|
||||
for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server": "localhost",
|
||||
"port": 6667,
|
||||
"nickname": "testbot",
|
||||
"channel": "#test",
|
||||
"use_tls": False,
|
||||
},
|
||||
)
|
||||
return IRCAdapter(cfg)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_not_connected(self, adapter):
|
||||
result = await adapter.send("#test", "hello")
|
||||
assert result.success is False
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_success(self, adapter):
|
||||
writer = MagicMock()
|
||||
writer.is_closing = MagicMock(return_value=False)
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
adapter._writer = writer
|
||||
|
||||
result = await adapter.send("#test", "hello world")
|
||||
assert result.success is True
|
||||
assert result.message_id is not None
|
||||
# Verify PRIVMSG was sent
|
||||
writer.write.assert_called()
|
||||
sent_data = writer.write.call_args[0][0]
|
||||
assert b"PRIVMSG #test :hello world" in sent_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_splits_long_messages(self, adapter):
|
||||
writer = MagicMock()
|
||||
writer.is_closing = MagicMock(return_value=False)
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
adapter._writer = writer
|
||||
|
||||
long_msg = "x" * 1000
|
||||
result = await adapter.send("#test", long_msg)
|
||||
assert result.success is True
|
||||
# Should have been split into multiple PRIVMSG calls
|
||||
assert writer.write.call_count > 1
|
||||
|
||||
|
||||
class TestIRCAdapterMessageParsing:
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, monkeypatch):
|
||||
for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server": "localhost",
|
||||
"port": 6667,
|
||||
"nickname": "hermes",
|
||||
"channel": "#test",
|
||||
"use_tls": False,
|
||||
},
|
||||
)
|
||||
a = IRCAdapter(cfg)
|
||||
a._current_nick = "hermes"
|
||||
a._registered = True
|
||||
return a
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_ping(self, adapter):
|
||||
writer = MagicMock()
|
||||
writer.is_closing = MagicMock(return_value=False)
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
adapter._writer = writer
|
||||
|
||||
await adapter._handle_line("PING :test-server")
|
||||
sent = writer.write.call_args[0][0]
|
||||
assert b"PONG :test-server" in sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_welcome(self, adapter):
|
||||
adapter._registered = False
|
||||
adapter._registration_event = asyncio.Event()
|
||||
|
||||
await adapter._handle_line(":server 001 hermes :Welcome to IRC")
|
||||
assert adapter._registered is True
|
||||
assert adapter._registration_event.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_nick_collision(self, adapter):
|
||||
writer = MagicMock()
|
||||
writer.is_closing = MagicMock(return_value=False)
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
adapter._writer = writer
|
||||
|
||||
await adapter._handle_line(":server 433 * hermes :Nickname in use")
|
||||
assert adapter._current_nick == "hermes_"
|
||||
sent = writer.write.call_args[0][0]
|
||||
assert b"NICK hermes_" in sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_addressed_channel_message(self, adapter):
|
||||
"""Messages addressed to the bot (nick: msg) should be dispatched."""
|
||||
handler = AsyncMock(return_value="response")
|
||||
adapter._message_handler = handler
|
||||
|
||||
# Mock handle_message to capture the event
|
||||
dispatched = []
|
||||
original_dispatch = adapter._dispatch_message
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
|
||||
await adapter._handle_line(":user!u@host PRIVMSG #test :hermes: hello there")
|
||||
assert len(dispatched) == 1
|
||||
assert dispatched[0]["text"] == "hello there"
|
||||
assert dispatched[0]["chat_id"] == "#test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_unaddressed_channel_message(self, adapter):
|
||||
dispatched = []
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
adapter._message_handler = AsyncMock()
|
||||
|
||||
await adapter._handle_line(":user!u@host PRIVMSG #test :just talking")
|
||||
assert len(dispatched) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_dm(self, adapter):
|
||||
"""DMs (target == bot nick) should always be dispatched."""
|
||||
dispatched = []
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
adapter._message_handler = AsyncMock()
|
||||
|
||||
await adapter._handle_line(":user!u@host PRIVMSG hermes :private message")
|
||||
assert len(dispatched) == 1
|
||||
assert dispatched[0]["text"] == "private message"
|
||||
assert dispatched[0]["chat_type"] == "dm"
|
||||
assert dispatched[0]["chat_id"] == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_own_messages(self, adapter):
|
||||
dispatched = []
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
adapter._message_handler = AsyncMock()
|
||||
|
||||
await adapter._handle_line(":hermes!bot@host PRIVMSG #test :my own msg")
|
||||
assert len(dispatched) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ctcp_action_converted(self, adapter):
|
||||
"""CTCP ACTION (/me) should be converted to text."""
|
||||
dispatched = []
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
adapter._message_handler = AsyncMock()
|
||||
|
||||
await adapter._handle_line(":user!u@host PRIVMSG hermes :\x01ACTION waves\x01")
|
||||
assert len(dispatched) == 1
|
||||
assert dispatched[0]["text"] == "* user waves"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_users_case_insensitive(self, monkeypatch):
|
||||
"""Allowlist should match nicks case-insensitively."""
|
||||
for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server": "localhost",
|
||||
"port": 6667,
|
||||
"nickname": "hermes",
|
||||
"channel": "#test",
|
||||
"use_tls": False,
|
||||
"allowed_users": ["Admin", "BOB"],
|
||||
},
|
||||
)
|
||||
adapter = IRCAdapter(cfg)
|
||||
adapter._current_nick = "hermes"
|
||||
adapter._registered = True
|
||||
dispatched = []
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
adapter._message_handler = AsyncMock()
|
||||
|
||||
# "admin" matches "Admin" in allowlist
|
||||
await adapter._handle_line(":admin!u@host PRIVMSG #test :hermes: hello")
|
||||
assert len(dispatched) == 1
|
||||
assert dispatched[0]["text"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_blocked(self, monkeypatch):
|
||||
"""Nicks not in allowlist should be ignored."""
|
||||
for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server": "localhost",
|
||||
"port": 6667,
|
||||
"nickname": "hermes",
|
||||
"channel": "#test",
|
||||
"use_tls": False,
|
||||
"allowed_users": ["Admin", "BOB"],
|
||||
},
|
||||
)
|
||||
adapter = IRCAdapter(cfg)
|
||||
adapter._current_nick = "hermes"
|
||||
adapter._registered = True
|
||||
dispatched = []
|
||||
|
||||
async def capture_dispatch(**kwargs):
|
||||
dispatched.append(kwargs)
|
||||
|
||||
adapter._dispatch_message = capture_dispatch
|
||||
adapter._message_handler = AsyncMock()
|
||||
|
||||
await adapter._handle_line(":eve!u@host PRIVMSG #test :hermes: hello")
|
||||
assert len(dispatched) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nick_collision_retry(self, adapter):
|
||||
"""Multiple 433 responses should keep incrementing the suffix."""
|
||||
writer = MagicMock()
|
||||
writer.is_closing = MagicMock(return_value=False)
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
adapter._writer = writer
|
||||
|
||||
await adapter._handle_line(":server 433 * hermes :Nickname in use")
|
||||
assert adapter._current_nick == "hermes_"
|
||||
await adapter._handle_line(":server 433 * hermes_ :Nickname in use")
|
||||
assert adapter._current_nick == "hermes_1"
|
||||
await adapter._handle_line(":server 433 * hermes_1 :Nickname in use")
|
||||
assert adapter._current_nick == "hermes_2"
|
||||
|
||||
|
||||
class TestIRCAdapterSplitting:
|
||||
|
||||
def test_split_respects_byte_limit(self):
|
||||
"""Multi-byte characters should not exceed IRC byte limit."""
|
||||
# 100 japanese chars = 300 bytes in utf-8
|
||||
text = "あ" * 100
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"server": "x", "channel": "#x"})
|
||||
adapter = IRCAdapter(cfg)
|
||||
adapter._current_nick = "bot"
|
||||
lines = adapter._split_message(text, "#test")
|
||||
for line in lines:
|
||||
overhead = len(f"PRIVMSG #test :{line}\r\n".encode("utf-8"))
|
||||
assert overhead <= 512, f"line over 512 bytes: {overhead}"
|
||||
|
||||
def test_split_prefers_word_boundary(self):
|
||||
text = "hello world foo bar baz qux"
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(enabled=True, extra={"server": "x", "channel": "#x"})
|
||||
adapter = IRCAdapter(cfg)
|
||||
adapter._current_nick = "bot"
|
||||
lines = adapter._split_message(text, "#test")
|
||||
# Should not split in the middle of "world"
|
||||
assert any("hello" in ln for ln in lines)
|
||||
assert any("world" in ln for ln in lines)
|
||||
|
||||
|
||||
class TestIRCProtocolHelpersExtra:
|
||||
|
||||
def test_parse_malformed_no_space(self):
|
||||
"""A line starting with : but no space should not crash."""
|
||||
msg = _parse_irc_message(":justaprefix")
|
||||
assert msg["prefix"] == "justaprefix"
|
||||
assert msg["command"] == ""
|
||||
assert msg["params"] == []
|
||||
|
||||
def test_parse_empty(self):
|
||||
msg = _parse_irc_message("")
|
||||
assert msg["prefix"] == ""
|
||||
assert msg["command"] == ""
|
||||
assert msg["params"] == []
|
||||
|
||||
|
||||
class TestIRCAdapterMarkdown:
|
||||
|
||||
def test_strip_bold(self):
|
||||
assert IRCAdapter._strip_markdown("**bold**") == "bold"
|
||||
|
||||
def test_strip_italic(self):
|
||||
assert IRCAdapter._strip_markdown("*italic*") == "italic"
|
||||
|
||||
def test_strip_code(self):
|
||||
assert IRCAdapter._strip_markdown("`code`") == "code"
|
||||
|
||||
def test_strip_link(self):
|
||||
result = IRCAdapter._strip_markdown("[click here](https://example.com)")
|
||||
assert result == "click here (https://example.com)"
|
||||
|
||||
def test_strip_image(self):
|
||||
result = IRCAdapter._strip_markdown("")
|
||||
assert result == "https://example.com/img.png"
|
||||
|
||||
|
||||
# ── Requirements / validation ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIRCRequirements:
|
||||
|
||||
def test_check_requirements_with_env(self, monkeypatch):
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.test.net")
|
||||
monkeypatch.setenv("IRC_CHANNEL", "#test")
|
||||
assert check_requirements() is True
|
||||
|
||||
def test_check_requirements_missing_server(self, monkeypatch):
|
||||
monkeypatch.delenv("IRC_SERVER", raising=False)
|
||||
monkeypatch.setenv("IRC_CHANNEL", "#test")
|
||||
assert check_requirements() is False
|
||||
|
||||
def test_check_requirements_missing_channel(self, monkeypatch):
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.test.net")
|
||||
monkeypatch.delenv("IRC_CHANNEL", raising=False)
|
||||
assert check_requirements() is False
|
||||
|
||||
def test_validate_config_from_extra(self, monkeypatch):
|
||||
for key in ("IRC_SERVER", "IRC_CHANNEL"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(extra={"server": "irc.test.net", "channel": "#test"})
|
||||
assert validate_config(cfg) is True
|
||||
|
||||
def test_validate_config_missing(self, monkeypatch):
|
||||
for key in ("IRC_SERVER", "IRC_CHANNEL"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
from gateway.config import PlatformConfig
|
||||
cfg = PlatformConfig(extra={})
|
||||
assert validate_config(cfg) is False
|
||||
|
||||
|
||||
# ── Plugin registration ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIRCPluginRegistration:
|
||||
"""Test the register() entry point."""
|
||||
|
||||
def test_register_adds_to_registry(self, monkeypatch):
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.test.net")
|
||||
monkeypatch.setenv("IRC_CHANNEL", "#test")
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
|
||||
# Clean up if already registered
|
||||
platform_registry.unregister("irc")
|
||||
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
ctx.register_platform.assert_called_once()
|
||||
call_kwargs = ctx.register_platform.call_args
|
||||
assert call_kwargs[1]["name"] == "irc" or call_kwargs[0][0] == "irc" if call_kwargs[0] else call_kwargs[1]["name"] == "irc"
|
||||
@@ -1276,9 +1276,10 @@ class TestMatrixUploadAndSend:
|
||||
mock_client.send_message_event = AsyncMock(return_value="$event")
|
||||
adapter._client = mock_client
|
||||
|
||||
result = await adapter._upload_and_send(
|
||||
"!room:example.org", b"secret", "secret.txt", "text/plain", "m.file",
|
||||
)
|
||||
with patch.dict("sys.modules", _make_fake_mautrix()):
|
||||
result = await adapter._upload_and_send(
|
||||
"!room:example.org", b"secret", "secret.txt", "text/plain", "m.file",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
# Should have uploaded ciphertext, not plaintext
|
||||
|
||||
@@ -323,6 +323,55 @@ class TestExtractMedia:
|
||||
assert "Here" in cleaned
|
||||
assert "After" in cleaned
|
||||
|
||||
def test_media_tag_supports_unquoted_flac_paths_with_spaces(self):
|
||||
content = "MEDIA:/tmp/Jane Doe/speech.flac"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert media == [("/tmp/Jane Doe/speech.flac", False)]
|
||||
assert cleaned == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# should_send_media_as_audio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShouldSendMediaAsAudio:
|
||||
"""Audio-routing policy shared by gateway + scheduler + send_message."""
|
||||
|
||||
def test_unknown_extension_returns_false(self):
|
||||
from gateway.platforms.base import should_send_media_as_audio
|
||||
assert should_send_media_as_audio(None, ".png") is False
|
||||
assert should_send_media_as_audio("telegram", ".pdf") is False
|
||||
|
||||
def test_non_telegram_platforms_route_all_audio(self):
|
||||
from gateway.platforms.base import should_send_media_as_audio
|
||||
for ext in (".mp3", ".m4a", ".wav", ".flac", ".ogg", ".opus"):
|
||||
assert should_send_media_as_audio("discord", ext) is True
|
||||
assert should_send_media_as_audio("slack", ext) is True
|
||||
|
||||
def test_telegram_mp3_and_m4a_route_to_audio(self):
|
||||
from gateway.platforms.base import should_send_media_as_audio
|
||||
assert should_send_media_as_audio("telegram", ".mp3") is True
|
||||
assert should_send_media_as_audio("telegram", ".m4a") is True
|
||||
|
||||
def test_telegram_wav_and_flac_fall_through_to_document(self):
|
||||
from gateway.platforms.base import should_send_media_as_audio
|
||||
assert should_send_media_as_audio("telegram", ".wav") is False
|
||||
assert should_send_media_as_audio("telegram", ".flac") is False
|
||||
|
||||
def test_telegram_ogg_opus_only_when_voice_flagged(self):
|
||||
from gateway.platforms.base import should_send_media_as_audio
|
||||
assert should_send_media_as_audio("telegram", ".ogg", is_voice=True) is True
|
||||
assert should_send_media_as_audio("telegram", ".opus", is_voice=True) is True
|
||||
assert should_send_media_as_audio("telegram", ".ogg") is False
|
||||
assert should_send_media_as_audio("telegram", ".opus") is False
|
||||
|
||||
def test_accepts_platform_enum(self):
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import should_send_media_as_audio
|
||||
assert should_send_media_as_audio(Platform.TELEGRAM, ".mp3") is True
|
||||
assert should_send_media_as_audio(Platform.TELEGRAM, ".flac") is False
|
||||
assert should_send_media_as_audio(Platform.DISCORD, ".flac") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# truncate_message
|
||||
|
||||
99
tests/gateway/test_platform_connected_checkers.py
Normal file
99
tests/gateway/test_platform_connected_checkers.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Verify that every gateway platform — built-in and plugin — has a connection
|
||||
checker so ``GatewayConfig.get_connected_platforms()`` doesn't silently drop
|
||||
platforms with bespoke auth requirements.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, _PLATFORM_CONNECTED_CHECKERS, _BUILTIN_PLATFORM_VALUES
|
||||
|
||||
|
||||
def test_all_builtins_have_checker_or_generic_token_path():
|
||||
"""Every built-in Platform member must be reachable by either:
|
||||
|
||||
1. The generic ``config.token or config.api_key`` check, OR
|
||||
2. A platform-specific entry in ``_PLATFORM_CONNECTED_CHECKERS``.
|
||||
|
||||
This guarantees ``get_connected_platforms()`` doesn't silently ignore
|
||||
a built-in just because nobody added it to the checker dict.
|
||||
"""
|
||||
# Platforms covered by the generic token/api_key branch
|
||||
generic_token_values = {p.value for p in {
|
||||
Platform.TELEGRAM,
|
||||
Platform.DISCORD,
|
||||
Platform.SLACK,
|
||||
Platform.MATRIX,
|
||||
Platform.MATTERMOST,
|
||||
Platform.HOMEASSISTANT,
|
||||
}}
|
||||
|
||||
# Platforms with a bespoke checker
|
||||
checker_values = {p.value for p in set(_PLATFORM_CONNECTED_CHECKERS.keys())}
|
||||
|
||||
# Every built-in should be in one of the two sets
|
||||
all_builtins = set(_BUILTIN_PLATFORM_VALUES)
|
||||
missing = all_builtins - generic_token_values - checker_values - {"local"}
|
||||
|
||||
assert not missing, (
|
||||
f"Built-in platforms missing a connection checker: "
|
||||
f"{sorted(missing)}. "
|
||||
f"Add them to _PLATFORM_CONNECTED_CHECKERS or generic_token_platforms."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform, checker", list(_PLATFORM_CONNECTED_CHECKERS.items()))
|
||||
def test_checker_handles_minimal_config(platform, checker):
|
||||
"""Each bespoke checker must not crash on a minimal PlatformConfig."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.extra = {}
|
||||
mock_config.token = None
|
||||
mock_config.api_key = None
|
||||
mock_config.enabled = True
|
||||
|
||||
# Should return a bool without raising
|
||||
result = checker(mock_config)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform, checker", list(_PLATFORM_CONNECTED_CHECKERS.items()))
|
||||
def test_checker_returns_true_when_configured(platform, checker, monkeypatch):
|
||||
"""Each bespoke checker must return True when the config looks valid."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.token = None
|
||||
mock_config.api_key = None
|
||||
mock_config.enabled = True
|
||||
|
||||
# Set up platform-specific mock extra fields so the checker succeeds
|
||||
if platform == Platform.WEIXIN:
|
||||
mock_config.extra = {"account_id": "123", "token": "***"}
|
||||
elif platform == Platform.SIGNAL:
|
||||
mock_config.extra = {"http_url": "http://signal:8080"}
|
||||
elif platform == Platform.EMAIL:
|
||||
mock_config.extra = {"address": "hermes@example.com"}
|
||||
elif platform == Platform.SMS:
|
||||
monkeypatch.setenv("TWILIO_ACCOUNT_SID", "ACtest")
|
||||
mock_config.extra = {}
|
||||
elif platform in (Platform.API_SERVER, Platform.WEBHOOK, Platform.WHATSAPP):
|
||||
mock_config.extra = {}
|
||||
elif platform == Platform.FEISHU:
|
||||
mock_config.extra = {"app_id": "app"}
|
||||
elif platform == Platform.WECOM:
|
||||
mock_config.extra = {"bot_id": "bot"}
|
||||
elif platform == Platform.WECOM_CALLBACK:
|
||||
mock_config.extra = {"corp_id": "corp"}
|
||||
elif platform == Platform.BLUEBUBBLES:
|
||||
mock_config.extra = {"server_url": "http://bb:1234", "password": "pw"}
|
||||
elif platform == Platform.QQBOT:
|
||||
mock_config.extra = {"app_id": "app", "client_secret": "sec"}
|
||||
elif platform == Platform.YUANBAO:
|
||||
mock_config.extra = {"app_id": "app", "app_secret": "sec"}
|
||||
elif platform == Platform.DINGTALK:
|
||||
mock_config.extra = {"client_id": "id", "client_secret": "sec"}
|
||||
else:
|
||||
pytest.skip(f"No synthetic config defined for {platform.value}")
|
||||
|
||||
result = checker(mock_config)
|
||||
assert result is True, f"{platform.value} checker should return True with valid-looking config"
|
||||
@@ -14,8 +14,15 @@ from gateway.run import GatewayRunner
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Adapter whose connect() result can be controlled."""
|
||||
|
||||
def __init__(self, *, succeed=True, fatal_error=None, fatal_retryable=True):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
platform=Platform.TELEGRAM,
|
||||
succeed=True,
|
||||
fatal_error=None,
|
||||
fatal_retryable=True,
|
||||
):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), platform)
|
||||
self._succeed = succeed
|
||||
self._fatal_error = fatal_error
|
||||
self._fatal_retryable = fatal_retryable
|
||||
@@ -65,6 +72,85 @@ def _make_runner():
|
||||
|
||||
# --- Startup queueing ---
|
||||
|
||||
class TestStartupPlatformIsolation:
|
||||
"""Verify one blocked platform cannot prevent later platforms from starting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_continues_after_platform_connect_timeout(self, tmp_path):
|
||||
"""A timeout on Telegram should queue it and still connect Feishu."""
|
||||
runner = _make_runner()
|
||||
runner.config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="test"),
|
||||
Platform.FEISHU: PlatformConfig(enabled=True, token="test"),
|
||||
},
|
||||
sessions_dir=tmp_path,
|
||||
)
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner._suspend_stuck_loop_sessions = MagicMock(return_value=0)
|
||||
runner._update_runtime_status = MagicMock()
|
||||
runner._update_platform_runtime_status = MagicMock()
|
||||
runner._sync_voice_mode_state_to_adapter = MagicMock()
|
||||
runner._send_update_notification = AsyncMock(return_value=True)
|
||||
runner._send_restart_notification = AsyncMock()
|
||||
|
||||
adapters = {
|
||||
Platform.TELEGRAM: StubAdapter(platform=Platform.TELEGRAM),
|
||||
Platform.FEISHU: StubAdapter(platform=Platform.FEISHU),
|
||||
}
|
||||
runner._create_adapter = MagicMock(
|
||||
side_effect=lambda platform, _config: adapters[platform]
|
||||
)
|
||||
runner._connect_adapter_with_timeout = AsyncMock(
|
||||
side_effect=[
|
||||
TimeoutError("telegram connect timed out after 30s"),
|
||||
True,
|
||||
]
|
||||
)
|
||||
|
||||
def fake_create_task(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch("gateway.status.write_runtime_status"):
|
||||
with patch("hermes_cli.plugins.discover_plugins"):
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
with patch("agent.shell_hooks.register_from_config"):
|
||||
with patch(
|
||||
"tools.process_registry.process_registry.recover_from_checkpoint",
|
||||
return_value=0,
|
||||
):
|
||||
with patch(
|
||||
"gateway.channel_directory.build_channel_directory",
|
||||
new=AsyncMock(return_value={"platforms": {}}),
|
||||
):
|
||||
with patch("gateway.run.asyncio.create_task", side_effect=fake_create_task):
|
||||
assert await runner.start() is True
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert Platform.FEISHU in runner.adapters
|
||||
assert Platform.TELEGRAM not in runner.adapters
|
||||
assert runner._create_adapter.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_adapter_timeout_raises_retryable_exception(self, monkeypatch):
|
||||
"""The timeout helper turns a hanging connect into a caught startup error."""
|
||||
runner = _make_runner()
|
||||
adapter = StubAdapter()
|
||||
|
||||
async def hang():
|
||||
await asyncio.sleep(60)
|
||||
return True
|
||||
|
||||
adapter.connect = hang
|
||||
monkeypatch.setenv("HERMES_GATEWAY_PLATFORM_CONNECT_TIMEOUT", "0.001")
|
||||
|
||||
with pytest.raises(TimeoutError, match="telegram connect timed out"):
|
||||
await runner._connect_adapter_with_timeout(adapter, Platform.TELEGRAM)
|
||||
|
||||
|
||||
class TestStartupFailureQueuing:
|
||||
"""Verify that failed platforms are queued during startup."""
|
||||
|
||||
|
||||
396
tests/gateway/test_platform_registry.py
Normal file
396
tests/gateway/test_platform_registry.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""Tests for the platform adapter registry and dynamic Platform enum."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from gateway.platform_registry import PlatformRegistry, PlatformEntry, platform_registry
|
||||
from gateway.config import Platform, PlatformConfig, GatewayConfig
|
||||
|
||||
|
||||
# ── Platform enum dynamic members ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlatformEnumDynamic:
|
||||
"""Test that Platform enum accepts unknown values for plugin platforms."""
|
||||
|
||||
def test_builtin_members_still_work(self):
|
||||
assert Platform.TELEGRAM.value == "telegram"
|
||||
assert Platform("telegram") is Platform.TELEGRAM
|
||||
|
||||
def test_dynamic_member_created(self):
|
||||
p = Platform("irc")
|
||||
assert p.value == "irc"
|
||||
assert p.name == "IRC"
|
||||
|
||||
def test_dynamic_member_identity_stable(self):
|
||||
"""Same value returns same object (cached)."""
|
||||
a = Platform("irc")
|
||||
b = Platform("irc")
|
||||
assert a is b
|
||||
|
||||
def test_dynamic_member_case_normalised(self):
|
||||
"""Mixed case normalised to lowercase."""
|
||||
a = Platform("IRC")
|
||||
b = Platform("irc")
|
||||
assert a is b
|
||||
assert a.value == "irc"
|
||||
|
||||
def test_dynamic_member_with_hyphens(self):
|
||||
"""Registered plugin platforms with hyphens work once registered."""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
entry = PlatformEntry(
|
||||
name="my-platform",
|
||||
label="My Platform",
|
||||
adapter_factory=lambda cfg: MagicMock(),
|
||||
check_fn=lambda: True,
|
||||
source="plugin",
|
||||
)
|
||||
_reg.register(entry)
|
||||
try:
|
||||
p = Platform("my-platform")
|
||||
assert p.value == "my-platform"
|
||||
assert p.name == "MY_PLATFORM"
|
||||
finally:
|
||||
_reg.unregister("my-platform")
|
||||
|
||||
def test_dynamic_member_rejects_unregistered(self):
|
||||
"""Arbitrary strings are rejected to prevent enum pollution."""
|
||||
with pytest.raises(ValueError):
|
||||
Platform("totally-fake-platform")
|
||||
|
||||
def test_dynamic_member_rejects_non_string(self):
|
||||
with pytest.raises(ValueError):
|
||||
Platform(123)
|
||||
|
||||
def test_dynamic_member_rejects_empty(self):
|
||||
with pytest.raises(ValueError):
|
||||
Platform("")
|
||||
|
||||
def test_dynamic_member_rejects_whitespace_only(self):
|
||||
with pytest.raises(ValueError):
|
||||
Platform(" ")
|
||||
|
||||
|
||||
# ── PlatformRegistry ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlatformRegistry:
|
||||
"""Test the PlatformRegistry itself."""
|
||||
|
||||
def _make_entry(self, name="test", check_ok=True, validate_ok=True, factory_ok=True):
|
||||
adapter_mock = MagicMock()
|
||||
return PlatformEntry(
|
||||
name=name,
|
||||
label=name.title(),
|
||||
adapter_factory=lambda cfg, _m=adapter_mock: _m if factory_ok else (_ for _ in ()).throw(RuntimeError("factory error")),
|
||||
check_fn=lambda: check_ok,
|
||||
validate_config=lambda cfg: validate_ok,
|
||||
required_env=[],
|
||||
source="plugin",
|
||||
), adapter_mock
|
||||
|
||||
def test_register_and_get(self):
|
||||
reg = PlatformRegistry()
|
||||
entry, _ = self._make_entry("alpha")
|
||||
reg.register(entry)
|
||||
assert reg.get("alpha") is entry
|
||||
assert reg.is_registered("alpha")
|
||||
|
||||
def test_get_unknown_returns_none(self):
|
||||
reg = PlatformRegistry()
|
||||
assert reg.get("nonexistent") is None
|
||||
|
||||
def test_unregister(self):
|
||||
reg = PlatformRegistry()
|
||||
entry, _ = self._make_entry("beta")
|
||||
reg.register(entry)
|
||||
assert reg.unregister("beta") is True
|
||||
assert reg.get("beta") is None
|
||||
assert reg.unregister("beta") is False # already gone
|
||||
|
||||
def test_create_adapter_success(self):
|
||||
reg = PlatformRegistry()
|
||||
entry, mock_adapter = self._make_entry("gamma")
|
||||
reg.register(entry)
|
||||
result = reg.create_adapter("gamma", MagicMock())
|
||||
assert result is mock_adapter
|
||||
|
||||
def test_create_adapter_unknown_name(self):
|
||||
reg = PlatformRegistry()
|
||||
assert reg.create_adapter("unknown", MagicMock()) is None
|
||||
|
||||
def test_create_adapter_check_fails(self):
|
||||
reg = PlatformRegistry()
|
||||
entry, _ = self._make_entry("delta", check_ok=False)
|
||||
reg.register(entry)
|
||||
assert reg.create_adapter("delta", MagicMock()) is None
|
||||
|
||||
def test_create_adapter_validate_fails(self):
|
||||
reg = PlatformRegistry()
|
||||
entry, _ = self._make_entry("epsilon", validate_ok=False)
|
||||
reg.register(entry)
|
||||
assert reg.create_adapter("epsilon", MagicMock()) is None
|
||||
|
||||
def test_create_adapter_factory_exception(self):
|
||||
reg = PlatformRegistry()
|
||||
entry = PlatformEntry(
|
||||
name="broken",
|
||||
label="Broken",
|
||||
adapter_factory=lambda cfg: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
check_fn=lambda: True,
|
||||
validate_config=None,
|
||||
source="plugin",
|
||||
)
|
||||
reg.register(entry)
|
||||
# factory raises → create_adapter returns None instead of propagating
|
||||
assert reg.create_adapter("broken", MagicMock()) is None
|
||||
|
||||
def test_create_adapter_no_validate(self):
|
||||
"""When validate_config is None, skip validation."""
|
||||
reg = PlatformRegistry()
|
||||
mock_adapter = MagicMock()
|
||||
entry = PlatformEntry(
|
||||
name="novalidate",
|
||||
label="NoValidate",
|
||||
adapter_factory=lambda cfg: mock_adapter,
|
||||
check_fn=lambda: True,
|
||||
validate_config=None,
|
||||
source="plugin",
|
||||
)
|
||||
reg.register(entry)
|
||||
assert reg.create_adapter("novalidate", MagicMock()) is mock_adapter
|
||||
|
||||
def test_all_entries(self):
|
||||
reg = PlatformRegistry()
|
||||
e1, _ = self._make_entry("one")
|
||||
e2, _ = self._make_entry("two")
|
||||
reg.register(e1)
|
||||
reg.register(e2)
|
||||
names = {e.name for e in reg.all_entries()}
|
||||
assert names == {"one", "two"}
|
||||
|
||||
def test_plugin_entries(self):
|
||||
reg = PlatformRegistry()
|
||||
plugin_entry, _ = self._make_entry("plugged")
|
||||
builtin_entry = PlatformEntry(
|
||||
name="core",
|
||||
label="Core",
|
||||
adapter_factory=lambda cfg: MagicMock(),
|
||||
check_fn=lambda: True,
|
||||
source="builtin",
|
||||
)
|
||||
reg.register(plugin_entry)
|
||||
reg.register(builtin_entry)
|
||||
plugin_names = {e.name for e in reg.plugin_entries()}
|
||||
assert plugin_names == {"plugged"}
|
||||
|
||||
def test_re_register_replaces(self):
|
||||
reg = PlatformRegistry()
|
||||
entry1, mock1 = self._make_entry("dup")
|
||||
entry2 = PlatformEntry(
|
||||
name="dup",
|
||||
label="Dup v2",
|
||||
adapter_factory=lambda cfg: "v2",
|
||||
check_fn=lambda: True,
|
||||
source="plugin",
|
||||
)
|
||||
reg.register(entry1)
|
||||
reg.register(entry2)
|
||||
assert reg.get("dup").label == "Dup v2"
|
||||
|
||||
|
||||
# ── GatewayConfig integration ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGatewayConfigPluginPlatform:
|
||||
"""Test that GatewayConfig parses and validates plugin platforms."""
|
||||
|
||||
def test_from_dict_accepts_plugin_platform(self):
|
||||
data = {
|
||||
"platforms": {
|
||||
"telegram": {"enabled": True, "token": "test-token"},
|
||||
"irc": {"enabled": True, "extra": {"server": "irc.libera.chat"}},
|
||||
}
|
||||
}
|
||||
cfg = GatewayConfig.from_dict(data)
|
||||
platform_values = {p.value for p in cfg.platforms}
|
||||
assert "telegram" in platform_values
|
||||
assert "irc" in platform_values
|
||||
|
||||
def test_get_connected_platforms_includes_registered_plugin(self):
|
||||
"""Plugin platform with registry entry passes get_connected_platforms."""
|
||||
# Register a fake plugin platform
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
test_entry = PlatformEntry(
|
||||
name="testplat",
|
||||
label="TestPlat",
|
||||
adapter_factory=lambda cfg: MagicMock(),
|
||||
check_fn=lambda: True,
|
||||
validate_config=lambda cfg: bool(cfg.extra.get("token")),
|
||||
source="plugin",
|
||||
)
|
||||
_reg.register(test_entry)
|
||||
try:
|
||||
data = {
|
||||
"platforms": {
|
||||
"testplat": {"enabled": True, "extra": {"token": "abc"}},
|
||||
}
|
||||
}
|
||||
cfg = GatewayConfig.from_dict(data)
|
||||
connected = cfg.get_connected_platforms()
|
||||
connected_values = {p.value for p in connected}
|
||||
assert "testplat" in connected_values
|
||||
finally:
|
||||
_reg.unregister("testplat")
|
||||
|
||||
def test_get_connected_platforms_excludes_unregistered_plugin(self):
|
||||
"""Plugin platform without registry entry is excluded."""
|
||||
data = {
|
||||
"platforms": {
|
||||
"unknown_plugin": {"enabled": True, "extra": {"token": "abc"}},
|
||||
}
|
||||
}
|
||||
cfg = GatewayConfig.from_dict(data)
|
||||
connected = cfg.get_connected_platforms()
|
||||
connected_values = {p.value for p in connected}
|
||||
assert "unknown_plugin" not in connected_values
|
||||
|
||||
def test_get_connected_platforms_excludes_invalid_config(self):
|
||||
"""Plugin platform with failing validate_config is excluded."""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
test_entry = PlatformEntry(
|
||||
name="badconfig",
|
||||
label="BadConfig",
|
||||
adapter_factory=lambda cfg: MagicMock(),
|
||||
check_fn=lambda: True,
|
||||
validate_config=lambda cfg: False, # always fails
|
||||
source="plugin",
|
||||
)
|
||||
_reg.register(test_entry)
|
||||
try:
|
||||
data = {
|
||||
"platforms": {
|
||||
"badconfig": {"enabled": True, "extra": {}},
|
||||
}
|
||||
}
|
||||
cfg = GatewayConfig.from_dict(data)
|
||||
connected = cfg.get_connected_platforms()
|
||||
connected_values = {p.value for p in connected}
|
||||
assert "badconfig" not in connected_values
|
||||
finally:
|
||||
_reg.unregister("badconfig")
|
||||
|
||||
|
||||
# ── Extended PlatformEntry fields ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlatformEntryExtendedFields:
|
||||
"""Test the auth, message length, and display fields on PlatformEntry."""
|
||||
|
||||
def test_default_field_values(self):
|
||||
entry = PlatformEntry(
|
||||
name="test",
|
||||
label="Test",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
)
|
||||
assert entry.allowed_users_env == ""
|
||||
assert entry.allow_all_env == ""
|
||||
assert entry.max_message_length == 0
|
||||
assert entry.pii_safe is False
|
||||
assert entry.emoji == "🔌"
|
||||
assert entry.allow_update_command is True
|
||||
|
||||
def test_custom_auth_fields(self):
|
||||
entry = PlatformEntry(
|
||||
name="irc",
|
||||
label="IRC",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
allowed_users_env="IRC_ALLOWED_USERS",
|
||||
allow_all_env="IRC_ALLOW_ALL_USERS",
|
||||
max_message_length=450,
|
||||
pii_safe=False,
|
||||
emoji="💬",
|
||||
)
|
||||
assert entry.allowed_users_env == "IRC_ALLOWED_USERS"
|
||||
assert entry.allow_all_env == "IRC_ALLOW_ALL_USERS"
|
||||
assert entry.max_message_length == 450
|
||||
assert entry.emoji == "💬"
|
||||
|
||||
|
||||
# ── Cron platform resolution ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCronPlatformResolution:
|
||||
"""Test that cron delivery accepts plugin platform names."""
|
||||
|
||||
def test_builtin_platform_resolves(self):
|
||||
"""Built-in platform names resolve via Platform() call."""
|
||||
p = Platform("telegram")
|
||||
assert p is Platform.TELEGRAM
|
||||
|
||||
def test_plugin_platform_resolves(self):
|
||||
"""Plugin platform names create dynamic enum members."""
|
||||
p = Platform("irc")
|
||||
assert p.value == "irc"
|
||||
|
||||
def test_invalid_platform_type_rejected(self):
|
||||
"""Non-string values are still rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
Platform(None)
|
||||
|
||||
|
||||
# ── platforms.py integration ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlatformsMerge:
|
||||
"""Test get_all_platforms() merges with registry."""
|
||||
|
||||
def test_get_all_platforms_includes_builtins(self):
|
||||
from hermes_cli.platforms import get_all_platforms, PLATFORMS
|
||||
merged = get_all_platforms()
|
||||
for key in PLATFORMS:
|
||||
assert key in merged
|
||||
|
||||
def test_get_all_platforms_includes_plugin(self):
|
||||
from hermes_cli.platforms import get_all_platforms
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="testmerge",
|
||||
label="TestMerge",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
source="plugin",
|
||||
emoji="🧪",
|
||||
))
|
||||
try:
|
||||
merged = get_all_platforms()
|
||||
assert "testmerge" in merged
|
||||
assert "TestMerge" in merged["testmerge"].label
|
||||
finally:
|
||||
_reg.unregister("testmerge")
|
||||
|
||||
def test_platform_label_plugin_fallback(self):
|
||||
from hermes_cli.platforms import platform_label
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="labeltest",
|
||||
label="LabelTest",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
source="plugin",
|
||||
emoji="🏷️",
|
||||
))
|
||||
try:
|
||||
label = platform_label("labeltest")
|
||||
assert "LabelTest" in label
|
||||
finally:
|
||||
_reg.unregister("labeltest")
|
||||
230
tests/gateway/test_plugin_platform_interface.py
Normal file
230
tests/gateway/test_plugin_platform_interface.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Interface compliance tests for all plugin-based gateway platforms.
|
||||
|
||||
Discovers platforms dynamically under ``plugins/platforms/`` — no manual
|
||||
enumeration — and verifies each one implements the required contract.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
PLATFORMS_DIR = PROJECT_ROOT / "plugins" / "platforms"
|
||||
|
||||
|
||||
def _discover_platform_plugins() -> list[str]:
|
||||
"""Return names of all bundled platform plugins."""
|
||||
if not PLATFORMS_DIR.is_dir():
|
||||
return []
|
||||
names = []
|
||||
for child in sorted(PLATFORMS_DIR.iterdir()):
|
||||
if child.is_dir() and (child / "__init__.py").exists():
|
||||
names.append(child.name)
|
||||
return names
|
||||
|
||||
|
||||
# Dynamically parametrise over discovered platforms
|
||||
_PLATFORM_NAMES = _discover_platform_plugins()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_registry():
|
||||
"""Yield with a clean platform registry, restoring state afterwards."""
|
||||
from gateway.platform_registry import platform_registry
|
||||
|
||||
original = dict(platform_registry._entries)
|
||||
platform_registry._entries.clear()
|
||||
yield platform_registry
|
||||
platform_registry._entries.clear()
|
||||
platform_registry._entries.update(original)
|
||||
|
||||
|
||||
class _MockPluginContext:
|
||||
"""Minimal mock of hermes_cli.plugins.PluginContext.
|
||||
|
||||
Only implements register_platform so we can exercise the plugin's
|
||||
register() entrypoint without importing the real plugin system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.registered_names: list[str] = []
|
||||
|
||||
def register_platform(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
label: str,
|
||||
adapter_factory: Any,
|
||||
check_fn: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
from gateway.platform_registry import platform_registry, PlatformEntry
|
||||
|
||||
entry = PlatformEntry(
|
||||
name=name,
|
||||
label=label,
|
||||
adapter_factory=adapter_factory,
|
||||
check_fn=check_fn,
|
||||
**kwargs,
|
||||
)
|
||||
platform_registry.register(entry)
|
||||
self.registered_names.append(name)
|
||||
|
||||
|
||||
def _import_platform_module(name: str) -> ModuleType:
|
||||
"""Import plugins.platforms.<name> in a test-safe way."""
|
||||
# Make sure the project root is on sys.path so relative imports work
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
module = importlib.import_module(f"plugins.platforms.{name}")
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_plugin_exposes_register_function(platform_name: str):
|
||||
"""Every platform plugin must expose a callable register function."""
|
||||
module = _import_platform_module(platform_name)
|
||||
assert hasattr(module, "register"), f"{platform_name} missing register()"
|
||||
assert callable(module.register), f"{platform_name}.register not callable"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_plugin_registers_valid_platform_entry(platform_name: str, clean_registry):
|
||||
"""Calling register() must create a valid PlatformEntry."""
|
||||
module = _import_platform_module(platform_name)
|
||||
ctx = _MockPluginContext()
|
||||
module.register(ctx)
|
||||
|
||||
assert platform_name in ctx.registered_names
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
entry = platform_registry.get(platform_name)
|
||||
assert entry is not None, f"{platform_name} did not register an entry"
|
||||
assert entry.name == platform_name
|
||||
assert entry.label
|
||||
assert callable(entry.adapter_factory)
|
||||
assert callable(entry.check_fn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_platform_entry_has_required_fields(platform_name: str, clean_registry):
|
||||
"""PlatformEntry must have the mandatory metadata fields."""
|
||||
module = _import_platform_module(platform_name)
|
||||
ctx = _MockPluginContext()
|
||||
module.register(ctx)
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
entry = platform_registry.get(platform_name)
|
||||
assert entry is not None
|
||||
|
||||
# Mandatory fields
|
||||
assert isinstance(entry.name, str) and entry.name
|
||||
assert isinstance(entry.label, str) and entry.label
|
||||
assert callable(entry.adapter_factory)
|
||||
assert callable(entry.check_fn)
|
||||
|
||||
# Optional but recommended fields
|
||||
if entry.validate_config is not None:
|
||||
assert callable(entry.validate_config)
|
||||
if entry.is_connected is not None:
|
||||
assert callable(entry.is_connected)
|
||||
if entry.setup_fn is not None:
|
||||
assert callable(entry.setup_fn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_adapter_factory_produces_valid_adapter(platform_name: str, clean_registry):
|
||||
"""The adapter factory must return an object with the base interface."""
|
||||
module = _import_platform_module(platform_name)
|
||||
ctx = _MockPluginContext()
|
||||
module.register(ctx)
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
entry = platform_registry.get(platform_name)
|
||||
assert entry is not None
|
||||
|
||||
# Build a minimal synthetic config that shouldn't crash __init__
|
||||
mock_config = MagicMock()
|
||||
mock_config.extra = {}
|
||||
mock_config.enabled = True
|
||||
mock_config.token = None
|
||||
mock_config.api_key = None
|
||||
mock_config.home_channel = None
|
||||
mock_config.reply_to_mode = "first"
|
||||
|
||||
adapter = entry.adapter_factory(mock_config)
|
||||
assert adapter is not None, f"{platform_name} adapter_factory returned None"
|
||||
|
||||
# Required adapter interface
|
||||
assert hasattr(adapter, "connect") and callable(adapter.connect)
|
||||
assert hasattr(adapter, "disconnect") and callable(adapter.disconnect)
|
||||
assert hasattr(adapter, "send") and callable(adapter.send)
|
||||
assert hasattr(adapter, "name")
|
||||
|
||||
# Should be a BasePlatformAdapter subclass if importable
|
||||
try:
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
assert isinstance(adapter, BasePlatformAdapter)
|
||||
except Exception:
|
||||
pytest.skip("BasePlatformAdapter not available for isinstance check")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_check_fn_returns_bool(platform_name: str, clean_registry):
|
||||
"""check_fn() must return a boolean."""
|
||||
module = _import_platform_module(platform_name)
|
||||
ctx = _MockPluginContext()
|
||||
module.register(ctx)
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
entry = platform_registry.get(platform_name)
|
||||
assert entry is not None
|
||||
|
||||
result = entry.check_fn()
|
||||
assert isinstance(result, bool), f"{platform_name}.check_fn() returned {type(result)}, expected bool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_validate_config_if_present(platform_name: str, clean_registry):
|
||||
"""If validate_config is provided, it must accept a config object."""
|
||||
module = _import_platform_module(platform_name)
|
||||
ctx = _MockPluginContext()
|
||||
module.register(ctx)
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
entry = platform_registry.get(platform_name)
|
||||
assert entry is not None
|
||||
|
||||
if entry.validate_config is None:
|
||||
pytest.skip("No validate_config provided")
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.extra = {}
|
||||
result = entry.validate_config(mock_config)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES)
|
||||
def test_is_connected_if_present(platform_name: str, clean_registry):
|
||||
"""If is_connected is provided, it must accept a config object."""
|
||||
module = _import_platform_module(platform_name)
|
||||
ctx = _MockPluginContext()
|
||||
module.register(ctx)
|
||||
|
||||
from gateway.platform_registry import platform_registry
|
||||
entry = platform_registry.get(platform_name)
|
||||
assert entry is not None
|
||||
|
||||
if entry.is_connected is None:
|
||||
pytest.skip("No is_connected provided")
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.extra = {}
|
||||
result = entry.is_connected(mock_config)
|
||||
assert isinstance(result, bool)
|
||||
200
tests/gateway/test_reload_skills_command.py
Normal file
200
tests/gateway/test_reload_skills_command.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Tests for the ``/reload-skills`` gateway slash command handler.
|
||||
|
||||
Verifies:
|
||||
* dispatcher routes ``/reload-skills`` to ``_handle_reload_skills_command``
|
||||
* the underscored alias ``/reload_skills`` is not flagged as unknown
|
||||
* the handler invokes ``agent.skill_commands.reload_skills`` and renders a
|
||||
human-readable diff
|
||||
* when any skills changed, a one-shot note is queued on
|
||||
``runner._pending_skills_reload_notes[session_key]`` (the agent loop
|
||||
consumes and clears it on the next user turn — see ``gateway/run.py``
|
||||
near the ``_has_fresh_tool_tail`` block)
|
||||
* the handler does NOT append to the session transcript out-of-band —
|
||||
message alternation must not be broken by a phantom user turn
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(
|
||||
emit=AsyncMock(),
|
||||
emit_collect=AsyncMock(return_value=[]),
|
||||
loaded_hooks=False,
|
||||
)
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||
# Use the real _session_key_for_source binding so the key matches what
|
||||
# the agent-loop consumer will look up later.
|
||||
from gateway.run import GatewayRunner as _GR
|
||||
runner._session_key_for_source = _GR._session_key_for_source.__get__(runner, _GR)
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_skills_handler_queues_note_on_diff(monkeypatch):
|
||||
"""Diff non-empty → handler queues a one-shot note and does NOT touch transcript."""
|
||||
fake_result = {
|
||||
"added": [
|
||||
{"name": "alpha", "description": "Run alpha to do xyz"},
|
||||
{"name": "beta", "description": "Run beta to do abc"},
|
||||
],
|
||||
"removed": [
|
||||
{"name": "gamma", "description": "Old removed skill"},
|
||||
],
|
||||
"unchanged": ["delta"],
|
||||
"total": 3,
|
||||
"commands": 3,
|
||||
}
|
||||
|
||||
import agent.skill_commands as skill_commands_mod
|
||||
monkeypatch.setattr(skill_commands_mod, "reload_skills", lambda: fake_result)
|
||||
|
||||
runner = _make_runner()
|
||||
event = _make_event("/reload-skills")
|
||||
out = await runner._handle_reload_skills_command(event)
|
||||
|
||||
assert out is not None
|
||||
assert "Skills Reloaded" in out
|
||||
assert "Added Skills:" in out
|
||||
assert "- alpha: Run alpha to do xyz" in out
|
||||
assert "- beta: Run beta to do abc" in out
|
||||
assert "Removed Skills:" in out
|
||||
assert "- gamma: Old removed skill" in out
|
||||
assert "3 skill(s) available" in out
|
||||
|
||||
# MUST NOT write to the session transcript — that would break alternation.
|
||||
runner.session_store.append_to_transcript.assert_not_called()
|
||||
|
||||
# MUST have queued a one-shot note keyed on the session.
|
||||
pending = getattr(runner, "_pending_skills_reload_notes", None)
|
||||
assert pending is not None
|
||||
session_key = runner._session_key_for_source(event.source)
|
||||
assert session_key in pending
|
||||
note = pending[session_key]
|
||||
assert note.startswith("[USER INITIATED SKILLS RELOAD:")
|
||||
assert note.endswith("Use skills_list to see the updated catalog.]")
|
||||
assert "Added Skills:" in note
|
||||
assert " - alpha: Run alpha to do xyz" in note
|
||||
assert " - beta: Run beta to do abc" in note
|
||||
assert "Removed Skills:" in note
|
||||
assert " - gamma: Old removed skill" in note
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_skills_handler_reports_no_changes(monkeypatch):
|
||||
"""No diff → no queued note, no transcript write."""
|
||||
import agent.skill_commands as skill_commands_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
skill_commands_mod,
|
||||
"reload_skills",
|
||||
lambda: {
|
||||
"added": [],
|
||||
"removed": [],
|
||||
"unchanged": ["alpha"],
|
||||
"total": 1,
|
||||
"commands": 1,
|
||||
},
|
||||
)
|
||||
|
||||
runner = _make_runner()
|
||||
out = await runner._handle_reload_skills_command(_make_event("/reload-skills"))
|
||||
|
||||
assert "No new skills detected" in out
|
||||
assert "1 skill(s) available" in out
|
||||
runner.session_store.append_to_transcript.assert_not_called()
|
||||
# No queued note when nothing changed.
|
||||
pending = getattr(runner, "_pending_skills_reload_notes", None)
|
||||
assert not pending # None or empty dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_routes_reload_skills(monkeypatch):
|
||||
"""``/reload-skills`` must reach ``_handle_reload_skills_command``."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
runner = _make_runner()
|
||||
sentinel = "reload-skills handler reached"
|
||||
runner._handle_reload_skills_command = AsyncMock(return_value=sentinel) # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("/reload-skills"))
|
||||
assert result == sentinel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_underscored_alias_not_flagged_unknown(monkeypatch):
|
||||
"""Telegram autocomplete sends ``/reload_skills`` for ``/reload-skills``."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
runner = _make_runner()
|
||||
runner._handle_reload_skills_command = AsyncMock(return_value="ok") # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("/reload_skills"))
|
||||
if result is not None:
|
||||
assert "Unknown command" not in result
|
||||
@@ -230,3 +230,30 @@ class TestHandleResumeCommand:
|
||||
|
||||
assert real_key not in runner._running_agents
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_evicts_cached_agent(self, tmp_path):
|
||||
"""Gateway /resume evicts the cached AIAgent so the next message
|
||||
rebuilds with the correct session_id end-to-end — mirrors /branch
|
||||
and /reset. Without this, the cached agent's memory provider keeps
|
||||
writing into the wrong session. See #6672.
|
||||
"""
|
||||
import threading
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
# Seed the cache with a fake agent
|
||||
real_key = _session_key_for_event(event)
|
||||
runner._agent_cache = {real_key: (MagicMock(), object())}
|
||||
runner._agent_cache_lock = threading.RLock()
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
assert real_key not in runner._agent_cache
|
||||
db.close()
|
||||
|
||||
@@ -67,14 +67,20 @@ class NonEditingProgressCaptureAdapter(ProgressCaptureAdapter):
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, **kwargs):
|
||||
# Capture anything passed via kwargs (older code path) but don't
|
||||
# freeze it — production now assigns tool_progress_callback after
|
||||
# construction (see gateway/run.py around the agent-cache hit),
|
||||
# so we must read it at call time, not at init.
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", "pwd", {})
|
||||
time.sleep(0.35)
|
||||
self.tool_progress_callback("tool.started", "browser_navigate", "https://example.com", {})
|
||||
time.sleep(0.35)
|
||||
cb = self.tool_progress_callback
|
||||
if cb is not None:
|
||||
cb("tool.started", "terminal", "pwd", {})
|
||||
time.sleep(0.35)
|
||||
cb("tool.started", "browser_navigate", "https://example.com", {})
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
@@ -251,6 +257,14 @@ async def test_run_agent_progress_does_not_use_event_message_id_for_telegram_dm(
|
||||
async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch, tmp_path):
|
||||
"""Slack DM progress should keep event ts fallback threading."""
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
# Since PR #8006, Slack's built-in display tier sets tool_progress="off"
|
||||
# by default. Override via config so this test still exercises the
|
||||
# progress-callback path the Slack DM event_message_id threading depends on.
|
||||
import yaml
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"platforms": {"slack": {"tool_progress": "all"}}}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
|
||||
@@ -12,9 +12,13 @@ from gateway.session import (
|
||||
build_session_context_prompt,
|
||||
build_session_key,
|
||||
canonical_whatsapp_identifier,
|
||||
normalize_whatsapp_identifier,
|
||||
)
|
||||
|
||||
# Legacy name preserved for these tests; product renamed the function to
|
||||
# canonical_whatsapp_identifier. Keep the tests referencing the old name
|
||||
# working without duplicating the suite.
|
||||
normalize_whatsapp_identifier = canonical_whatsapp_identifier
|
||||
|
||||
|
||||
class TestSessionSourceRoundtrip:
|
||||
def test_full_roundtrip(self):
|
||||
@@ -85,8 +89,13 @@ class TestSessionSourceRoundtrip:
|
||||
assert restored.chat_topic is None
|
||||
assert restored.chat_type == "dm"
|
||||
|
||||
def test_invalid_platform_raises(self):
|
||||
with pytest.raises((ValueError, KeyError)):
|
||||
def test_unknown_platform_rejected_for_bad_names(self):
|
||||
"""Arbitrary platform names are rejected (no accidental enum pollution).
|
||||
|
||||
Only bundled platform plugins (discovered under ``plugins/platforms/``)
|
||||
and runtime-registered plugins get dynamic enum members.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
|
||||
|
||||
|
||||
|
||||
@@ -800,15 +800,23 @@ class TestSignalSendDocumentViaHelper:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() returns message_id from timestamp (#4647)
|
||||
# Signal streaming edit capability / message_id behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalStreamingCapabilities:
|
||||
"""Signal must opt out of edit-based streaming behavior."""
|
||||
|
||||
def test_signal_declares_no_message_editing(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
|
||||
assert adapter.SUPPORTS_MESSAGE_EDITING is False
|
||||
|
||||
|
||||
class TestSignalSendReturnsMessageId:
|
||||
"""Signal send() must return a timestamp-based message_id so the stream
|
||||
consumer can follow its edit→fallback path correctly."""
|
||||
"""Signal send() should not pretend sent messages are editable."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_timestamp_as_message_id(self, monkeypatch):
|
||||
async def test_send_returns_none_message_id_even_with_timestamp(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc({"timestamp": 1712345678000})
|
||||
adapter._rpc = mock_rpc
|
||||
@@ -817,7 +825,7 @@ class TestSignalSendReturnsMessageId:
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1712345678000"
|
||||
assert result.message_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_none_message_id_when_no_timestamp(self, monkeypatch):
|
||||
@@ -997,3 +1005,100 @@ class TestSignalTypingBackoff:
|
||||
|
||||
assert "+155****4567" not in adapter._typing_failures
|
||||
assert "+155****4567" not in adapter._typing_skip_until
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reply quote extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalQuoteExtraction:
|
||||
"""Verify Signal reply quote fields are propagated to MessageEvent."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_envelope_sets_reply_context_from_quote(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
async def fake_handle(event):
|
||||
captured["event"] = event
|
||||
|
||||
adapter.handle_message = fake_handle
|
||||
|
||||
await adapter._handle_envelope({
|
||||
"envelope": {
|
||||
"sourceNumber": "+15550001111",
|
||||
"sourceUuid": "uuid-sender",
|
||||
"sourceName": "Tester",
|
||||
"timestamp": 1000000000,
|
||||
"dataMessage": {
|
||||
"message": "yes I agree",
|
||||
"quote": {
|
||||
"id": 99,
|
||||
"text": "want to grab lunch?",
|
||||
"author": "+15550002222",
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
event = captured["event"]
|
||||
assert event.text == "yes I agree"
|
||||
assert event.reply_to_message_id == "99"
|
||||
assert event.reply_to_text == "want to grab lunch?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_envelope_without_quote_leaves_reply_fields_none(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
async def fake_handle(event):
|
||||
captured["event"] = event
|
||||
|
||||
adapter.handle_message = fake_handle
|
||||
|
||||
await adapter._handle_envelope({
|
||||
"envelope": {
|
||||
"sourceNumber": "+15550001111",
|
||||
"sourceUuid": "uuid-sender",
|
||||
"sourceName": "Tester",
|
||||
"timestamp": 1000000000,
|
||||
"dataMessage": {
|
||||
"message": "plain message",
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
event = captured["event"]
|
||||
assert event.text == "plain message"
|
||||
assert event.reply_to_message_id is None
|
||||
assert event.reply_to_text is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_envelope_quote_without_text_sets_only_reply_id(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
async def fake_handle(event):
|
||||
captured["event"] = event
|
||||
|
||||
adapter.handle_message = fake_handle
|
||||
|
||||
await adapter._handle_envelope({
|
||||
"envelope": {
|
||||
"sourceNumber": "+15550001111",
|
||||
"sourceUuid": "uuid-sender",
|
||||
"sourceName": "Tester",
|
||||
"timestamp": 1000000000,
|
||||
"dataMessage": {
|
||||
"message": "reply without quote text",
|
||||
"quote": {
|
||||
"id": 123,
|
||||
"author": "+15550002222",
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
event = captured["event"]
|
||||
assert event.reply_to_message_id == "123"
|
||||
assert event.reply_to_text is None
|
||||
|
||||
452
tests/gateway/test_signal_format.py
Normal file
452
tests/gateway/test_signal_format.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""Tests for Signal _markdown_to_signal() formatting.
|
||||
|
||||
Covers the markdown-to-bodyRanges conversion pipeline: bold, italic,
|
||||
strikethrough, monospace, code blocks, headings, and — critically — the
|
||||
false-positive regressions that caused spurious italics in production.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _m2s(text: str):
|
||||
"""Shorthand: call the static method and return (plain_text, styles)."""
|
||||
return SignalAdapter._markdown_to_signal(text)
|
||||
|
||||
|
||||
def _style_types(styles: list[str]) -> list[str]:
|
||||
"""Extract just the STYLE part from '0:4:BOLD' strings."""
|
||||
return [s.rsplit(":", 1)[1] for s in styles]
|
||||
|
||||
|
||||
def _find_style(styles: list[str], style_type: str) -> list[str]:
|
||||
"""Return only styles matching a given type."""
|
||||
return [s for s in styles if s.endswith(f":{style_type}")]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Basic formatting
|
||||
# ===========================================================================
|
||||
|
||||
class TestMarkdownToSignalBasic:
|
||||
"""Core formatting: bold, italic, strikethrough, monospace."""
|
||||
|
||||
def test_bold_double_asterisk(self):
|
||||
text, styles = _m2s("hello **world**")
|
||||
assert text == "hello world"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":BOLD")
|
||||
|
||||
def test_bold_double_underscore(self):
|
||||
text, styles = _m2s("hello __world__")
|
||||
assert text == "hello world"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":BOLD")
|
||||
|
||||
def test_italic_single_asterisk(self):
|
||||
text, styles = _m2s("hello *world*")
|
||||
assert text == "hello world"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":ITALIC")
|
||||
|
||||
def test_italic_single_underscore(self):
|
||||
text, styles = _m2s("hello _world_")
|
||||
assert text == "hello world"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":ITALIC")
|
||||
|
||||
def test_strikethrough(self):
|
||||
text, styles = _m2s("hello ~~world~~")
|
||||
assert text == "hello world"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":STRIKETHROUGH")
|
||||
|
||||
def test_inline_monospace(self):
|
||||
text, styles = _m2s("run `ls -la` now")
|
||||
assert text == "run ls -la now"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":MONOSPACE")
|
||||
|
||||
def test_fenced_code_block(self):
|
||||
text, styles = _m2s("before\n```\ncode here\n```\nafter")
|
||||
assert "code here" in text
|
||||
assert "```" not in text
|
||||
assert any(s.endswith(":MONOSPACE") for s in styles)
|
||||
|
||||
def test_heading_becomes_bold(self):
|
||||
text, styles = _m2s("## Section Title")
|
||||
assert text == "Section Title"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":BOLD")
|
||||
|
||||
def test_multiple_styles(self):
|
||||
text, styles = _m2s("**bold** and *italic*")
|
||||
assert text == "bold and italic"
|
||||
types = _style_types(styles)
|
||||
assert "BOLD" in types
|
||||
assert "ITALIC" in types
|
||||
|
||||
def test_plain_text_no_styles(self):
|
||||
text, styles = _m2s("just plain text")
|
||||
assert text == "just plain text"
|
||||
assert styles == []
|
||||
|
||||
def test_empty_string(self):
|
||||
text, styles = _m2s("")
|
||||
assert text == ""
|
||||
assert styles == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Italic false-positive regressions
|
||||
# ===========================================================================
|
||||
|
||||
class TestItalicFalsePositives:
|
||||
"""Regressions from signal-italic-false-positive-fix.md and
|
||||
signal-italic-bullet-list-fix.md."""
|
||||
|
||||
# --- snake_case (original fix) ---
|
||||
|
||||
def test_snake_case_not_italic(self):
|
||||
"""snake_case identifiers must NOT be italicized."""
|
||||
text, styles = _m2s("the config_file is ready")
|
||||
assert text == "the config_file is ready"
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_multiple_snake_case(self):
|
||||
text, styles = _m2s("set OPENAI_API_KEY and ANTHROPIC_API_KEY")
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_snake_case_path(self):
|
||||
text, styles = _m2s("/tools/delegate_tool.py")
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_snake_case_between_words(self):
|
||||
"""file_path and error_code — underscores between words."""
|
||||
text, styles = _m2s("file_path and error_code")
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
# --- Bullet lists (second fix) ---
|
||||
|
||||
def test_bullet_list_not_italic(self):
|
||||
"""* item lines must NOT be treated as italic delimiters."""
|
||||
md = "* item one\n* item two\n* item three"
|
||||
text, styles = _m2s(md)
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_bullet_list_with_content_before(self):
|
||||
md = "Here are things:\n\n* first thing\n* second thing"
|
||||
text, styles = _m2s(md)
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_bullet_list_file_paths(self):
|
||||
"""Real-world case that triggered the bug."""
|
||||
md = (
|
||||
"* tools/delegate_tool.py — delegation\n"
|
||||
"* tools/file_tools.py — file operations\n"
|
||||
"* tools/web_tools.py — web operations"
|
||||
)
|
||||
text, styles = _m2s(md)
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_bullet_with_italic_inside(self):
|
||||
"""Italic *inside* a bullet item should still work."""
|
||||
md = "* this has *emphasis* inside\n* plain item"
|
||||
text, styles = _m2s(md)
|
||||
italic_styles = _find_style(styles, "ITALIC")
|
||||
assert len(italic_styles) == 1
|
||||
# The italic should cover "emphasis", not the whole bullet
|
||||
assert "emphasis" in text
|
||||
|
||||
# --- Cross-line spans (DOTALL removal) ---
|
||||
|
||||
def test_star_italic_no_cross_line(self):
|
||||
"""*foo\\nbar* must NOT match as italic (no DOTALL)."""
|
||||
text, styles = _m2s("*foo\nbar*")
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_underscore_italic_no_cross_line(self):
|
||||
"""_foo\\nbar_ must NOT match as italic (no DOTALL)."""
|
||||
text, styles = _m2s("_foo\nbar_")
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_star_italic_multiline_response(self):
|
||||
"""Multi-paragraph response with * should not false-positive."""
|
||||
md = (
|
||||
"I checked the following files:\n\n"
|
||||
"* tools/delegate_tool.py — sub-agent delegation\n"
|
||||
"* tools/file_tools.py — file read/write/search\n"
|
||||
"* tools/web_tools.py — web search/extract\n\n"
|
||||
"Everything looks good."
|
||||
)
|
||||
text, styles = _m2s(md)
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
# --- Legitimate italic still works ---
|
||||
|
||||
def test_star_italic_still_works(self):
|
||||
text, styles = _m2s("this is *italic* text")
|
||||
assert text == "this is italic text"
|
||||
assert len(_find_style(styles, "ITALIC")) == 1
|
||||
|
||||
def test_underscore_italic_still_works(self):
|
||||
text, styles = _m2s("this is _italic_ text")
|
||||
assert text == "this is italic text"
|
||||
assert len(_find_style(styles, "ITALIC")) == 1
|
||||
|
||||
def test_multiple_italic_same_line(self):
|
||||
text, styles = _m2s("*foo* and *bar* ok")
|
||||
assert text == "foo and bar ok"
|
||||
assert len(_find_style(styles, "ITALIC")) == 2
|
||||
|
||||
def test_italic_single_word(self):
|
||||
text, styles = _m2s("*word*")
|
||||
assert text == "word"
|
||||
assert len(_find_style(styles, "ITALIC")) == 1
|
||||
|
||||
def test_italic_multi_word(self):
|
||||
text, styles = _m2s("*several words here*")
|
||||
assert text == "several words here"
|
||||
assert len(_find_style(styles, "ITALIC")) == 1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Style position accuracy
|
||||
# ===========================================================================
|
||||
|
||||
class TestStylePositions:
|
||||
"""Verify that start:length positions map to the correct text."""
|
||||
|
||||
def _extract(self, text: str, style_str: str) -> str:
|
||||
"""Given 'start:length:STYLE', extract the substring from text."""
|
||||
# Positions are UTF-16 code units; for ASCII they match code points
|
||||
parts = style_str.split(":")
|
||||
start, length = int(parts[0]), int(parts[1])
|
||||
# Encode to UTF-16-LE, slice, decode back
|
||||
encoded = text.encode("utf-16-le")
|
||||
extracted = encoded[start * 2 : (start + length) * 2]
|
||||
return extracted.decode("utf-16-le")
|
||||
|
||||
def test_bold_position(self):
|
||||
text, styles = _m2s("hello **world** end")
|
||||
assert len(styles) == 1
|
||||
assert self._extract(text, styles[0]) == "world"
|
||||
|
||||
def test_italic_position(self):
|
||||
text, styles = _m2s("hello *world* end")
|
||||
assert len(styles) == 1
|
||||
assert self._extract(text, styles[0]) == "world"
|
||||
|
||||
def test_multiple_styles_positions(self):
|
||||
text, styles = _m2s("**bold** then *italic*")
|
||||
assert len(styles) == 2
|
||||
extracted = {self._extract(text, s) for s in styles}
|
||||
assert extracted == {"bold", "italic"}
|
||||
|
||||
def test_emoji_utf16_offset(self):
|
||||
"""Emoji (multi-byte UTF-16) before a styled span."""
|
||||
text, styles = _m2s("👋 **hello**")
|
||||
assert text == "👋 hello"
|
||||
assert len(styles) == 1
|
||||
assert self._extract(text, styles[0]) == "hello"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Edge cases
|
||||
# ===========================================================================
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tricky inputs that have caused issues or could regress."""
|
||||
|
||||
def test_bold_inside_bullet(self):
|
||||
"""Bold inside a bullet list item."""
|
||||
md = "* **important** item\n* normal item"
|
||||
text, styles = _m2s(md)
|
||||
assert len(_find_style(styles, "BOLD")) == 1
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_code_span_with_underscores(self):
|
||||
"""`snake_case_var` — backtick takes priority over underscore."""
|
||||
text, styles = _m2s("use `my_var_name` here")
|
||||
assert text == "use my_var_name here"
|
||||
types = _style_types(styles)
|
||||
assert "MONOSPACE" in types
|
||||
assert "ITALIC" not in types
|
||||
|
||||
def test_bold_and_italic_nested(self):
|
||||
"""***bold+italic*** — bold captured, not italic (bold pattern first)."""
|
||||
text, styles = _m2s("***word***")
|
||||
# ** matches bold around *word*, or *** is ambiguous;
|
||||
# either way there should be no false italic of the whole string
|
||||
assert "word" in text
|
||||
|
||||
def test_lone_asterisk(self):
|
||||
"""A single * with no pair should not cause issues."""
|
||||
text, styles = _m2s("5 * 3 = 15")
|
||||
# Should not crash; any italic match would be a false positive
|
||||
assert "5" in text and "15" in text
|
||||
|
||||
def test_lone_underscore(self):
|
||||
"""A single _ with no pair."""
|
||||
text, styles = _m2s("this _ that")
|
||||
assert text == "this _ that"
|
||||
|
||||
def test_consecutive_underscored_words(self):
|
||||
"""_foo and _bar (leading underscores, no closers)."""
|
||||
text, styles = _m2s("call _init and _setup")
|
||||
assert _find_style(styles, "ITALIC") == []
|
||||
|
||||
def test_mixed_formatting_no_bleed(self):
|
||||
"""Multiple format types don't bleed into each other."""
|
||||
md = "**bold** and `code` and *italic* and ~~strike~~"
|
||||
text, styles = _m2s(md)
|
||||
assert text == "bold and code and italic and strike"
|
||||
types = _style_types(styles)
|
||||
assert sorted(types) == ["BOLD", "ITALIC", "MONOSPACE", "STRIKETHROUGH"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# signal-markdown-strip-patch: core conversion pipeline
|
||||
# ===========================================================================
|
||||
|
||||
class TestMarkdownStripPatch:
|
||||
"""Tests for the original signal-markdown-strip-patch.
|
||||
|
||||
Covers: fenced code blocks with language tags, links preserved,
|
||||
headings converted to bold, multiple headings, UTF-16 correctness
|
||||
for multi-byte characters, and marker stripping completeness.
|
||||
"""
|
||||
|
||||
def test_fenced_code_block_with_language_tag(self):
|
||||
"""```python\\ncode\\n``` — language tag is stripped, content is MONOSPACE."""
|
||||
text, styles = _m2s("```python\nprint('hello')\n```")
|
||||
assert "```" not in text
|
||||
assert "python" not in text # language tag stripped
|
||||
assert "print('hello')" in text
|
||||
assert any(s.endswith(":MONOSPACE") for s in styles)
|
||||
|
||||
def test_fenced_code_block_multiline(self):
|
||||
"""Multi-line code blocks preserve all lines."""
|
||||
md = "```\nline1\nline2\nline3\n```"
|
||||
text, styles = _m2s(md)
|
||||
assert "line1" in text
|
||||
assert "line2" in text
|
||||
assert "line3" in text
|
||||
assert "```" not in text
|
||||
|
||||
def test_links_preserved(self):
|
||||
"""[text](url) links are kept as-is — Signal auto-linkifies."""
|
||||
md = "Check [this link](https://example.com) for details"
|
||||
text, styles = _m2s(md)
|
||||
# Links should pass through — either as markdown or just preserved
|
||||
assert "https://example.com" in text
|
||||
|
||||
def test_heading_h1(self):
|
||||
"""# H1 becomes bold text."""
|
||||
text, styles = _m2s("# Main Title")
|
||||
assert text == "Main Title"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":BOLD")
|
||||
|
||||
def test_heading_h3(self):
|
||||
"""### H3 becomes bold text."""
|
||||
text, styles = _m2s("### Sub Section")
|
||||
assert text == "Sub Section"
|
||||
assert len(styles) == 1
|
||||
assert styles[0].endswith(":BOLD")
|
||||
|
||||
def test_multiple_headings(self):
|
||||
"""Multiple headings each become separate bold spans."""
|
||||
md = "## First\n\nSome text\n\n## Second"
|
||||
text, styles = _m2s(md)
|
||||
assert "First" in text
|
||||
assert "Second" in text
|
||||
assert "##" not in text
|
||||
bold_styles = _find_style(styles, "BOLD")
|
||||
assert len(bold_styles) == 2
|
||||
|
||||
def test_no_raw_markdown_markers_in_output(self):
|
||||
"""All markdown syntax is stripped from plain text output."""
|
||||
md = "**bold** and *italic* and ~~struck~~ and `code` and ## heading"
|
||||
text, styles = _m2s(md)
|
||||
assert "**" not in text
|
||||
assert "~~" not in text
|
||||
assert "`" not in text
|
||||
# ## at end might remain if not at line start — that's ok
|
||||
# The important thing is styled markers are stripped
|
||||
|
||||
def test_utf16_surrogate_pair_emoji(self):
|
||||
"""Emoji requiring UTF-16 surrogate pairs don't corrupt offsets."""
|
||||
# 🎉 is U+1F389 — requires surrogate pair (2 UTF-16 code units)
|
||||
text, styles = _m2s("🎉🎉 **test**")
|
||||
assert "test" in text
|
||||
assert len(styles) == 1
|
||||
# Verify the style position is correct
|
||||
parts = styles[0].split(":")
|
||||
start, length = int(parts[0]), int(parts[1])
|
||||
# 🎉🎉 = 4 UTF-16 code units + space = 5, then "test" = 4
|
||||
assert start == 5
|
||||
assert length == 4
|
||||
|
||||
def test_consecutive_newlines_collapsed(self):
|
||||
"""3+ consecutive newlines are collapsed to 2."""
|
||||
text, styles = _m2s("first\n\n\n\n\nsecond")
|
||||
assert "\n\n\n" not in text
|
||||
assert "first" in text
|
||||
assert "second" in text
|
||||
|
||||
def test_empty_bold_not_crash(self):
|
||||
"""**** (empty bold) should not crash."""
|
||||
text, styles = _m2s("before **** after")
|
||||
# Should not raise — exact output doesn't matter much
|
||||
assert "before" in text
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# signal-streaming-patch: SUPPORTS_MESSAGE_EDITING and send() behavior
|
||||
# ===========================================================================
|
||||
|
||||
class TestSignalStreamingPatch:
|
||||
"""Tests for signal-streaming-patch: cursor suppression and edit support.
|
||||
|
||||
These verify the adapter-level properties that prevent the streaming
|
||||
cursor from leaking into Signal messages.
|
||||
"""
|
||||
|
||||
def test_signal_does_not_support_editing(self, monkeypatch):
|
||||
"""SignalAdapter.SUPPORTS_MESSAGE_EDITING must be False."""
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
assert SignalAdapter.SUPPORTS_MESSAGE_EDITING is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_no_message_id(self, monkeypatch):
|
||||
"""send() returns message_id=None so stream consumer uses no-edit path."""
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
config = PlatformConfig(enabled=True)
|
||||
config.extra = {
|
||||
"http_url": "http://localhost:8080",
|
||||
"account": "+15551234567",
|
||||
}
|
||||
adapter = SignalAdapter(config)
|
||||
|
||||
# Mock the RPC call
|
||||
async def mock_rpc(method, params, rpc_id=None):
|
||||
return {"timestamp": 1234567890}
|
||||
|
||||
adapter._rpc = mock_rpc
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="+15559876543",
|
||||
content="Hello",
|
||||
)
|
||||
assert result.message_id is None
|
||||
560
tests/gateway/test_teams.py
Normal file
560
tests/gateway/test_teams.py
Normal file
@@ -0,0 +1,560 @@
|
||||
"""Tests for the Microsoft Teams platform adapter plugin."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig, HomeChannel
|
||||
from tests.gateway._plugin_adapter_loader import load_plugin_adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK Mock — install in sys.modules before importing the adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_teams_mock():
|
||||
"""Install a teams SDK mock in sys.modules if the real package isn't present."""
|
||||
if "microsoft_teams" in sys.modules and hasattr(sys.modules["microsoft_teams"], "__file__"):
|
||||
return
|
||||
|
||||
# Build the module hierarchy
|
||||
microsoft_teams = types.ModuleType("microsoft_teams")
|
||||
microsoft_teams_apps = types.ModuleType("microsoft_teams.apps")
|
||||
microsoft_teams_api = types.ModuleType("microsoft_teams.api")
|
||||
microsoft_teams_api_activities = types.ModuleType("microsoft_teams.api.activities")
|
||||
microsoft_teams_api_activities_typing = types.ModuleType("microsoft_teams.api.activities.typing")
|
||||
microsoft_teams_api_activities_invoke = types.ModuleType("microsoft_teams.api.activities.invoke")
|
||||
microsoft_teams_api_activities_invoke_adaptive_card = types.ModuleType(
|
||||
"microsoft_teams.api.activities.invoke.adaptive_card"
|
||||
)
|
||||
microsoft_teams_api_models = types.ModuleType("microsoft_teams.api.models")
|
||||
microsoft_teams_api_models_adaptive_card = types.ModuleType("microsoft_teams.api.models.adaptive_card")
|
||||
microsoft_teams_api_models_invoke_response = types.ModuleType("microsoft_teams.api.models.invoke_response")
|
||||
microsoft_teams_cards = types.ModuleType("microsoft_teams.cards")
|
||||
microsoft_teams_apps_http = types.ModuleType("microsoft_teams.apps.http")
|
||||
microsoft_teams_apps_http_adapter = types.ModuleType("microsoft_teams.apps.http.adapter")
|
||||
|
||||
# App class mock
|
||||
class MockApp:
|
||||
def __init__(self, **kwargs):
|
||||
self._client_id = kwargs.get("client_id")
|
||||
self.server = MagicMock()
|
||||
self.server.handle_request = AsyncMock(return_value={"status": 200, "body": None})
|
||||
self.credentials = MagicMock()
|
||||
self.credentials.client_id = self._client_id
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._client_id
|
||||
|
||||
def on_message(self, func):
|
||||
self._message_handler = func
|
||||
return func
|
||||
|
||||
def on_card_action(self, func):
|
||||
self._card_action_handler = func
|
||||
return func
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def send(self, conversation_id, activity):
|
||||
result = MagicMock()
|
||||
result.id = "sent-activity-id"
|
||||
return result
|
||||
|
||||
async def start(self, port=3978):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
microsoft_teams_apps.App = MockApp
|
||||
microsoft_teams_apps.ActivityContext = MagicMock
|
||||
|
||||
# MessageActivity mock
|
||||
microsoft_teams_api.MessageActivity = MagicMock
|
||||
microsoft_teams_api.ConversationReference = MagicMock
|
||||
microsoft_teams_api.MessageActivityInput = MagicMock
|
||||
|
||||
# TypingActivityInput mock
|
||||
class MockTypingActivityInput:
|
||||
pass
|
||||
|
||||
microsoft_teams_api_activities_typing.TypingActivityInput = MockTypingActivityInput
|
||||
|
||||
# Adaptive card invoke activity mock
|
||||
microsoft_teams_api_activities_invoke_adaptive_card.AdaptiveCardInvokeActivity = MagicMock
|
||||
|
||||
# Adaptive card response mocks
|
||||
microsoft_teams_api_models_adaptive_card.AdaptiveCardActionCardResponse = MagicMock
|
||||
microsoft_teams_api_models_adaptive_card.AdaptiveCardActionMessageResponse = MagicMock
|
||||
|
||||
# Invoke response mocks
|
||||
class MockInvokeResponse:
|
||||
def __init__(self, status=200, body=None):
|
||||
self.status = status
|
||||
self.body = body
|
||||
|
||||
microsoft_teams_api_models_invoke_response.InvokeResponse = MockInvokeResponse
|
||||
microsoft_teams_api_models_invoke_response.AdaptiveCardInvokeResponse = MagicMock
|
||||
|
||||
# Cards mocks
|
||||
class MockAdaptiveCard:
|
||||
def with_version(self, v):
|
||||
return self
|
||||
|
||||
def with_body(self, body):
|
||||
return self
|
||||
|
||||
def with_actions(self, actions):
|
||||
return self
|
||||
|
||||
microsoft_teams_cards.AdaptiveCard = MockAdaptiveCard
|
||||
microsoft_teams_cards.ExecuteAction = MagicMock
|
||||
microsoft_teams_cards.TextBlock = MagicMock
|
||||
|
||||
# HttpRequest TypedDict mock
|
||||
def HttpRequest(body=None, headers=None):
|
||||
return {"body": body, "headers": headers}
|
||||
|
||||
# HttpResponse TypedDict mock
|
||||
HttpResponse = dict
|
||||
HttpMethod = str
|
||||
from typing import Callable
|
||||
HttpRouteHandler = Callable
|
||||
|
||||
microsoft_teams_apps_http_adapter.HttpRequest = HttpRequest
|
||||
microsoft_teams_apps_http_adapter.HttpResponse = HttpResponse
|
||||
microsoft_teams_apps_http_adapter.HttpMethod = HttpMethod
|
||||
microsoft_teams_apps_http_adapter.HttpRouteHandler = HttpRouteHandler
|
||||
|
||||
# Wire the hierarchy
|
||||
for name, mod in {
|
||||
"microsoft_teams": microsoft_teams,
|
||||
"microsoft_teams.apps": microsoft_teams_apps,
|
||||
"microsoft_teams.api": microsoft_teams_api,
|
||||
"microsoft_teams.api.activities": microsoft_teams_api_activities,
|
||||
"microsoft_teams.api.activities.typing": microsoft_teams_api_activities_typing,
|
||||
"microsoft_teams.api.activities.invoke": microsoft_teams_api_activities_invoke,
|
||||
"microsoft_teams.api.activities.invoke.adaptive_card": microsoft_teams_api_activities_invoke_adaptive_card,
|
||||
"microsoft_teams.api.models": microsoft_teams_api_models,
|
||||
"microsoft_teams.api.models.adaptive_card": microsoft_teams_api_models_adaptive_card,
|
||||
"microsoft_teams.api.models.invoke_response": microsoft_teams_api_models_invoke_response,
|
||||
"microsoft_teams.cards": microsoft_teams_cards,
|
||||
"microsoft_teams.apps.http": microsoft_teams_apps_http,
|
||||
"microsoft_teams.apps.http.adapter": microsoft_teams_apps_http_adapter,
|
||||
}.items():
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_teams_mock()
|
||||
|
||||
# Load plugins/platforms/teams/adapter.py under a unique module name
|
||||
# (plugin_adapter_teams) so it cannot collide with sibling plugin adapters.
|
||||
_teams_mod = load_plugin_adapter("teams")
|
||||
|
||||
_teams_mod.TEAMS_SDK_AVAILABLE = True
|
||||
_teams_mod.AIOHTTP_AVAILABLE = True
|
||||
|
||||
TeamsAdapter = _teams_mod.TeamsAdapter
|
||||
check_requirements = _teams_mod.check_requirements
|
||||
check_teams_requirements = _teams_mod.check_teams_requirements
|
||||
validate_config = _teams_mod.validate_config
|
||||
register = _teams_mod.register
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_config(**extra):
|
||||
return PlatformConfig(enabled=True, extra=extra)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Requirements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTeamsRequirements:
|
||||
def test_returns_false_when_sdk_missing(self, monkeypatch):
|
||||
monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", False)
|
||||
assert check_requirements() is False
|
||||
|
||||
def test_returns_false_when_aiohttp_missing(self, monkeypatch):
|
||||
monkeypatch.setattr(_teams_mod, "AIOHTTP_AVAILABLE", False)
|
||||
assert check_requirements() is False
|
||||
|
||||
def test_returns_true_when_deps_available(self, monkeypatch):
|
||||
monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", True)
|
||||
monkeypatch.setattr(_teams_mod, "AIOHTTP_AVAILABLE", True)
|
||||
assert check_requirements() is True
|
||||
|
||||
def test_alias_matches(self, monkeypatch):
|
||||
monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", True)
|
||||
monkeypatch.setattr(_teams_mod, "AIOHTTP_AVAILABLE", True)
|
||||
assert check_teams_requirements() is True
|
||||
|
||||
def test_validate_config_with_env(self, monkeypatch):
|
||||
monkeypatch.setenv("TEAMS_CLIENT_ID", "test-id")
|
||||
monkeypatch.setenv("TEAMS_CLIENT_SECRET", "test-secret")
|
||||
monkeypatch.setenv("TEAMS_TENANT_ID", "test-tenant")
|
||||
assert validate_config(_make_config()) is True
|
||||
|
||||
def test_validate_config_from_extra(self, monkeypatch):
|
||||
monkeypatch.delenv("TEAMS_CLIENT_ID", raising=False)
|
||||
monkeypatch.delenv("TEAMS_CLIENT_SECRET", raising=False)
|
||||
monkeypatch.delenv("TEAMS_TENANT_ID", raising=False)
|
||||
cfg = _make_config(client_id="id", client_secret="secret", tenant_id="tenant")
|
||||
assert validate_config(cfg) is True
|
||||
|
||||
def test_validate_config_missing(self, monkeypatch):
|
||||
monkeypatch.delenv("TEAMS_CLIENT_ID", raising=False)
|
||||
monkeypatch.delenv("TEAMS_CLIENT_SECRET", raising=False)
|
||||
monkeypatch.delenv("TEAMS_TENANT_ID", raising=False)
|
||||
assert validate_config(_make_config()) is False
|
||||
|
||||
def test_validate_config_missing_tenant(self, monkeypatch):
|
||||
monkeypatch.setenv("TEAMS_CLIENT_ID", "test-id")
|
||||
monkeypatch.setenv("TEAMS_CLIENT_SECRET", "test-secret")
|
||||
monkeypatch.delenv("TEAMS_TENANT_ID", raising=False)
|
||||
assert validate_config(_make_config()) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Adapter Init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTeamsAdapterInit:
|
||||
def test_reads_config_from_extra(self):
|
||||
config = _make_config(
|
||||
client_id="cfg-id",
|
||||
client_secret="cfg-secret",
|
||||
tenant_id="cfg-tenant",
|
||||
)
|
||||
adapter = TeamsAdapter(config)
|
||||
assert adapter._client_id == "cfg-id"
|
||||
assert adapter._client_secret == "cfg-secret"
|
||||
assert adapter._tenant_id == "cfg-tenant"
|
||||
|
||||
def test_falls_back_to_env_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("TEAMS_CLIENT_ID", "env-id")
|
||||
monkeypatch.setenv("TEAMS_CLIENT_SECRET", "env-secret")
|
||||
monkeypatch.setenv("TEAMS_TENANT_ID", "env-tenant")
|
||||
adapter = TeamsAdapter(_make_config())
|
||||
assert adapter._client_id == "env-id"
|
||||
assert adapter._client_secret == "env-secret"
|
||||
assert adapter._tenant_id == "env-tenant"
|
||||
|
||||
def test_default_port(self):
|
||||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant"))
|
||||
assert adapter._port == 3978
|
||||
|
||||
def test_custom_port_from_extra(self):
|
||||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant", port=4000))
|
||||
assert adapter._port == 4000
|
||||
|
||||
def test_custom_port_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("TEAMS_PORT", "5000")
|
||||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant"))
|
||||
assert adapter._port == 5000
|
||||
|
||||
def test_platform_value(self):
|
||||
adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant"))
|
||||
assert adapter.platform.value == "teams"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Plugin registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTeamsPluginRegistration:
|
||||
|
||||
def test_register_calls_ctx(self):
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
ctx.register_platform.assert_called_once()
|
||||
|
||||
def test_register_name(self):
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
kwargs = ctx.register_platform.call_args[1]
|
||||
assert kwargs["name"] == "teams"
|
||||
|
||||
def test_register_auth_env_vars(self):
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
kwargs = ctx.register_platform.call_args[1]
|
||||
assert kwargs["allowed_users_env"] == "TEAMS_ALLOWED_USERS"
|
||||
assert kwargs["allow_all_env"] == "TEAMS_ALLOW_ALL_USERS"
|
||||
|
||||
def test_register_max_message_length(self):
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
kwargs = ctx.register_platform.call_args[1]
|
||||
assert kwargs["max_message_length"] == 28000
|
||||
|
||||
def test_register_has_setup_fn(self):
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
kwargs = ctx.register_platform.call_args[1]
|
||||
assert callable(kwargs.get("setup_fn"))
|
||||
|
||||
def test_register_has_platform_hint(self):
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
kwargs = ctx.register_platform.call_args[1]
|
||||
assert kwargs.get("platform_hint")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Connect / Disconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTeamsConnect:
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_without_sdk(self, monkeypatch):
|
||||
monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", False)
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_without_credentials(self):
|
||||
adapter = TeamsAdapter(_make_config())
|
||||
adapter._client_id = ""
|
||||
adapter._client_secret = ""
|
||||
adapter._tenant_id = ""
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cleans_up(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._running = True
|
||||
mock_runner = AsyncMock()
|
||||
adapter._runner = mock_runner
|
||||
adapter._app = MagicMock()
|
||||
|
||||
await adapter.disconnect()
|
||||
assert adapter._running is False
|
||||
assert adapter._app is None
|
||||
assert adapter._runner is None
|
||||
mock_runner.cleanup.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTeamsSend:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_error_without_app(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = None
|
||||
result = await adapter.send("conv-id", "Hello")
|
||||
assert result.success is False
|
||||
assert "not initialized" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_calls_app_send(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
mock_result = MagicMock()
|
||||
mock_result.id = "msg-123"
|
||||
mock_app = MagicMock()
|
||||
mock_app.send = AsyncMock(return_value=mock_result)
|
||||
adapter._app = mock_app
|
||||
|
||||
result = await adapter.send("conv-id", "Hello")
|
||||
assert result.success is True
|
||||
assert result.message_id == "msg-123"
|
||||
mock_app.send.assert_awaited_once_with("conv-id", "Hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_handles_error(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
mock_app = MagicMock()
|
||||
mock_app.send = AsyncMock(side_effect=Exception("Network error"))
|
||||
adapter._app = mock_app
|
||||
|
||||
result = await adapter.send("conv-id", "Hello")
|
||||
assert result.success is False
|
||||
assert "Network error" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_typing(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
mock_app = MagicMock()
|
||||
mock_app.send = AsyncMock()
|
||||
adapter._app = mock_app
|
||||
|
||||
await adapter.send_typing("conv-id")
|
||||
mock_app.send.assert_awaited_once()
|
||||
call_args = mock_app.send.call_args
|
||||
assert call_args[0][0] == "conv-id"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Message Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTeamsMessageHandling:
|
||||
def _make_activity(
|
||||
self,
|
||||
*,
|
||||
text="Hello",
|
||||
from_id="user-123",
|
||||
from_aad_id="aad-456",
|
||||
from_name="Test User",
|
||||
conversation_id="19:abc@thread.v2",
|
||||
conversation_type="personal",
|
||||
tenant_id="tenant-789",
|
||||
activity_id="activity-001",
|
||||
attachments=None,
|
||||
):
|
||||
activity = MagicMock()
|
||||
activity.text = text
|
||||
activity.id = activity_id
|
||||
activity.from_ = MagicMock()
|
||||
activity.from_.id = from_id
|
||||
activity.from_.aad_object_id = from_aad_id
|
||||
activity.from_.name = from_name
|
||||
activity.conversation = MagicMock()
|
||||
activity.conversation.id = conversation_id
|
||||
activity.conversation.conversation_type = conversation_type
|
||||
activity.conversation.name = "Test Chat"
|
||||
activity.conversation.tenant_id = tenant_id
|
||||
activity.attachments = attachments or []
|
||||
return activity
|
||||
|
||||
def _make_ctx(self, activity):
|
||||
ctx = MagicMock()
|
||||
ctx.activity = activity
|
||||
return ctx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personal_message_creates_dm_event(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(conversation_type="personal")
|
||||
await adapter._on_message(self._make_ctx(activity))
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.source.chat_type == "dm"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_creates_group_event(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(conversation_type="groupChat")
|
||||
await adapter._on_message(self._make_ctx(activity))
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_message_creates_channel_event(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(conversation_type="channel")
|
||||
await adapter._on_message(self._make_ctx(activity))
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.source.chat_type == "channel"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_uses_aad_object_id(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(from_aad_id="aad-stable-id", from_id="teams-id")
|
||||
await adapter._on_message(self._make_ctx(activity))
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.source.user_id == "aad-stable-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_self_message_filtered(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(from_id="bot-id")
|
||||
await adapter._on_message(self._make_ctx(activity))
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_mention_stripped_from_text(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(
|
||||
text="<at>Hermes</at> what is the weather?",
|
||||
from_id="user-id",
|
||||
)
|
||||
await adapter._on_message(self._make_ctx(activity))
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.text == "what is the weather?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplication(self):
|
||||
adapter = TeamsAdapter(_make_config(
|
||||
client_id="bot-id", client_secret="secret", tenant_id="tenant",
|
||||
))
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.id = "bot-id"
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
activity = self._make_activity(activity_id="msg-dup-001", from_id="user-id")
|
||||
ctx = self._make_ctx(activity)
|
||||
|
||||
await adapter._on_message(ctx)
|
||||
await adapter._on_message(ctx)
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
@@ -453,6 +453,87 @@ class TestMediaGroups:
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVoice — outbound audio delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendVoice:
|
||||
"""Tests for TelegramAdapter.send_voice() routing across audio formats."""
|
||||
|
||||
@pytest.fixture()
|
||||
def connected_adapter(self, adapter):
|
||||
"""Adapter with a mock bot attached."""
|
||||
bot = AsyncMock()
|
||||
adapter._bot = bot
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flac_falls_back_to_document(self, connected_adapter, tmp_path):
|
||||
"""Telegram sendAudio does not accept FLAC — must fall back to sendDocument."""
|
||||
audio_file = tmp_path / "clip.flac"
|
||||
audio_file.write_bytes(b"fLaC" + b"\x00" * 32)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 101
|
||||
connected_adapter._bot.send_voice = AsyncMock()
|
||||
connected_adapter._bot.send_audio = AsyncMock()
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = await connected_adapter.send_voice(
|
||||
chat_id="12345",
|
||||
audio_path=str(audio_file),
|
||||
caption="Audio",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "101"
|
||||
connected_adapter._bot.send_document.assert_awaited_once()
|
||||
connected_adapter._bot.send_audio.assert_not_awaited()
|
||||
connected_adapter._bot.send_voice.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wav_falls_back_to_document(self, connected_adapter, tmp_path):
|
||||
"""Telegram sendAudio does not accept WAV — must fall back to sendDocument."""
|
||||
audio_file = tmp_path / "clip.wav"
|
||||
audio_file.write_bytes(b"RIFF" + b"\x00" * 32)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 102
|
||||
connected_adapter._bot.send_voice = AsyncMock()
|
||||
connected_adapter._bot.send_audio = AsyncMock()
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = await connected_adapter.send_voice(
|
||||
chat_id="12345",
|
||||
audio_path=str(audio_file),
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
connected_adapter._bot.send_document.assert_awaited_once()
|
||||
connected_adapter._bot.send_audio.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp3_routes_to_send_audio(self, connected_adapter, tmp_path):
|
||||
"""MP3 is Telegram-sendAudio-compatible."""
|
||||
audio_file = tmp_path / "clip.mp3"
|
||||
audio_file.write_bytes(b"ID3" + b"\x00" * 32)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 103
|
||||
connected_adapter._bot.send_voice = AsyncMock()
|
||||
connected_adapter._bot.send_audio = AsyncMock(return_value=mock_msg)
|
||||
connected_adapter._bot.send_document = AsyncMock()
|
||||
|
||||
result = await connected_adapter.send_voice(
|
||||
chat_id="12345",
|
||||
audio_path=str(audio_file),
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
connected_adapter._bot.send_audio.assert_awaited_once()
|
||||
connected_adapter._bot.send_document.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument — outbound file attachment delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -5,7 +5,14 @@ from unittest.mock import AsyncMock
|
||||
from gateway.config import Platform, PlatformConfig, load_gateway_config
|
||||
|
||||
|
||||
def _make_adapter(require_mention=None, free_response_chats=None, mention_patterns=None, ignored_threads=None):
|
||||
def _make_adapter(
|
||||
require_mention=None,
|
||||
free_response_chats=None,
|
||||
mention_patterns=None,
|
||||
ignored_threads=None,
|
||||
allow_from=None,
|
||||
group_allow_from=None,
|
||||
):
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
extra = {}
|
||||
@@ -17,6 +24,10 @@ def _make_adapter(require_mention=None, free_response_chats=None, mention_patter
|
||||
extra["mention_patterns"] = mention_patterns
|
||||
if ignored_threads is not None:
|
||||
extra["ignored_threads"] = ignored_threads
|
||||
if allow_from is not None:
|
||||
extra["allow_from"] = allow_from
|
||||
if group_allow_from is not None:
|
||||
extra["group_allow_from"] = group_allow_from
|
||||
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
@@ -34,6 +45,7 @@ def _group_message(
|
||||
text="hello",
|
||||
*,
|
||||
chat_id=-100,
|
||||
from_user_id=111,
|
||||
thread_id=None,
|
||||
reply_to_bot=False,
|
||||
entities=None,
|
||||
@@ -50,10 +62,24 @@ def _group_message(
|
||||
caption_entities=caption_entities or [],
|
||||
message_thread_id=thread_id,
|
||||
chat=SimpleNamespace(id=chat_id, type="group"),
|
||||
from_user=SimpleNamespace(id=from_user_id),
|
||||
reply_to_message=reply_to_message,
|
||||
)
|
||||
|
||||
|
||||
def _dm_message(text="hello", *, from_user_id=111):
|
||||
return SimpleNamespace(
|
||||
text=text,
|
||||
caption=None,
|
||||
entities=[],
|
||||
caption_entities=[],
|
||||
message_thread_id=None,
|
||||
chat=SimpleNamespace(id=from_user_id, type="private"),
|
||||
from_user=SimpleNamespace(id=from_user_id),
|
||||
reply_to_message=None,
|
||||
)
|
||||
|
||||
|
||||
def _mention_entity(text, mention="@hermes_bot"):
|
||||
offset = text.index(mention)
|
||||
return SimpleNamespace(type="mention", offset=offset, length=len(mention))
|
||||
@@ -173,6 +199,68 @@ def test_config_bridges_telegram_group_settings(monkeypatch, tmp_path):
|
||||
assert __import__("os").environ["TELEGRAM_FREE_RESPONSE_CHATS"] == "-123"
|
||||
|
||||
|
||||
def test_config_bridges_telegram_user_allowlists(monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"telegram:\n"
|
||||
" allow_from:\n"
|
||||
" - \"111\"\n"
|
||||
" - \"222\"\n"
|
||||
" group_allow_from:\n"
|
||||
" - \"333\"\n"
|
||||
" group_allowed_chats:\n"
|
||||
" - \"-100\"\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("TELEGRAM_ALLOWED_USERS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_GROUP_ALLOWED_USERS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_GROUP_ALLOWED_CHATS", raising=False)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
assert __import__("os").environ["TELEGRAM_ALLOWED_USERS"] == "111,222"
|
||||
assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_USERS"] == "333"
|
||||
assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_CHATS"] == "-100"
|
||||
|
||||
|
||||
def test_config_env_overrides_telegram_user_allowlists(monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"telegram:\n"
|
||||
" allow_from: \"111\"\n"
|
||||
" group_allow_from: \"222\"\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("TELEGRAM_ALLOWED_USERS", "999")
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "888")
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
assert __import__("os").environ["TELEGRAM_ALLOWED_USERS"] == "999"
|
||||
assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_USERS"] == "888"
|
||||
|
||||
|
||||
def test_dm_allow_from_is_enforced_by_gateway_authorization_not_trigger_gate():
|
||||
adapter = _make_adapter(allow_from=["111", "222"])
|
||||
|
||||
assert adapter._should_process_message(_dm_message("hello", from_user_id=111)) is True
|
||||
assert adapter._should_process_message(_dm_message("hello", from_user_id=333)) is True
|
||||
|
||||
|
||||
def test_group_allow_from_is_enforced_by_gateway_authorization_not_trigger_gate():
|
||||
adapter = _make_adapter(group_allow_from=["111"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello", from_user_id=333)) is True
|
||||
|
||||
|
||||
def test_config_bridges_telegram_ignored_threads(monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
195
tests/gateway/test_tts_media_routing.py
Normal file
195
tests/gateway/test_tts_media_routing.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Tests for cross-platform audio/voice media routing.
|
||||
|
||||
These tests pin the expected delivery path for audio media files across
|
||||
Telegram (where Bot-API sendAudio only accepts MP3/M4A and .ogg/.opus
|
||||
only renders as a voice bubble when explicitly flagged) and via
|
||||
``GatewayRunner._deliver_media_from_response``.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class _MediaRoutingAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, content=None, **kwargs):
|
||||
return SendResult(success=True, message_id="text")
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id, "type": "dm"}
|
||||
|
||||
|
||||
def _event(thread_id=None):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="chat-1",
|
||||
chat_type="dm",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return MessageEvent(
|
||||
text="make speech",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="msg-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender():
|
||||
adapter = _MediaRoutingAdapter()
|
||||
event = _event()
|
||||
adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.flac")
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
|
||||
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
|
||||
|
||||
await adapter._process_message_background(event, build_session_key(event.source))
|
||||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.flac",
|
||||
metadata=None,
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender():
|
||||
adapter = _MediaRoutingAdapter()
|
||||
event = _event()
|
||||
adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.ogg")
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
|
||||
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
|
||||
|
||||
await adapter._process_message_background(event, build_session_key(event.source))
|
||||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.ogg",
|
||||
metadata=None,
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender():
|
||||
adapter = _MediaRoutingAdapter()
|
||||
event = _event()
|
||||
adapter._message_handler = AsyncMock(
|
||||
return_value="[[audio_as_voice]]\nMEDIA:/tmp/speech.ogg"
|
||||
)
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
|
||||
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
|
||||
|
||||
await adapter._process_message_background(event, build_session_key(event.source))
|
||||
|
||||
adapter.send_voice.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
audio_path="/tmp/speech.ogg",
|
||||
metadata=None,
|
||||
)
|
||||
adapter.send_document.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender():
|
||||
event = _event(thread_id="topic-1")
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
extract_images=BasePlatformAdapter.extract_images,
|
||||
extract_local_files=BasePlatformAdapter.extract_local_files,
|
||||
send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")),
|
||||
send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")),
|
||||
send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")),
|
||||
send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")),
|
||||
)
|
||||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
object(),
|
||||
"MEDIA:/tmp/speech.flac",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.flac",
|
||||
metadata={"thread_id": "topic-1"},
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender():
|
||||
event = _event(thread_id="topic-1")
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
extract_images=BasePlatformAdapter.extract_images,
|
||||
extract_local_files=BasePlatformAdapter.extract_local_files,
|
||||
send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")),
|
||||
send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")),
|
||||
send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")),
|
||||
send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")),
|
||||
)
|
||||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
object(),
|
||||
"MEDIA:/tmp/speech.ogg",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.ogg",
|
||||
metadata={"thread_id": "topic-1"},
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender():
|
||||
"""MP3 audio on Telegram must go through send_voice (which routes to
|
||||
sendAudio internally); Telegram accepts MP3 for the audio player."""
|
||||
event = _event(thread_id="topic-1")
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
extract_images=BasePlatformAdapter.extract_images,
|
||||
extract_local_files=BasePlatformAdapter.extract_local_files,
|
||||
send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")),
|
||||
send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")),
|
||||
send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")),
|
||||
send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")),
|
||||
)
|
||||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
object(),
|
||||
"MEDIA:/tmp/speech.mp3",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_voice.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
audio_path="/tmp/speech.mp3",
|
||||
metadata={"thread_id": "topic-1"},
|
||||
)
|
||||
adapter.send_document.assert_not_awaited()
|
||||
@@ -16,6 +16,8 @@ def _clear_auth_env(monkeypatch) -> None:
|
||||
"WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS",
|
||||
"SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"TELEGRAM_GROUP_ALLOWED_CHATS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS",
|
||||
"MATTERMOST_ALLOWED_USERS",
|
||||
@@ -178,7 +180,109 @@ def test_qq_group_allowlist_does_not_authorize_other_groups(monkeypatch):
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_telegram_group_allowlist_authorizes_forum_chat_without_user_allowlist(monkeypatch):
|
||||
def test_telegram_group_user_allowlist_authorizes_forum_sender_without_dm_allowlist(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="999",
|
||||
chat_id="-1001878443972",
|
||||
user_name="tester",
|
||||
chat_type="forum",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_telegram_group_user_allowlist_rejects_other_senders(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="123",
|
||||
chat_id="-1001878443972",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_telegram_group_user_allowlist_wildcard_authorizes_any_sender(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "*")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="123",
|
||||
chat_id="-1001878443972",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_telegram_group_user_allowlist_does_not_authorize_dms(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="999",
|
||||
chat_id="999",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_telegram_group_chat_allowlist_authorizes_group_chat_without_user_allowlist(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_CHATS", "-1001878443972")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="999",
|
||||
chat_id="-1001878443972",
|
||||
user_name="tester",
|
||||
chat_type="forum",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_telegram_group_users_legacy_chat_ids_still_authorize(monkeypatch):
|
||||
"""Backward-compat: PR #15027 shipped TELEGRAM_GROUP_ALLOWED_USERS as a
|
||||
chat-ID allowlist. PR #17686 renamed it to sender IDs and added
|
||||
TELEGRAM_GROUP_ALLOWED_CHATS. Users on the old guidance must keep working:
|
||||
chat-ID-shaped values (starting with "-") in the _USERS var are honored as
|
||||
chat IDs with a deprecation warning.
|
||||
"""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "-1001878443972")
|
||||
|
||||
@@ -198,6 +302,58 @@ def test_telegram_group_allowlist_authorizes_forum_chat_without_user_allowlist(m
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_telegram_group_users_legacy_does_not_cross_chats(monkeypatch):
|
||||
"""Legacy chat-ID value only authorizes the listed chat, not any group."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "-1001878443972")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="999",
|
||||
chat_id="-1009999999999",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_telegram_group_users_mixed_sender_and_legacy_chat(monkeypatch):
|
||||
"""Mixed values: positive user ID gates senders; negative chat ID gates chat."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999,-1001878443972")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
|
||||
# Legacy chat ID path: any sender in the listed chat is authorized
|
||||
legacy_chat_source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="123",
|
||||
chat_id="-1001878443972",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
assert runner._is_user_authorized(legacy_chat_source) is True
|
||||
|
||||
# Sender path: listed sender user ID authorized in any group
|
||||
sender_source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="999",
|
||||
chat_id="-1009999999999",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
assert runner._is_user_authorized(sender_source) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_dm_pairs_by_default(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
|
||||
@@ -758,3 +758,33 @@ class TestWeixinVoiceSending:
|
||||
assert voice_item["encode_type"] == 6
|
||||
assert voice_item["sample_rate"] == 24000
|
||||
assert voice_item["bits_per_sample"] == 16
|
||||
|
||||
|
||||
class TestIsStaleSessionRet:
|
||||
"""Regression test for #17228: distinguish stale-session ret=-2 from rate-limit ret=-2."""
|
||||
|
||||
def test_ret_minus_2_with_unknown_error_is_stale(self):
|
||||
assert weixin._is_stale_session_ret(-2, None, "unknown error") is True
|
||||
|
||||
def test_errcode_minus_2_with_unknown_error_is_stale(self):
|
||||
assert weixin._is_stale_session_ret(None, -2, "unknown error") is True
|
||||
|
||||
def test_unknown_error_case_insensitive(self):
|
||||
assert weixin._is_stale_session_ret(-2, None, "Unknown Error") is True
|
||||
|
||||
def test_ret_minus_2_with_freq_limit_is_not_stale(self):
|
||||
# Genuine rate limit — must NOT be treated as stale session.
|
||||
assert weixin._is_stale_session_ret(-2, None, "freq limit") is False
|
||||
|
||||
def test_ret_minus_2_with_no_errmsg_is_not_stale(self):
|
||||
assert weixin._is_stale_session_ret(-2, None, None) is False
|
||||
assert weixin._is_stale_session_ret(-2, None, "") is False
|
||||
|
||||
def test_errcode_minus_14_is_not_matched_here(self):
|
||||
# -14 is handled by the separate SESSION_EXPIRED_ERRCODE path; the
|
||||
# helper only disambiguates -2 from a genuine rate limit.
|
||||
assert weixin._is_stale_session_ret(-14, None, "session expired") is False
|
||||
|
||||
def test_success_codes_are_not_stale(self):
|
||||
assert weixin._is_stale_session_ret(0, 0, "") is False
|
||||
assert weixin._is_stale_session_ret(None, None, "unknown error") is False
|
||||
|
||||
@@ -1097,3 +1097,63 @@ class TestHuggingFaceModels:
|
||||
from hermes_cli.models import _PROVIDER_LABELS
|
||||
assert "huggingface" in _PROVIDER_LABELS
|
||||
assert _PROVIDER_LABELS["huggingface"] == "Hugging Face"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MiniMax OAuth provider tests (added by feat/minimax-oauth-provider)
|
||||
# =============================================================================
|
||||
|
||||
class TestMinimaxOAuthProvider:
|
||||
"""Tests for the minimax-oauth OAuth provider."""
|
||||
|
||||
def test_minimax_oauth_in_provider_registry(self):
|
||||
assert "minimax-oauth" in PROVIDER_REGISTRY
|
||||
pconfig = PROVIDER_REGISTRY["minimax-oauth"]
|
||||
assert pconfig.auth_type == "oauth_minimax"
|
||||
assert pconfig.id == "minimax-oauth"
|
||||
|
||||
def test_minimax_oauth_has_correct_endpoints(self):
|
||||
from hermes_cli.auth import (
|
||||
MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
MINIMAX_OAUTH_GLOBAL_INFERENCE,
|
||||
MINIMAX_OAUTH_CN_BASE,
|
||||
MINIMAX_OAUTH_CN_INFERENCE,
|
||||
)
|
||||
pconfig = PROVIDER_REGISTRY["minimax-oauth"]
|
||||
assert pconfig.portal_base_url == MINIMAX_OAUTH_GLOBAL_BASE
|
||||
assert pconfig.inference_base_url == MINIMAX_OAUTH_GLOBAL_INFERENCE
|
||||
assert pconfig.extra["cn_portal_base_url"] == MINIMAX_OAUTH_CN_BASE
|
||||
assert pconfig.extra["cn_inference_base_url"] == MINIMAX_OAUTH_CN_INFERENCE
|
||||
|
||||
def test_minimax_oauth_alias_resolves_portal(self):
|
||||
result = resolve_provider("minimax-portal")
|
||||
assert result == "minimax-oauth"
|
||||
|
||||
def test_minimax_oauth_alias_resolves_global(self):
|
||||
result = resolve_provider("minimax-global")
|
||||
assert result == "minimax-oauth"
|
||||
|
||||
def test_minimax_oauth_alias_resolves_underscore(self):
|
||||
result = resolve_provider("minimax_oauth")
|
||||
assert result == "minimax-oauth"
|
||||
|
||||
def test_minimax_oauth_listed_in_canonical_providers(self):
|
||||
from hermes_cli.models import CANONICAL_PROVIDERS
|
||||
slugs = [p.slug for p in CANONICAL_PROVIDERS]
|
||||
assert "minimax-oauth" in slugs
|
||||
|
||||
def test_minimax_oauth_models_alias_in_models_py(self):
|
||||
from hermes_cli.models import _PROVIDER_ALIASES
|
||||
assert _PROVIDER_ALIASES.get("minimax-portal") == "minimax-oauth"
|
||||
assert _PROVIDER_ALIASES.get("minimax-global") == "minimax-oauth"
|
||||
assert _PROVIDER_ALIASES.get("minimax_oauth") == "minimax-oauth"
|
||||
|
||||
def test_minimax_oauth_has_models(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
models = _PROVIDER_MODELS.get("minimax-oauth", [])
|
||||
assert len(models) >= 1
|
||||
|
||||
def test_minimax_oauth_aux_model_registered(self):
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert "minimax-oauth" in _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert _API_KEY_PROVIDER_AUX_MODELS["minimax-oauth"] # non-empty
|
||||
|
||||
@@ -1446,23 +1446,36 @@ def test_seed_custom_pool_respects_config_suppression(tmp_path, monkeypatch):
|
||||
def test_credential_sources_registry_has_expected_steps():
|
||||
"""Sanity check — the registry contains the expected RemovalSteps.
|
||||
|
||||
Guards against accidentally dropping a step during future refactors.
|
||||
If you add a new credential source, add it to the expected set below.
|
||||
Adding a new credential source is routine, so this is a structural
|
||||
invariant check (every step has a description, every step is unique,
|
||||
core steps are present) rather than a frozen snapshot. Frozen
|
||||
snapshots of catalog-like data violate the AGENTS.md "don't write
|
||||
change-detector tests" rule — they break every time someone adds a
|
||||
provider.
|
||||
"""
|
||||
from agent.credential_sources import _REGISTRY
|
||||
|
||||
descriptions = {step.description for step in _REGISTRY}
|
||||
expected = {
|
||||
descriptions = [step.description for step in _REGISTRY]
|
||||
# No empty descriptions, no duplicates.
|
||||
assert all(d for d in descriptions), "Every removal step must have a description"
|
||||
assert len(descriptions) == len(set(descriptions)), (
|
||||
f"Registry has duplicate step descriptions: {descriptions}"
|
||||
)
|
||||
# Core steps must be present — these are the ones the rest of the code
|
||||
# assumes exist. When deliberately dropping one, update this list.
|
||||
required = {
|
||||
"gh auth token / COPILOT_GITHUB_TOKEN / GH_TOKEN",
|
||||
"Any env-seeded credential (XAI_API_KEY, DEEPSEEK_API_KEY, etc.)",
|
||||
"~/.claude/.credentials.json",
|
||||
"~/.hermes/.anthropic_oauth.json",
|
||||
"auth.json providers.nous",
|
||||
"auth.json providers.openai-codex + ~/.codex/auth.json",
|
||||
"auth.json providers.minimax-oauth",
|
||||
"~/.qwen/oauth_creds.json",
|
||||
"Custom provider config.yaml api_key field",
|
||||
}
|
||||
assert descriptions == expected, f"Registry mismatch. Got: {descriptions}"
|
||||
missing = required - set(descriptions)
|
||||
assert not missing, f"Registry missing required steps: {missing}"
|
||||
|
||||
|
||||
def test_credential_sources_find_step_returns_none_for_manual():
|
||||
|
||||
@@ -526,6 +526,11 @@ class TestCmdMigrate:
|
||||
class TestCmdCleanup:
|
||||
"""Test the cleanup command handler."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_openclaw_running(self):
|
||||
with patch.object(claw_mod, "_detect_openclaw_processes", return_value=[]):
|
||||
yield
|
||||
|
||||
def test_no_dirs_found(self, tmp_path, capsys):
|
||||
args = Namespace(source=None, dry_run=False, yes=False)
|
||||
with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[]):
|
||||
|
||||
@@ -72,7 +72,10 @@ class TestLoadConfigExpansion:
|
||||
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "gsk-test-key")
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "1234567:ABC-token")
|
||||
monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file)
|
||||
# Patch the imported function's own globals. Other tests may reload
|
||||
# hermes_cli.config, making string-target monkeypatches hit a different
|
||||
# module object than this collection-time imported load_config().
|
||||
monkeypatch.setitem(load_config.__globals__, "get_config_path", lambda: config_file)
|
||||
|
||||
config = load_config()
|
||||
|
||||
@@ -86,7 +89,7 @@ class TestLoadConfigExpansion:
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.delenv("NOT_SET_XYZ_123", raising=False)
|
||||
monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file)
|
||||
monkeypatch.setitem(load_config.__globals__, "get_config_path", lambda: config_file)
|
||||
|
||||
config = load_config()
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ def test_get_container_exec_info_defaults():
|
||||
)
|
||||
|
||||
with patch("hermes_constants.is_container", return_value=False), \
|
||||
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \
|
||||
patch.dict(get_container_exec_info.__globals__, {"get_hermes_home": lambda: hermes_home}), \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_DEV", None)
|
||||
info = get_container_exec_info()
|
||||
|
||||
@@ -7,9 +7,10 @@ WEB_SRC = Path(__file__).resolve().parents[2] / "web" / "src"
|
||||
|
||||
def test_dashboard_does_not_import_nous_ui_root_barrel():
|
||||
offenders = []
|
||||
for path in WEB_SRC.rglob("*.tsx"):
|
||||
content = path.read_text(encoding="utf-8")
|
||||
if 'from "@nous-research/ui"' in content or "from '@nous-research/ui'" in content:
|
||||
offenders.append(str(path.relative_to(WEB_SRC)))
|
||||
for ext in ("*.tsx", "*.ts"):
|
||||
for path in WEB_SRC.rglob(ext):
|
||||
content = path.read_text(encoding="utf-8")
|
||||
if 'from "@nous-research/ui"' in content or "from '@nous-research/ui'" in content:
|
||||
offenders.append(str(path.relative_to(WEB_SRC)))
|
||||
|
||||
assert offenders == []
|
||||
|
||||
181
tests/hermes_cli/test_dashboard_lifecycle_flags.py
Normal file
181
tests/hermes_cli/test_dashboard_lifecycle_flags.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Tests for ``hermes dashboard --stop`` / ``--status`` flags.
|
||||
|
||||
These flags share the detection + kill path with the post-``hermes update``
|
||||
cleanup, so the heavy coverage of SIGTERM / SIGKILL / Windows taskkill lives
|
||||
in ``test_update_stale_dashboard.py``. This file just verifies the flag
|
||||
dispatch: argparse wiring, no-op when nothing is running, and correct
|
||||
exit codes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.main import cmd_dashboard, _report_dashboard_status
|
||||
|
||||
|
||||
def _ns(**kw):
|
||||
"""Build an argparse.Namespace with dashboard defaults plus overrides."""
|
||||
defaults = dict(
|
||||
port=9119, host="127.0.0.1", no_open=False, insecure=False,
|
||||
tui=False, stop=False, status=False,
|
||||
)
|
||||
defaults.update(kw)
|
||||
return argparse.Namespace(**defaults)
|
||||
|
||||
|
||||
class TestDashboardStatus:
|
||||
def test_status_no_processes(self, capsys):
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(status=True))
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "No hermes dashboard processes running" in out
|
||||
|
||||
def test_status_with_processes(self, capsys):
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[12345, 12346]), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(status=True))
|
||||
# Status is informational — always exits 0.
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "2 hermes dashboard process(es) running" in out
|
||||
assert "PID 12345" in out
|
||||
assert "PID 12346" in out
|
||||
|
||||
def test_status_does_not_try_to_import_fastapi(self):
|
||||
"""`--status` must not require dashboard runtime deps — it's a
|
||||
process-table scan only. We prove this by making fastapi import
|
||||
fail and confirming --status still succeeds."""
|
||||
orig_import = __import__
|
||||
def fake_import(name, *a, **kw):
|
||||
if name == "fastapi":
|
||||
raise ImportError("fastapi missing")
|
||||
return orig_import(name, *a, **kw)
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
patch("builtins.__import__", side_effect=fake_import), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(status=True))
|
||||
assert exc.value.code == 0
|
||||
|
||||
|
||||
class TestDashboardStop:
|
||||
def test_stop_when_nothing_running(self, capsys):
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(stop=True))
|
||||
assert exc.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "No hermes dashboard processes running" in out
|
||||
|
||||
def test_stop_kills_and_exits_zero_when_all_killed(self, capsys):
|
||||
"""After the kill, if the second scan returns empty we exit 0."""
|
||||
# First scan: finds two processes. Second (verification) scan: empty.
|
||||
scans = iter([[12345, 12346], []])
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
side_effect=lambda: next(scans)), \
|
||||
patch("hermes_cli.main._kill_stale_dashboard_processes") as mock_kill, \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(stop=True))
|
||||
mock_kill.assert_called_once()
|
||||
# --stop should pass a reason so the output doesn't say "running
|
||||
# backend no longer matches the updated frontend" (that wording is
|
||||
# for the post-`hermes update` path).
|
||||
kwargs = mock_kill.call_args.kwargs
|
||||
assert "reason" in kwargs
|
||||
assert "stop" in kwargs["reason"].lower()
|
||||
assert exc.value.code == 0
|
||||
|
||||
def test_stop_exits_nonzero_if_kill_leaves_survivors(self):
|
||||
"""If the second scan still finds PIDs, we exit 1 so scripts can
|
||||
detect that the stop didn't succeed (e.g. permission denied)."""
|
||||
scans = iter([[12345], [12345]]) # both scans find the same PID
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
side_effect=lambda: next(scans)), \
|
||||
patch("hermes_cli.main._kill_stale_dashboard_processes"), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(stop=True))
|
||||
assert exc.value.code == 1
|
||||
|
||||
def test_stop_does_not_try_to_import_fastapi(self):
|
||||
"""Like --status, --stop must work without dashboard runtime deps."""
|
||||
orig_import = __import__
|
||||
def fake_import(name, *a, **kw):
|
||||
if name == "fastapi":
|
||||
raise ImportError("fastapi missing")
|
||||
return orig_import(name, *a, **kw)
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
patch("builtins.__import__", side_effect=fake_import), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
cmd_dashboard(_ns(stop=True))
|
||||
assert exc.value.code == 0
|
||||
|
||||
|
||||
class TestLifecycleFlagsTakePrecedence:
|
||||
"""If both --stop and --status are set, --status wins (it's listed
|
||||
first in cmd_dashboard). Neither is allowed to fall through to the
|
||||
server-start path, which is the critical safety property — a user
|
||||
who typed ``hermes dashboard --stop`` must not end up ALSO starting
|
||||
a new server."""
|
||||
|
||||
def test_status_wins_over_stop(self, capsys):
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
patch("hermes_cli.main._kill_stale_dashboard_processes") as mock_kill, \
|
||||
pytest.raises(SystemExit):
|
||||
cmd_dashboard(_ns(status=True, stop=True))
|
||||
# Kill path must NOT run when --status is also set.
|
||||
mock_kill.assert_not_called()
|
||||
|
||||
def test_stop_does_not_fall_through_to_server_start(self):
|
||||
"""Covers the worst-case regression: if --stop ever stopped exiting
|
||||
early, the user would start the dashboard they just asked to stop."""
|
||||
called = {"start": False}
|
||||
def fake_start_server(**kw):
|
||||
called["start"] = True
|
||||
|
||||
# Provide a fake web_server module so the import doesn't matter.
|
||||
fake_ws = MagicMock()
|
||||
fake_ws.start_server = fake_start_server
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
patch.dict(sys.modules, {"hermes_cli.web_server": fake_ws}), \
|
||||
pytest.raises(SystemExit):
|
||||
cmd_dashboard(_ns(stop=True))
|
||||
assert called["start"] is False
|
||||
|
||||
|
||||
class TestArgparseWiring:
|
||||
"""Confirm the flags are exposed via the real argparse tree so
|
||||
``hermes dashboard --stop`` / ``--status`` actually parse."""
|
||||
|
||||
def test_flags_are_registered(self):
|
||||
from hermes_cli.main import main as _cli_main # noqa: F401
|
||||
# Rebuild the argparse tree by re-running the section of main()
|
||||
# that builds it. Cheapest way: introspect via --help on the
|
||||
# already-built parser would require refactoring; instead we
|
||||
# parse the flags directly via a minimal replay.
|
||||
import importlib
|
||||
mod = importlib.import_module("hermes_cli.main")
|
||||
# Find the dashboard_parser instance by running build logic would
|
||||
# be too invasive. Instead parse args as if via the CLI by
|
||||
# intercepting parse_args. This is overkill for a smoke test —
|
||||
# we just want to know the flags don't KeyError.
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[]), \
|
||||
pytest.raises(SystemExit) as exc:
|
||||
mod.cmd_dashboard(_ns(status=True))
|
||||
assert exc.value.code == 0
|
||||
@@ -161,6 +161,38 @@ def test_check_gateway_service_linger_skips_when_service_not_installed(monkeypat
|
||||
assert issues == []
|
||||
|
||||
|
||||
def test_doctor_reports_vercel_backend_diagnostics(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox")
|
||||
monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", "python3.13")
|
||||
monkeypatch.setenv("TERMINAL_CONTAINER_DISK", "2048")
|
||||
monkeypatch.setenv("VERCEL_TOKEN", "super-secret-value")
|
||||
monkeypatch.delenv("VERCEL_PROJECT_ID", raising=False)
|
||||
monkeypatch.setenv("VERCEL_TEAM_ID", "team")
|
||||
monkeypatch.setattr(doctor_mod.importlib.util, "find_spec", lambda name: object() if name == "vercel" else None)
|
||||
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stdout(buf):
|
||||
doctor_mod.run_doctor(Namespace(fix=False))
|
||||
|
||||
out = buf.getvalue()
|
||||
assert "Vercel runtime" in out
|
||||
assert "python3.13" in out
|
||||
assert "Vercel custom disk unsupported" in out
|
||||
assert "Vercel auth incomplete" in out
|
||||
assert "VERCEL_PROJECT_ID" in out
|
||||
assert "Vercel auth mode: incomplete access token" in out
|
||||
assert "Vercel auth present env: VERCEL_TOKEN, VERCEL_TEAM_ID" in out
|
||||
assert "Vercel auth missing env: VERCEL_PROJECT_ID" in out
|
||||
assert "super-secret-value" not in out
|
||||
assert "snapshot filesystem only" in out
|
||||
|
||||
|
||||
# ── Memory provider section (doctor should only check the *active* provider) ──
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,26 @@ from gateway.restart import (
|
||||
)
|
||||
|
||||
|
||||
class TestUserSystemdPrivateSocketPreflight:
|
||||
def test_preflight_accepts_private_socket_without_dbus_bus(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "_ensure_user_systemd_env", lambda: None)
|
||||
monkeypatch.setattr(gateway_cli, "_user_dbus_socket_path", lambda: Path("/tmp/missing-bus"))
|
||||
monkeypatch.setattr(gateway_cli, "_user_systemd_private_socket_path", lambda: Path("/tmp/private-socket"))
|
||||
monkeypatch.setattr(Path, "exists", lambda self: str(self) == "/tmp/private-socket")
|
||||
|
||||
gateway_cli._preflight_user_systemd(auto_enable_linger=False)
|
||||
|
||||
def test_wait_for_user_dbus_socket_accepts_private_socket(self, monkeypatch):
|
||||
calls = []
|
||||
monkeypatch.setattr(gateway_cli, "_ensure_user_systemd_env", lambda: calls.append("env"))
|
||||
monkeypatch.setattr(gateway_cli, "_user_dbus_socket_path", lambda: Path("/tmp/missing-bus"))
|
||||
monkeypatch.setattr(gateway_cli, "_user_systemd_private_socket_path", lambda: Path("/tmp/private-socket"))
|
||||
monkeypatch.setattr(Path, "exists", lambda self: str(self) == "/tmp/private-socket")
|
||||
|
||||
assert gateway_cli._wait_for_user_dbus_socket(timeout=0.1) is True
|
||||
assert calls == ["env"]
|
||||
|
||||
|
||||
class TestSystemdServiceRefresh:
|
||||
def test_systemd_install_repairs_outdated_unit_without_force(self, tmp_path, monkeypatch):
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
@@ -235,7 +255,8 @@ class TestLaunchdServiceRecovery:
|
||||
target = f"{domain}/{label}"
|
||||
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
if cmd and cmd[0] == "launchctl":
|
||||
calls.append(cmd)
|
||||
if cmd == ["launchctl", "kickstart", target] and calls.count(cmd) == 1:
|
||||
raise gateway_cli.subprocess.CalledProcessError(3, cmd, stderr="Could not find service")
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
@@ -262,7 +283,8 @@ class TestLaunchdServiceRecovery:
|
||||
target = f"{domain}/{label}"
|
||||
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
if cmd and cmd[0] == "launchctl":
|
||||
calls.append(cmd)
|
||||
if cmd == ["launchctl", "kickstart", target] and calls.count(cmd) == 1:
|
||||
raise gateway_cli.subprocess.CalledProcessError(113, cmd, stderr="Could not find service")
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
@@ -1105,6 +1127,10 @@ class TestPreflightUserSystemd:
|
||||
gateway_cli, "_user_dbus_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: True})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_user_systemd_private_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
# Should not raise, no subprocess calls needed.
|
||||
gateway_cli._preflight_user_systemd()
|
||||
|
||||
@@ -1114,6 +1140,10 @@ class TestPreflightUserSystemd:
|
||||
gateway_cli, "_user_dbus_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_user_systemd_private_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_linger_status", lambda: (False, ""),
|
||||
)
|
||||
@@ -1142,6 +1172,10 @@ class TestPreflightUserSystemd:
|
||||
gateway_cli, "_user_dbus_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_user_systemd_private_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_linger_status",
|
||||
lambda: (None, "loginctl not found"),
|
||||
@@ -1159,6 +1193,10 @@ class TestPreflightUserSystemd:
|
||||
gateway_cli, "_user_dbus_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_user_systemd_private_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_linger_status", lambda: (True, ""),
|
||||
)
|
||||
@@ -1177,6 +1215,10 @@ class TestPreflightUserSystemd:
|
||||
gateway_cli, "_user_dbus_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_user_systemd_private_socket_path",
|
||||
lambda: type("P", (), {"exists": lambda self: False})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_systemd_linger_status", lambda: (False, ""),
|
||||
)
|
||||
|
||||
@@ -224,22 +224,21 @@ class TestArgparseFlagsRegistered:
|
||||
assert args.ignore_rules is True
|
||||
|
||||
def test_main_py_registers_both_flags(self):
|
||||
"""E2E: the real hermes_cli/main.py parser accepts both flags.
|
||||
"""E2E: the real hermes parser accepts both flags."""
|
||||
from hermes_cli._parser import build_top_level_parser
|
||||
|
||||
We invoke the real argparse tree builder from hermes_cli.main.
|
||||
"""
|
||||
import hermes_cli.main as hm
|
||||
parser, _subparsers, chat_parser = build_top_level_parser()
|
||||
|
||||
top_dests = {a.dest for a in parser._actions}
|
||||
chat_dests = {a.dest for a in chat_parser._actions}
|
||||
assert "ignore_user_config" in top_dests
|
||||
assert "ignore_rules" in top_dests
|
||||
assert "ignore_user_config" in chat_dests
|
||||
assert "ignore_rules" in chat_dests
|
||||
|
||||
# hm has a helper that builds the argparse tree inside main().
|
||||
# We can extract it by catching the SystemExit on --help.
|
||||
# Simpler: just grep the source for the flag strings. Both approaches
|
||||
# are brittle; we use a combined test.
|
||||
import inspect
|
||||
src = inspect.getsource(hm)
|
||||
assert '"--ignore-user-config"' in src, \
|
||||
"chat subparser must register --ignore-user-config"
|
||||
assert '"--ignore-rules"' in src, \
|
||||
"chat subparser must register --ignore-rules"
|
||||
# And the cmd_chat env-var wiring must be present
|
||||
import inspect
|
||||
import hermes_cli.main as hm
|
||||
src = inspect.getsource(hm)
|
||||
assert "HERMES_IGNORE_USER_CONFIG" in src
|
||||
assert "HERMES_IGNORE_RULES" in src
|
||||
|
||||
91
tests/hermes_cli/test_mcp_reload_confirm_gate.py
Normal file
91
tests/hermes_cli/test_mcp_reload_confirm_gate.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for the approvals.mcp_reload_confirm config gate.
|
||||
|
||||
When the user runs /reload-mcp, the MCP tool set is rebuilt which
|
||||
invalidates the provider prompt cache for the active session. That's
|
||||
expensive on long-context / high-reasoning models. The config gate
|
||||
adds a three-option confirmation (Approve Once / Always Approve /
|
||||
Cancel); "Always Approve" flips this key to false so subsequent reloads
|
||||
run silently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
class TestMcpReloadConfirmDefault:
|
||||
def test_default_config_has_the_key(self):
|
||||
approvals = DEFAULT_CONFIG.get("approvals")
|
||||
assert isinstance(approvals, dict)
|
||||
assert "mcp_reload_confirm" in approvals
|
||||
|
||||
def test_default_is_true(self):
|
||||
# New installs confirm by default — this is the safe behavior.
|
||||
assert DEFAULT_CONFIG["approvals"]["mcp_reload_confirm"] is True
|
||||
|
||||
def test_shape_matches_other_approval_keys(self):
|
||||
# Same flat dict level as `mode` / `timeout` / `cron_mode`.
|
||||
approvals = DEFAULT_CONFIG["approvals"]
|
||||
assert isinstance(approvals.get("mode"), str)
|
||||
assert isinstance(approvals.get("timeout"), int)
|
||||
assert isinstance(approvals.get("cron_mode"), str)
|
||||
assert isinstance(approvals.get("mcp_reload_confirm"), bool)
|
||||
|
||||
|
||||
class TestUserConfigMerge:
|
||||
"""If a user has a pre-existing config without this key, load_config
|
||||
should fill it in from DEFAULT_CONFIG (deep merge preserves keys the
|
||||
user didn't override).
|
||||
"""
|
||||
|
||||
def test_existing_user_config_without_key_gets_default(self, tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
# Simulate a legacy user config without the new key.
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
cfg_path = home / "config.yaml"
|
||||
legacy = {
|
||||
"approvals": {"mode": "manual", "timeout": 60, "cron_mode": "deny"},
|
||||
}
|
||||
cfg_path.write_text(yaml.safe_dump(legacy))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
# Force a fresh reimport of config.py so the HERMES_HOME is honored.
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg = cfg_mod.load_config()
|
||||
assert cfg["approvals"]["mcp_reload_confirm"] is True
|
||||
|
||||
def test_existing_user_config_with_false_key_survives_merge(
|
||||
self, tmp_path, monkeypatch,
|
||||
):
|
||||
"""A user who has clicked "Always Approve" (key=false) must keep
|
||||
that setting across reloads — the default_true value must not win.
|
||||
"""
|
||||
import yaml
|
||||
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
cfg_path = home / "config.yaml"
|
||||
user_cfg = {
|
||||
"approvals": {
|
||||
"mode": "manual",
|
||||
"timeout": 60,
|
||||
"cron_mode": "deny",
|
||||
"mcp_reload_confirm": False,
|
||||
},
|
||||
}
|
||||
cfg_path.write_text(yaml.safe_dump(user_cfg))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg = cfg_mod.load_config()
|
||||
assert cfg["approvals"]["mcp_reload_confirm"] is False
|
||||
@@ -188,6 +188,23 @@ class TestCreateProfile:
|
||||
assert not (profile_dir / "gateway_state.json").exists()
|
||||
assert not (profile_dir / "processes.json").exists()
|
||||
|
||||
def test_clone_all_excludes_sibling_profiles_tree(self, profile_env):
|
||||
"""--clone-all from default ~/.hermes must not copy profiles/* (nested explosion)."""
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
profiles_root = default_home / "profiles"
|
||||
profiles_root.mkdir(exist_ok=True)
|
||||
(profiles_root / "other").mkdir(parents=True, exist_ok=True)
|
||||
(profiles_root / "other" / "marker.txt").write_text("sibling data")
|
||||
|
||||
(default_home / "memories").mkdir(exist_ok=True)
|
||||
(default_home / "memories" / "note.md").write_text("remember this")
|
||||
|
||||
profile_dir = create_profile("coder", clone_all=True, no_alias=True)
|
||||
|
||||
assert (profile_dir / "memories" / "note.md").read_text() == "remember this"
|
||||
assert not (profile_dir / "profiles").exists()
|
||||
|
||||
def test_clone_config_missing_files_skipped(self, profile_env):
|
||||
"""Clone config gracefully skips files that don't exist in source."""
|
||||
profile_dir = create_profile("coder", clone_config=True, no_alias=True)
|
||||
|
||||
@@ -96,10 +96,17 @@ class TestPtyBridgeIO:
|
||||
@skip_on_windows
|
||||
class TestPtyBridgeResize:
|
||||
def test_resize_updates_child_winsize(self):
|
||||
# tput reads COLUMNS/LINES from the TTY ioctl (TIOCGWINSZ).
|
||||
# Spawn a shell, resize, then ask tput for the dimensions.
|
||||
# Query the TTY ioctl directly instead of using tput, which requires
|
||||
# TERM and fails in GitHub Actions' non-interactive environment.
|
||||
winsize_script = (
|
||||
"import fcntl, struct, termios, time; "
|
||||
"time.sleep(0.1); "
|
||||
"rows, cols, *_ = struct.unpack('HHHH', "
|
||||
"fcntl.ioctl(0, termios.TIOCGWINSZ, b'\\0' * 8)); "
|
||||
"print(cols); print(rows)"
|
||||
)
|
||||
bridge = PtyBridge.spawn(
|
||||
["/bin/sh", "-c", "sleep 0.1; tput cols; tput lines"],
|
||||
[sys.executable, "-c", winsize_script],
|
||||
cols=80,
|
||||
rows=24,
|
||||
)
|
||||
|
||||
155
tests/hermes_cli/test_relaunch.py
Normal file
155
tests/hermes_cli/test_relaunch.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Tests for hermes_cli.relaunch — unified self-relaunch utility."""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import relaunch as relaunch_mod
|
||||
|
||||
|
||||
class TestResolveHermesBin:
|
||||
def test_prefers_absolute_argv0_when_executable(self, monkeypatch):
|
||||
fake = "/nix/store/abc/bin/hermes"
|
||||
monkeypatch.setattr(sys, "argv", [fake])
|
||||
monkeypatch.setattr(relaunch_mod.os.path, "isfile", lambda p: p == fake)
|
||||
monkeypatch.setattr(relaunch_mod.os, "access", lambda p, mode: p == fake)
|
||||
assert relaunch_mod.resolve_hermes_bin() == fake
|
||||
|
||||
def test_resolves_relative_argv0(self, monkeypatch, tmp_path):
|
||||
fake = tmp_path / "hermes"
|
||||
fake.write_text("#!/bin/sh\n")
|
||||
fake.chmod(0o755)
|
||||
monkeypatch.setattr(sys, "argv", [str(fake.name)])
|
||||
monkeypatch.chdir(tmp_path)
|
||||
# Ensure we don't accidentally match a real 'hermes' on PATH
|
||||
monkeypatch.setattr(relaunch_mod.shutil, "which", lambda _name: None)
|
||||
assert relaunch_mod.resolve_hermes_bin() == str(fake)
|
||||
|
||||
def test_falls_back_to_path_which(self, monkeypatch):
|
||||
monkeypatch.setattr(sys, "argv", ["-c"]) # not a real path
|
||||
monkeypatch.setattr(
|
||||
relaunch_mod.shutil, "which", lambda name: "/usr/bin/hermes" if name == "hermes" else None
|
||||
)
|
||||
assert relaunch_mod.resolve_hermes_bin() == "/usr/bin/hermes"
|
||||
|
||||
def test_returns_none_when_unresolvable(self, monkeypatch):
|
||||
monkeypatch.setattr(sys, "argv", ["-c"])
|
||||
monkeypatch.setattr(relaunch_mod.shutil, "which", lambda _name: None)
|
||||
assert relaunch_mod.resolve_hermes_bin() is None
|
||||
|
||||
|
||||
class TestExtractInheritedFlags:
|
||||
def test_extracts_tui_and_dev(self):
|
||||
argv = ["--tui", "--dev", "chat"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == ["--tui", "--dev"]
|
||||
|
||||
def test_extracts_profile_with_value(self):
|
||||
argv = ["--profile", "work", "chat"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == ["--profile", "work"]
|
||||
|
||||
def test_extracts_short_p_with_value(self):
|
||||
argv = ["-p", "work"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == ["-p", "work"]
|
||||
|
||||
def test_extracts_equals_form(self):
|
||||
argv = ["--profile=work", "--model=anthropic/claude-sonnet-4"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == [
|
||||
"--profile=work",
|
||||
"--model=anthropic/claude-sonnet-4",
|
||||
]
|
||||
|
||||
def test_skips_unknown_flags(self):
|
||||
argv = ["--foo", "bar", "--tui"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == ["--tui"]
|
||||
|
||||
def test_does_not_consume_flag_like_value(self):
|
||||
argv = ["--tui", "--resume", "abc123"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == ["--tui"]
|
||||
|
||||
def test_preserves_multiple_skills(self):
|
||||
argv = ["-s", "foo", "-s", "bar", "--tui"]
|
||||
assert relaunch_mod._extract_inherited_flags(argv) == ["-s", "foo", "-s", "bar", "--tui"]
|
||||
|
||||
|
||||
class TestInheritedFlagTable:
|
||||
"""Sanity-check the argparse-introspected table that drives extraction."""
|
||||
|
||||
def test_short_and_long_aliases_are_paired(self):
|
||||
table = dict(relaunch_mod._INHERITED_FLAGS_TABLE)
|
||||
# Each pair declared together in the parser shares takes_value.
|
||||
for short, long_ in [
|
||||
("-p", "--profile"),
|
||||
("-m", "--model"),
|
||||
("-s", "--skills"),
|
||||
]:
|
||||
assert table[short] == table[long_], f"{short}/{long_} disagree"
|
||||
|
||||
def test_store_true_flags_do_not_take_value(self):
|
||||
table = dict(relaunch_mod._INHERITED_FLAGS_TABLE)
|
||||
for flag in ["--tui", "--dev", "--yolo", "--ignore-user-config", "--ignore-rules"]:
|
||||
assert table[flag] is False, f"{flag} should not take a value"
|
||||
|
||||
def test_value_flags_take_value(self):
|
||||
table = dict(relaunch_mod._INHERITED_FLAGS_TABLE)
|
||||
for flag in ["--profile", "--model", "--provider", "--skills"]:
|
||||
assert table[flag] is True, f"{flag} should take a value"
|
||||
|
||||
def test_excluded_flags_are_not_inherited(self):
|
||||
table = dict(relaunch_mod._INHERITED_FLAGS_TABLE)
|
||||
# --worktree creates a new worktree per process; inheriting would
|
||||
# orphan the parent's. Chat-only flags (--quiet/-Q, --verbose/-v,
|
||||
# --source) can't be in argv at the existing relaunch callsites.
|
||||
for flag in ["-w", "--worktree", "-Q", "--quiet", "-v", "--verbose", "--source"]:
|
||||
assert flag not in table, f"{flag} should not be inherited"
|
||||
|
||||
|
||||
class TestBuildRelaunchArgv:
|
||||
def test_uses_bin_when_available(self, monkeypatch):
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes")
|
||||
argv = relaunch_mod.build_relaunch_argv(["--resume", "abc"])
|
||||
assert argv[0] == "/usr/bin/hermes"
|
||||
|
||||
def test_falls_back_to_python_module(self, monkeypatch):
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: None)
|
||||
argv = relaunch_mod.build_relaunch_argv(["--resume", "abc"])
|
||||
assert argv == [sys.executable, "-m", "hermes_cli.main", "--resume", "abc"]
|
||||
|
||||
def test_preserves_inherited_flags(self, monkeypatch):
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes")
|
||||
original = ["--tui", "--dev", "--profile", "work", "sessions", "browse"]
|
||||
argv = relaunch_mod.build_relaunch_argv(["--resume", "abc"], original_argv=original)
|
||||
assert "--tui" in argv
|
||||
assert "--dev" in argv
|
||||
assert "--profile" in argv
|
||||
assert "work" in argv
|
||||
assert "--resume" in argv
|
||||
assert "abc" in argv
|
||||
# The original subcommand should not survive
|
||||
assert "sessions" not in argv
|
||||
assert "browse" not in argv
|
||||
|
||||
def test_can_disable_preserve(self, monkeypatch):
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes")
|
||||
original = ["--tui", "chat"]
|
||||
argv = relaunch_mod.build_relaunch_argv(
|
||||
["--resume", "abc"], preserve_inherited=False, original_argv=original
|
||||
)
|
||||
assert "--tui" not in argv
|
||||
assert argv == ["/usr/bin/hermes", "--resume", "abc"]
|
||||
|
||||
|
||||
class TestRelaunch:
|
||||
def test_calls_execvp(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def fake_execvp(path, argv):
|
||||
calls.append((path, argv))
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(relaunch_mod.os, "execvp", fake_execvp)
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
relaunch_mod.relaunch(["--resume", "abc"])
|
||||
|
||||
assert calls == [("/usr/bin/hermes", ["/usr/bin/hermes", "--resume", "abc"])]
|
||||
@@ -1998,6 +1998,7 @@ class TestAzureAnthropicEnvVarHint:
|
||||
|
||||
assert resolved["api_key"] == "fallback-works"
|
||||
|
||||
|
||||
def test_no_key_anywhere_raises_helpful_error(self, monkeypatch):
|
||||
"""When nothing resolves, the error message mentions key_env as an option."""
|
||||
monkeypatch.delenv("AZURE_ANTHROPIC_KEY", raising=False)
|
||||
@@ -2168,3 +2169,67 @@ class TestTencentTokenhubRuntimeResolution:
|
||||
assert resolved["base_url"] == "https://explicit-proxy.example.com/v1"
|
||||
assert resolved["source"] == "explicit"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# minimax-oauth runtime resolution tests (added by feat/minimax-oauth-provider)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_minimax_oauth_runtime_returns_anthropic_messages_mode(monkeypatch):
|
||||
"""resolve_runtime_provider for minimax-oauth must return api_mode='anthropic_messages'."""
|
||||
from hermes_cli.auth import MINIMAX_OAUTH_GLOBAL_INFERENCE
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-oauth")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "minimax-oauth"})
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: None)
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_resolve_named_custom_runtime",
|
||||
lambda **k: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_resolve_explicit_runtime",
|
||||
lambda **k: None,
|
||||
)
|
||||
|
||||
fake_creds = {
|
||||
"provider": "minimax-oauth",
|
||||
"api_key": "mock-access-token",
|
||||
"base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE.rstrip("/"),
|
||||
"source": "oauth",
|
||||
}
|
||||
|
||||
import hermes_cli.auth as auth_mod
|
||||
monkeypatch.setattr(auth_mod, "resolve_minimax_oauth_runtime_credentials",
|
||||
lambda **k: fake_creds)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax-oauth")
|
||||
|
||||
assert resolved["provider"] == "minimax-oauth"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["api_key"] == "mock-access-token"
|
||||
|
||||
|
||||
def test_minimax_oauth_runtime_uses_inference_base_url(monkeypatch):
|
||||
"""Base URL returned by resolve_runtime_provider should match the OAuth credentials."""
|
||||
from hermes_cli.auth import MINIMAX_OAUTH_CN_INFERENCE
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-oauth")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "minimax-oauth"})
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: None)
|
||||
monkeypatch.setattr(rp, "_resolve_named_custom_runtime", lambda **k: None)
|
||||
monkeypatch.setattr(rp, "_resolve_explicit_runtime", lambda **k: None)
|
||||
|
||||
fake_creds = {
|
||||
"provider": "minimax-oauth",
|
||||
"api_key": "cn-token",
|
||||
"base_url": MINIMAX_OAUTH_CN_INFERENCE.rstrip("/"),
|
||||
"source": "oauth",
|
||||
}
|
||||
|
||||
import hermes_cli.auth as auth_mod
|
||||
monkeypatch.setattr(auth_mod, "resolve_minimax_oauth_runtime_credentials",
|
||||
lambda **k: fake_creds)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax-oauth")
|
||||
|
||||
assert MINIMAX_OAUTH_CN_INFERENCE.rstrip("/") in resolved["base_url"]
|
||||
|
||||
@@ -127,6 +127,13 @@ class TestConfigYamlRouting:
|
||||
or "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=True" in env_content
|
||||
)
|
||||
|
||||
def test_terminal_vercel_runtime_goes_to_config_and_env(self, _isolated_hermes_home):
|
||||
set_config_value("terminal.vercel_runtime", "python3.13")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
env_content = _read_env(_isolated_hermes_home)
|
||||
assert "vercel_runtime: python3.13" in config
|
||||
assert "TERMINAL_VERCEL_RUNTIME=python3.13" in env_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty / falsy values — regression tests for #4277
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for setup.py configuration flows."""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
@@ -29,6 +30,17 @@ def _clear_provider_env(monkeypatch):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _clear_vercel_env(monkeypatch):
|
||||
for key in (
|
||||
"TERMINAL_VERCEL_RUNTIME",
|
||||
"VERCEL_OIDC_TOKEN",
|
||||
"VERCEL_TOKEN",
|
||||
"VERCEL_PROJECT_ID",
|
||||
"VERCEL_TEAM_ID",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _stub_tts(monkeypatch):
|
||||
"""Stub out TTS prompts so setup_model_provider doesn't block."""
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda q, c, d=0: (
|
||||
@@ -162,12 +174,13 @@ def test_setup_gateway_skips_service_install_when_systemctl_missing(monkeypatch,
|
||||
"WEBHOOK_ENABLED": "",
|
||||
}
|
||||
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, ""))
|
||||
monkeypatch.setattr(gateway_mod, "get_env_value", lambda key: env.get(key, ""))
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
|
||||
@@ -200,12 +213,13 @@ def test_setup_gateway_in_container_shows_docker_guidance(monkeypatch, capsys):
|
||||
"WEBHOOK_ENABLED": "",
|
||||
}
|
||||
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, ""))
|
||||
monkeypatch.setattr(gateway_mod, "get_env_value", lambda key: env.get(key, ""))
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("platform.system", lambda: "Linux")
|
||||
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
|
||||
@@ -480,28 +494,91 @@ def test_modal_setup_persists_direct_mode_when_user_chooses_their_own_account(tm
|
||||
assert config["terminal"]["modal_mode"] == "direct"
|
||||
|
||||
|
||||
def test_resolve_hermes_chat_argv_prefers_which(monkeypatch):
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
monkeypatch.setattr(setup_mod.shutil, "which", lambda name: "/usr/local/bin/hermes" if name == "hermes" else None)
|
||||
|
||||
assert setup_mod._resolve_hermes_chat_argv() == ["/usr/local/bin/hermes", "chat"]
|
||||
|
||||
|
||||
def test_resolve_hermes_chat_argv_falls_back_to_module(monkeypatch):
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
monkeypatch.setattr(setup_mod.shutil, "which", lambda _name: None)
|
||||
monkeypatch.setattr(setup_mod.importlib.util, "find_spec", lambda name: object() if name == "hermes_cli" else None)
|
||||
|
||||
assert setup_mod._resolve_hermes_chat_argv() == [sys.executable, "-m", "hermes_cli.main", "chat"]
|
||||
|
||||
|
||||
def test_offer_launch_chat_execs_fresh_process(monkeypatch):
|
||||
def test_vercel_setup_configures_access_token_auth(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_vercel_env(monkeypatch)
|
||||
monkeypatch.setenv("VERCEL_OIDC_TOKEN", "old-oidc")
|
||||
monkeypatch.setitem(sys.modules, "vercel", types.ModuleType("vercel"))
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select terminal backend:":
|
||||
return 5
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
prompt_values = iter(["python3.13", "yes", "2", "4096", "token", "project", "team"])
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: next(prompt_values))
|
||||
|
||||
from hermes_cli.setup import setup_terminal_backend
|
||||
|
||||
setup_terminal_backend(config)
|
||||
|
||||
assert config["terminal"]["backend"] == "vercel_sandbox"
|
||||
assert config["terminal"]["vercel_runtime"] == "python3.13"
|
||||
assert config["terminal"]["container_disk"] == 51200
|
||||
assert os.environ["TERMINAL_VERCEL_RUNTIME"] == "python3.13"
|
||||
assert "VERCEL_OIDC_TOKEN" not in os.environ
|
||||
assert os.environ["VERCEL_TOKEN"] == "token"
|
||||
assert os.environ["VERCEL_PROJECT_ID"] == "project"
|
||||
assert os.environ["VERCEL_TEAM_ID"] == "team"
|
||||
|
||||
|
||||
def test_vercel_setup_prefills_project_and_team_from_link_file(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_vercel_env(monkeypatch)
|
||||
project_root = tmp_path / "project"
|
||||
nested = project_root / "app" / "src"
|
||||
nested.mkdir(parents=True)
|
||||
vercel_dir = project_root / ".vercel"
|
||||
vercel_dir.mkdir()
|
||||
(vercel_dir / "project.json").write_text(
|
||||
json.dumps({"projectId": "linked-project", "orgId": "linked-team"}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.chdir(nested)
|
||||
monkeypatch.setitem(sys.modules, "vercel", types.ModuleType("vercel"))
|
||||
config = load_config()
|
||||
config["terminal"]["container_disk"] = 999
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select terminal backend:":
|
||||
return 5
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
prompt_values = iter(["node24", "no", "1", "5120", "token", "", ""])
|
||||
defaults = {}
|
||||
|
||||
def fake_prompt(message, default="", **kwargs):
|
||||
defaults[message] = default
|
||||
value = next(prompt_values)
|
||||
return value or default
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
|
||||
|
||||
from hermes_cli.setup import setup_terminal_backend
|
||||
|
||||
setup_terminal_backend(config)
|
||||
|
||||
assert config["terminal"]["backend"] == "vercel_sandbox"
|
||||
assert config["terminal"]["container_persistent"] is False
|
||||
assert config["terminal"]["container_disk"] == 51200
|
||||
assert "VERCEL_OIDC_TOKEN" not in os.environ
|
||||
assert os.environ["VERCEL_TOKEN"] == "token"
|
||||
assert os.environ["VERCEL_PROJECT_ID"] == "linked-project"
|
||||
assert os.environ["VERCEL_TEAM_ID"] == "linked-team"
|
||||
assert defaults[" Vercel project ID"] == "linked-project"
|
||||
assert defaults[" Vercel team ID"] == "linked-team"
|
||||
|
||||
|
||||
def test_offer_launch_chat_relaunches_via_bin(monkeypatch):
|
||||
from hermes_cli import setup as setup_mod
|
||||
from hermes_cli import relaunch as relaunch_mod
|
||||
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(setup_mod, "_resolve_hermes_chat_argv", lambda: ["/usr/local/bin/hermes", "chat"])
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/local/bin/hermes")
|
||||
|
||||
exec_calls = []
|
||||
|
||||
@@ -509,7 +586,7 @@ def test_offer_launch_chat_execs_fresh_process(monkeypatch):
|
||||
exec_calls.append((path, argv))
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(setup_mod.os, "execvp", fake_execvp)
|
||||
monkeypatch.setattr(relaunch_mod.os, "execvp", fake_execvp)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
setup_mod._offer_launch_chat()
|
||||
@@ -517,13 +594,22 @@ def test_offer_launch_chat_execs_fresh_process(monkeypatch):
|
||||
assert exec_calls == [("/usr/local/bin/hermes", ["/usr/local/bin/hermes", "chat"])]
|
||||
|
||||
|
||||
def test_offer_launch_chat_manual_fallback_when_unresolvable(monkeypatch, capsys):
|
||||
def test_offer_launch_chat_falls_back_to_module(monkeypatch):
|
||||
from hermes_cli import setup as setup_mod
|
||||
from hermes_cli import relaunch as relaunch_mod
|
||||
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(setup_mod, "_resolve_hermes_chat_argv", lambda: None)
|
||||
monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: None)
|
||||
|
||||
setup_mod._offer_launch_chat()
|
||||
exec_calls = []
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Run 'hermes chat' manually" in captured.out
|
||||
def fake_execvp(path, argv):
|
||||
exec_calls.append((path, argv))
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(relaunch_mod.os, "execvp", fake_execvp)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
setup_mod._offer_launch_chat()
|
||||
|
||||
assert exec_calls == [(sys.executable, [sys.executable, "-m", "hermes_cli.main", "chat"])]
|
||||
|
||||
245
tests/hermes_cli/test_setup_irc.py
Normal file
245
tests/hermes_cli/test_setup_irc.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Tests for IRC gateway configuration via `hermes setup gateway` UI.
|
||||
|
||||
Covers the full plugin-platform discovery → status → configure flow so that
|
||||
a fresh Hermes install (no state, no env vars) can set up IRC through the
|
||||
interactive setup menus.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from gateway.platform_registry import PlatformEntry, platform_registry
|
||||
|
||||
|
||||
def _register_irc_platform(**overrides):
|
||||
"""Manually register the IRC platform entry as if discover_plugins() found it.
|
||||
|
||||
Tests run outside the normal plugin-discovery path, so we inject the entry
|
||||
directly into the singleton registry and yield its dict shape.
|
||||
"""
|
||||
defaults = dict(
|
||||
name="irc",
|
||||
label="IRC",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: bool(os.getenv("IRC_SERVER", "") and os.getenv("IRC_CHANNEL", "")),
|
||||
validate_config=None,
|
||||
required_env=["IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"],
|
||||
install_hint="No extra packages needed (stdlib only)",
|
||||
setup_fn=lambda: None,
|
||||
source="plugin",
|
||||
plugin_name="irc_platform",
|
||||
allowed_users_env="IRC_ALLOWED_USERS",
|
||||
allow_all_env="IRC_ALLOW_ALL_USERS",
|
||||
max_message_length=450,
|
||||
pii_safe=False,
|
||||
emoji="💬",
|
||||
allow_update_command=True,
|
||||
platform_hint="You are chatting via IRC.",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
entry = PlatformEntry(**defaults)
|
||||
platform_registry.register(entry)
|
||||
return {
|
||||
"key": entry.name,
|
||||
"label": entry.label,
|
||||
"emoji": entry.emoji,
|
||||
"token_var": entry.required_env[0] if entry.required_env else "",
|
||||
"install_hint": entry.install_hint,
|
||||
"_registry_entry": entry,
|
||||
}
|
||||
|
||||
|
||||
def _unregister_irc_platform():
|
||||
platform_registry.unregister("irc")
|
||||
|
||||
|
||||
# ── Fresh-install discovery ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIRCFreshInstallDiscovery:
|
||||
"""IRC appears in the setup menu on a brand-new Hermes install."""
|
||||
|
||||
def test_irc_appears_in_all_platforms(self, monkeypatch):
|
||||
"""When the IRC plugin is registered, _all_platforms() surfaces it."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
_register_irc_platform()
|
||||
try:
|
||||
# Ensure no stale env vars leak in
|
||||
for key in ("IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
platforms = gateway_mod._all_platforms()
|
||||
keys = {p["key"] for p in platforms}
|
||||
assert "irc" in keys
|
||||
|
||||
irc_plat = next(p for p in platforms if p["key"] == "irc")
|
||||
assert irc_plat["label"] == "IRC"
|
||||
assert irc_plat["emoji"] == "💬"
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
def test_irc_status_not_configured_when_fresh(self, monkeypatch):
|
||||
"""On a fresh install with no env vars, IRC shows 'not configured'."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
plat = _register_irc_platform()
|
||||
try:
|
||||
for key in ("IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
status = gateway_mod._platform_status(plat)
|
||||
assert status == "not configured"
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
def test_irc_status_configured_when_env_set(self, monkeypatch):
|
||||
"""After the user sets IRC_SERVER and IRC_CHANNEL, status is 'configured'."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
plat = _register_irc_platform()
|
||||
try:
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.libera.chat")
|
||||
monkeypatch.setenv("IRC_CHANNEL", "#hermes")
|
||||
monkeypatch.setenv("IRC_NICKNAME", "hermes-bot")
|
||||
|
||||
status = gateway_mod._platform_status(plat)
|
||||
assert status == "configured"
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
def test_irc_status_partial_when_only_server_set(self, monkeypatch):
|
||||
"""If only IRC_SERVER is set, the platform is still not configured."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
plat = _register_irc_platform()
|
||||
try:
|
||||
monkeypatch.delenv("IRC_CHANNEL", raising=False)
|
||||
monkeypatch.delenv("IRC_NICKNAME", raising=False)
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.libera.chat")
|
||||
|
||||
status = gateway_mod._platform_status(plat)
|
||||
assert status == "not configured"
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
|
||||
# ── Interactive setup dispatch ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIRCInteractiveSetup:
|
||||
"""The setup UI dispatches to IRC's interactive_setup() correctly."""
|
||||
|
||||
def test_configure_platform_dispatches_to_irc_setup_fn(self, monkeypatch, capsys):
|
||||
"""_configure_platform() calls the IRC plugin's setup_fn when selected."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_setup():
|
||||
calls.append("setup_called")
|
||||
print("IRC setup complete!")
|
||||
|
||||
plat = _register_irc_platform(setup_fn=fake_setup)
|
||||
try:
|
||||
gateway_mod._configure_platform(plat)
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
assert "setup_called" in calls
|
||||
out = capsys.readouterr().out
|
||||
assert "IRC setup complete!" in out
|
||||
|
||||
|
||||
def test_configure_platform_fallback_when_no_setup_fn(self, monkeypatch, capsys):
|
||||
"""A plugin with no setup_fn falls back to env-var instructions."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
plat = _register_irc_platform(setup_fn=None)
|
||||
try:
|
||||
gateway_mod._configure_platform(plat)
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "IRC" in out
|
||||
assert "IRC_SERVER" in out
|
||||
|
||||
|
||||
# ── End-to-end fresh-install gateway setup ──────────────────────────────────
|
||||
|
||||
|
||||
class TestIRCGatewaySetupFreshInstall:
|
||||
"""Simulate the full `hermes setup gateway` experience with IRC present."""
|
||||
|
||||
def test_setup_gateway_shows_irc_in_platform_menu(self, monkeypatch, capsys, tmp_path):
|
||||
"""The gateway setup menu lists IRC among the available platforms."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_register_irc_platform()
|
||||
try:
|
||||
for key in ("IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
# Sanity-check: IRC must be visible to _all_platforms()
|
||||
platforms = gateway_mod._all_platforms()
|
||||
assert any(p["key"] == "irc" for p in platforms), \
|
||||
f"IRC not in platforms: {[p['key'] for p in platforms]}"
|
||||
|
||||
# Capture what prompt_checklist is asked to display
|
||||
checklist_calls = []
|
||||
|
||||
def capture_prompt_checklist(question, choices, pre_selected=None):
|
||||
checklist_calls.append({"question": question, "choices": choices})
|
||||
return [] # nothing selected → clean exit
|
||||
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *a, **kw: False)
|
||||
monkeypatch.setattr(setup_mod, "prompt_checklist", capture_prompt_checklist)
|
||||
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False)
|
||||
|
||||
setup_mod.setup_gateway({})
|
||||
|
||||
# Find the platform-selection prompt
|
||||
platform_prompt = next(
|
||||
(c for c in checklist_calls if "platform" in c["question"].lower()),
|
||||
None,
|
||||
)
|
||||
assert platform_prompt is not None, \
|
||||
f"No platform prompt found in {checklist_calls}"
|
||||
choices_text = "\n".join(platform_prompt["choices"])
|
||||
assert "IRC" in choices_text
|
||||
assert "💬" in choices_text
|
||||
assert "not configured" in choices_text.lower()
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
|
||||
def test_setup_gateway_irc_counts_as_messaging_platform(self, monkeypatch, capsys, tmp_path):
|
||||
"""When IRC is configured, setup_gateway counts it as a messaging platform."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_register_irc_platform()
|
||||
try:
|
||||
monkeypatch.setenv("IRC_SERVER", "irc.libera.chat")
|
||||
monkeypatch.setenv("IRC_CHANNEL", "#hermes")
|
||||
monkeypatch.setenv("IRC_NICKNAME", "hermes-bot")
|
||||
|
||||
monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *a, **kw: False)
|
||||
monkeypatch.setattr(setup_mod, "prompt_choice", lambda *a, **kw: 0)
|
||||
monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False)
|
||||
monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False)
|
||||
|
||||
setup_mod.setup_gateway({})
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Messaging platforms configured!" in out
|
||||
finally:
|
||||
_unregister_irc_platform()
|
||||
@@ -419,7 +419,12 @@ class TestGetSectionConfigSummary:
|
||||
return "disc456"
|
||||
return ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
# Also patch gateway module's binding since _platform_status()
|
||||
# reads from hermes_cli.gateway.get_env_value after the setup
|
||||
# flows were unified via platform_registry.
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side), \
|
||||
patch.object(gateway_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert "Telegram" in result
|
||||
assert "Discord" in result
|
||||
@@ -471,7 +476,9 @@ class TestGetSectionConfigSummary:
|
||||
def env_side(key):
|
||||
return "true" if key == "WHATSAPP_ENABLED" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side), \
|
||||
patch.object(gateway_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert result is not None
|
||||
assert "WhatsApp" in result
|
||||
@@ -481,7 +488,9 @@ class TestGetSectionConfigSummary:
|
||||
def env_side(key):
|
||||
return "http://signal.local" if key == "SIGNAL_HTTP_URL" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side), \
|
||||
patch.object(gateway_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert result is not None
|
||||
assert "Signal" in result
|
||||
@@ -529,13 +538,28 @@ class TestGetSectionConfigSummary:
|
||||
assert result == "gpt-5"
|
||||
|
||||
def test_gateway_matches_platform_registry(self):
|
||||
"""Every platform in _GATEWAY_PLATFORMS should be recognised by its
|
||||
own env-var sentinel — i.e. the summary must not drift from the
|
||||
"""Every built-in platform should be recognised by its primary
|
||||
env-var sentinel — i.e. the summary must not drift from the
|
||||
registry used by the setup checklist."""
|
||||
for label, env_var, _fn in setup_mod._GATEWAY_PLATFORMS:
|
||||
from hermes_cli.gateway import _PLATFORMS
|
||||
|
||||
for plat in _PLATFORMS:
|
||||
label = plat["label"]
|
||||
env_var = plat.get("token_var")
|
||||
if not env_var:
|
||||
continue
|
||||
# Some platforms require a specific value shape (e.g. WhatsApp
|
||||
# needs the literal "true"). Use a sentinel that satisfies every
|
||||
# real validator _platform_status() currently checks.
|
||||
def env_side(key, _target=env_var):
|
||||
return "x" if key == _target else ""
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
if key != _target:
|
||||
return ""
|
||||
if _target == "WHATSAPP_ENABLED":
|
||||
return "true"
|
||||
return "x"
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side), \
|
||||
patch.object(gateway_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
expected = setup_mod._gateway_platform_short_label(label)
|
||||
assert result is not None, f"{label} ({env_var}) not recognised"
|
||||
|
||||
@@ -79,3 +79,33 @@ def test_show_status_reports_nous_auth_error(monkeypatch, capsys, tmp_path):
|
||||
assert "Error: Refresh session has been revoked" in output
|
||||
assert "Access exp:" in output
|
||||
assert "Key exp:" in output
|
||||
|
||||
|
||||
def test_show_status_reports_vercel_backend_contract(monkeypatch, capsys, tmp_path):
|
||||
from hermes_cli import status as status_mod
|
||||
import hermes_cli.auth as auth_mod
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox")
|
||||
monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", "python3.13")
|
||||
monkeypatch.setenv("TERMINAL_CONTAINER_PERSISTENT", "true")
|
||||
monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token")
|
||||
monkeypatch.setattr(status_mod.importlib.util, "find_spec", lambda name: object() if name == "vercel" else None)
|
||||
monkeypatch.setattr(status_mod, "load_config", lambda: {"terminal": {"backend": "vercel_sandbox"}}, raising=False)
|
||||
monkeypatch.setattr(auth_mod, "get_nous_auth_status", lambda: {}, raising=False)
|
||||
monkeypatch.setattr(auth_mod, "get_codex_auth_status", lambda: {}, raising=False)
|
||||
monkeypatch.setattr(auth_mod, "get_qwen_auth_status", lambda: {}, raising=False)
|
||||
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda exclude_pids=None: [], raising=False)
|
||||
|
||||
status_mod.show_status(SimpleNamespace(all=False, deep=False))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Backend: vercel_sandbox" in output
|
||||
assert "Runtime: python3.13" in output
|
||||
assert "Auth:" in output and "OIDC token via VERCEL_OIDC_TOKEN" in output
|
||||
assert "Auth detail: mode: OIDC" in output
|
||||
assert "Auth detail: active env: VERCEL_OIDC_TOKEN" in output
|
||||
assert "oidc-token" not in output
|
||||
assert "snapshot filesystem" in output
|
||||
assert "live processes do not survive" in output
|
||||
|
||||
@@ -12,6 +12,7 @@ def _args(**overrides):
|
||||
"model": None,
|
||||
"provider": None,
|
||||
"resume": None,
|
||||
"toolsets": None,
|
||||
"tui": True,
|
||||
"tui_dev": False,
|
||||
}
|
||||
@@ -35,7 +36,7 @@ def test_cmd_chat_tui_continue_uses_latest_tui_session(monkeypatch, main_mod):
|
||||
calls.append(source)
|
||||
return "20260408_235959_a1b2c3" if source == "tui" else None
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
@@ -62,7 +63,7 @@ def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch, mai
|
||||
return "20260408_235959_d4e5f6"
|
||||
return None
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
@@ -80,7 +81,7 @@ def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch, mai
|
||||
def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
@@ -98,12 +99,13 @@ def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch, main_mod)
|
||||
def test_cmd_chat_tui_passes_model_and_provider(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None):
|
||||
captured.update(
|
||||
{
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"resume": resume_session_id,
|
||||
"toolsets": toolsets,
|
||||
"tui_dev": tui_dev,
|
||||
}
|
||||
)
|
||||
@@ -120,11 +122,193 @@ def test_cmd_chat_tui_passes_model_and_provider(monkeypatch, main_mod):
|
||||
"model": "anthropic/claude-sonnet-4.6",
|
||||
"provider": "anthropic",
|
||||
"resume": None,
|
||||
"toolsets": None,
|
||||
"tui_dev": False,
|
||||
}
|
||||
|
||||
|
||||
def test_launch_tui_exports_model_and_provider(monkeypatch, main_mod):
|
||||
def test_cmd_chat_tui_passes_toolsets(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
|
||||
def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None):
|
||||
captured["toolsets"] = toolsets
|
||||
raise SystemExit(0)
|
||||
|
||||
monkeypatch.setattr(main_mod, "_launch_tui", fake_launch)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main_mod.cmd_chat(_args(toolsets="web,terminal"))
|
||||
|
||||
assert captured["toolsets"] == "web,terminal"
|
||||
|
||||
|
||||
def test_main_top_level_tui_accepts_toolsets(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["hermes", "--tui", "--toolsets", "web,terminal"])
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.plugins", types.SimpleNamespace(discover_plugins=lambda: None))
|
||||
monkeypatch.setitem(sys.modules, "tools.mcp_tool", types.SimpleNamespace(discover_mcp_tools=lambda: None))
|
||||
monkeypatch.setattr(config_mod, "load_config", lambda: {})
|
||||
monkeypatch.setattr(config_mod, "get_container_exec_info", lambda: None)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"agent.shell_hooks",
|
||||
types.SimpleNamespace(register_from_config=lambda _cfg, accept_hooks=False: None),
|
||||
)
|
||||
monkeypatch.setattr(main_mod, "cmd_chat", lambda args: captured.update({"toolsets": args.toolsets, "tui": args.tui}))
|
||||
|
||||
main_mod.main()
|
||||
|
||||
assert captured == {"toolsets": "web,terminal", "tui": True}
|
||||
|
||||
|
||||
def test_main_top_level_oneshot_accepts_toolsets(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["hermes", "-z", "hello", "--toolsets", "web,terminal"])
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.plugins", types.SimpleNamespace(discover_plugins=lambda: None))
|
||||
monkeypatch.setitem(sys.modules, "tools.mcp_tool", types.SimpleNamespace(discover_mcp_tools=lambda: None))
|
||||
monkeypatch.setattr(config_mod, "load_config", lambda: {})
|
||||
monkeypatch.setattr(config_mod, "get_container_exec_info", lambda: None)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"agent.shell_hooks",
|
||||
types.SimpleNamespace(register_from_config=lambda _cfg, accept_hooks=False: None),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.oneshot",
|
||||
types.SimpleNamespace(run_oneshot=lambda prompt, **kwargs: captured.update({"prompt": prompt, **kwargs}) or 0),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
main_mod.main()
|
||||
|
||||
assert exc.value.code == 0
|
||||
assert captured == {"prompt": "hello", "model": None, "provider": None, "toolsets": "web,terminal"}
|
||||
|
||||
|
||||
def _stub_plugin_discovery(monkeypatch):
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: None),
|
||||
)
|
||||
|
||||
|
||||
def test_oneshot_rejects_invalid_only_toolsets(monkeypatch, capsys):
|
||||
_stub_plugin_discovery(monkeypatch)
|
||||
from hermes_cli.oneshot import run_oneshot
|
||||
|
||||
assert run_oneshot("hello", toolsets="nope") == 2
|
||||
err = capsys.readouterr().err
|
||||
assert "nope" in err
|
||||
assert "did not contain any valid toolsets" in err
|
||||
|
||||
|
||||
def test_oneshot_filters_invalid_toolsets_before_redirect(monkeypatch, capsys):
|
||||
_stub_plugin_discovery(monkeypatch)
|
||||
from hermes_cli.oneshot import _validate_explicit_toolsets
|
||||
|
||||
valid, error = _validate_explicit_toolsets("web,nope")
|
||||
|
||||
assert valid == ["web"]
|
||||
assert error is None
|
||||
assert "nope" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_oneshot_all_toolsets_means_all_not_configured_cli():
|
||||
from hermes_cli.oneshot import _validate_explicit_toolsets
|
||||
|
||||
valid, error = _validate_explicit_toolsets("all")
|
||||
|
||||
assert valid is None
|
||||
assert error is None
|
||||
|
||||
|
||||
def test_oneshot_all_toolsets_warns_about_ignored_extra_entries(monkeypatch, capsys):
|
||||
_stub_plugin_discovery(monkeypatch)
|
||||
from hermes_cli.oneshot import _validate_explicit_toolsets
|
||||
|
||||
valid, error = _validate_explicit_toolsets("all,nope")
|
||||
|
||||
assert valid is None
|
||||
assert error is None
|
||||
assert "ignoring additional entries: nope" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_oneshot_accepts_plugin_toolset_after_discovery(monkeypatch):
|
||||
import toolsets
|
||||
|
||||
from hermes_cli.oneshot import _validate_explicit_toolsets
|
||||
|
||||
discovered = {"ready": False}
|
||||
original_validate = toolsets.validate_toolset
|
||||
|
||||
def fake_validate(name):
|
||||
return name == "plugin_demo" and discovered["ready"] or original_validate(name)
|
||||
|
||||
monkeypatch.setattr(toolsets, "validate_toolset", fake_validate)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: discovered.update({"ready": True})),
|
||||
)
|
||||
|
||||
valid, error = _validate_explicit_toolsets("plugin_demo")
|
||||
|
||||
assert valid == ["plugin_demo"]
|
||||
assert error is None
|
||||
|
||||
|
||||
def test_oneshot_rejects_disabled_mcp_toolset(monkeypatch, capsys):
|
||||
_stub_plugin_discovery(monkeypatch)
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
from hermes_cli.oneshot import _validate_explicit_toolsets
|
||||
|
||||
monkeypatch.setattr(
|
||||
config_mod,
|
||||
"read_raw_config",
|
||||
lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}},
|
||||
)
|
||||
|
||||
valid, error = _validate_explicit_toolsets("mcp-off")
|
||||
|
||||
assert valid is None
|
||||
assert error == "hermes -z: --toolsets did not contain any valid toolsets.\n"
|
||||
err = capsys.readouterr().err
|
||||
assert "ignoring disabled MCP servers" in err
|
||||
assert "mcp-off" in err
|
||||
|
||||
|
||||
def test_oneshot_distinguishes_disabled_mcp_from_unknown(monkeypatch, capsys):
|
||||
_stub_plugin_discovery(monkeypatch)
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
from hermes_cli.oneshot import _validate_explicit_toolsets
|
||||
|
||||
monkeypatch.setattr(
|
||||
config_mod,
|
||||
"read_raw_config",
|
||||
lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}},
|
||||
)
|
||||
|
||||
valid, error = _validate_explicit_toolsets("web,mcp-off,nope")
|
||||
|
||||
assert valid == ["web"]
|
||||
assert error is None
|
||||
err = capsys.readouterr().err
|
||||
assert "ignoring unknown --toolsets entries: nope" in err
|
||||
assert "ignoring disabled MCP servers" in err
|
||||
assert "mcp-off" in err
|
||||
|
||||
|
||||
def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod):
|
||||
captured = {}
|
||||
active_path_during_call = None
|
||||
|
||||
@@ -144,13 +328,14 @@ def test_launch_tui_exports_model_and_provider(monkeypatch, main_mod):
|
||||
monkeypatch.setattr(main_mod.subprocess, "call", fake_call)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main_mod._launch_tui(model="nous/hermes-test", provider="nous")
|
||||
main_mod._launch_tui(model="nous/hermes-test", provider="nous", toolsets="web, terminal")
|
||||
|
||||
env = captured["env"]
|
||||
assert env["HERMES_MODEL"] == "nous/hermes-test"
|
||||
assert env["HERMES_INFERENCE_MODEL"] == "nous/hermes-test"
|
||||
assert env["HERMES_TUI_PROVIDER"] == "nous"
|
||||
assert env["HERMES_INFERENCE_PROVIDER"] == "nous"
|
||||
assert env["HERMES_TUI_TOOLSETS"] == "web,terminal"
|
||||
active_path = Path(env["HERMES_TUI_ACTIVE_SESSION_FILE"])
|
||||
assert active_path.name.startswith("hermes-tui-active-session-")
|
||||
assert active_path.suffix == ".json"
|
||||
|
||||
@@ -333,7 +333,10 @@ def test_cmd_update_retries_optional_extras_individually_when_all_fails(monkeypa
|
||||
raise CalledProcessError(returncode=1, cmd=cmd)
|
||||
if cmd == ["/usr/bin/uv", "pip", "install", "-e", ".[mcp]", "--quiet"]:
|
||||
return SimpleNamespace(returncode=0)
|
||||
return SimpleNamespace(returncode=0)
|
||||
# Catch-all must include stdout/stderr so consumers that parse
|
||||
# output (e.g. the dashboard-restart `ps -A` scan added in the
|
||||
# updater) don't crash on AttributeError.
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
@@ -370,7 +373,7 @@ def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path):
|
||||
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
|
||||
if cmd == ["git", "pull", "origin", "main"]:
|
||||
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
|
||||
return SimpleNamespace(returncode=0)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
|
||||
@@ -1,19 +1,62 @@
|
||||
"""Tests for _warn_stale_dashboard_processes — stale dashboard detection.
|
||||
"""Tests for the stale-dashboard handling run at the end of ``hermes update``.
|
||||
|
||||
Ensures ``hermes update`` warns the user when dashboard processes from a
|
||||
previous version are still running after files on disk have been replaced.
|
||||
See #16872.
|
||||
``hermes update`` detects ``hermes dashboard`` processes left over from the
|
||||
previous version and kills them (SIGTERM + SIGKILL grace, or ``taskkill /F``
|
||||
on Windows). Without this, the running backend silently serves stale Python
|
||||
against a freshly-updated JS bundle, producing 401s / empty data.
|
||||
|
||||
History:
|
||||
- #16872 introduced the warn-only helper (``_warn_stale_dashboard_processes``).
|
||||
- #17049 fixed a Windows wmic UnicodeDecodeError crash on non-UTF-8 locales.
|
||||
- This file now also covers the kill semantics that replaced the warning.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.main import _warn_stale_dashboard_processes
|
||||
from hermes_cli.main import (
|
||||
_find_stale_dashboard_pids,
|
||||
_kill_stale_dashboard_processes,
|
||||
_warn_stale_dashboard_processes, # back-compat alias
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _refresh_bindings_against_live_module():
|
||||
"""Rebind module-level names to the *current* ``hermes_cli.main``.
|
||||
|
||||
Other tests in the suite (notably ``test_env_loader.py`` and
|
||||
``test_skills_subparser.py``) reload or delete ``hermes_cli.main`` from
|
||||
``sys.modules``. When that happens on the same xdist worker before we
|
||||
run, our top-of-file ``from hermes_cli.main import ...`` bindings end
|
||||
up pointing at the *old* module object. ``patch(\"hermes_cli.main.X\")``
|
||||
then patches the *new* module, but the function we call still resolves
|
||||
``_find_stale_dashboard_pids`` via its stale ``__globals__``, so every
|
||||
patch becomes a no-op and the kill path silently returns early.
|
||||
|
||||
Refreshing the bindings (and the patch target) to the live module
|
||||
object — and keeping them consistent — makes the tests immune to
|
||||
ordering within the worker. The fix lives in the test module because
|
||||
the two pollutants above are load-bearing for their own tests.
|
||||
"""
|
||||
global _find_stale_dashboard_pids
|
||||
global _kill_stale_dashboard_processes
|
||||
global _warn_stale_dashboard_processes
|
||||
|
||||
live = sys.modules.get("hermes_cli.main")
|
||||
if live is None:
|
||||
live = importlib.import_module("hermes_cli.main")
|
||||
|
||||
_find_stale_dashboard_pids = live._find_stale_dashboard_pids
|
||||
_kill_stale_dashboard_processes = live._kill_stale_dashboard_processes
|
||||
_warn_stale_dashboard_processes = live._warn_stale_dashboard_processes
|
||||
yield
|
||||
|
||||
|
||||
def _ps_line(pid: int, cmd: str) -> str:
|
||||
@@ -21,11 +64,26 @@ def _ps_line(pid: int, cmd: str) -> str:
|
||||
return f"{pid:>7} {cmd}"
|
||||
|
||||
|
||||
class TestWarnStaleDashboardProcesses:
|
||||
"""Unit tests for the stale dashboard process warning."""
|
||||
def _ps_runner(stdout: str):
|
||||
"""Build a subprocess.run side_effect that only stubs ps -A calls.
|
||||
|
||||
def test_no_warning_when_no_dashboard_running(self, capsys):
|
||||
"""ps returns no matching processes — no warning should be printed."""
|
||||
Any other subprocess.run invocation (e.g. taskkill on Windows) is
|
||||
handed back as a successful no-op. This lets tests exercise the real
|
||||
scan path without having to re-stub every unrelated subprocess call
|
||||
made later in ``_kill_stale_dashboard_processes``.
|
||||
"""
|
||||
def _side_effect(args, *a, **kw):
|
||||
if isinstance(args, (list, tuple)) and args and args[0] == "ps":
|
||||
return MagicMock(returncode=0, stdout=stdout, stderr="")
|
||||
# Any other subprocess.run (e.g. taskkill) — benign success stub.
|
||||
return MagicMock(returncode=0, stdout="", stderr="")
|
||||
return _side_effect
|
||||
|
||||
|
||||
class TestFindStaleDashboardPids:
|
||||
"""Unit tests for the ps/wmic-based detection step."""
|
||||
|
||||
def test_no_matches_returns_empty(self):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
@@ -35,26 +93,18 @@ class TestWarnStaleDashboardProcesses:
|
||||
+ "\n",
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "dashboard process" not in output
|
||||
assert _find_stale_dashboard_pids() == []
|
||||
|
||||
def test_warning_printed_for_running_dashboard(self, capsys):
|
||||
"""ps finds a dashboard PID — warning with PID should appear."""
|
||||
def test_matches_running_dashboard(self):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout=_ps_line(12345, "python3 -m hermes_cli.main dashboard --port 9119") + "\n",
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "1 dashboard process" in output
|
||||
assert "PID 12345" in output
|
||||
assert "kill <pid>" in output
|
||||
assert _find_stale_dashboard_pids() == [12345]
|
||||
|
||||
def test_multiple_dashboard_pids(self, capsys):
|
||||
"""Multiple dashboard processes — all PIDs listed."""
|
||||
def test_multiple_matches(self):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
@@ -65,15 +115,9 @@ class TestWarnStaleDashboardProcesses:
|
||||
]) + "\n",
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "3 dashboard process" in output
|
||||
assert "PID 12345" in output
|
||||
assert "PID 12346" in output
|
||||
assert "PID 12347" in output
|
||||
assert sorted(_find_stale_dashboard_pids()) == [12345, 12346, 12347]
|
||||
|
||||
def test_self_pid_excluded(self, capsys):
|
||||
"""The current process PID should not be reported."""
|
||||
def test_self_pid_excluded(self):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
@@ -83,41 +127,51 @@ class TestWarnStaleDashboardProcesses:
|
||||
]) + "\n",
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
# The self PID may still appear inside an unrelated context, so anchor
|
||||
# the check to "PID <self>" which is how the warning prints.
|
||||
assert f"PID {os.getpid()}" not in output
|
||||
assert "PID 12345" in output
|
||||
pids = _find_stale_dashboard_pids()
|
||||
assert os.getpid() not in pids
|
||||
assert 12345 in pids
|
||||
|
||||
def test_ps_not_found_silently_ignored(self, capsys):
|
||||
"""If ps is missing (FileNotFoundError), no crash, no warning."""
|
||||
def test_ps_not_found_returns_empty(self):
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError):
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert output == ""
|
||||
assert _find_stale_dashboard_pids() == []
|
||||
|
||||
def test_ps_timeout_silently_ignored(self, capsys):
|
||||
"""If ps times out, no crash, no warning."""
|
||||
def test_ps_timeout_returns_empty(self):
|
||||
import subprocess as sp
|
||||
|
||||
with patch("subprocess.run", side_effect=sp.TimeoutExpired("ps", 10)):
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert output == ""
|
||||
assert _find_stale_dashboard_pids() == []
|
||||
|
||||
def test_empty_ps_output_no_warning(self, capsys):
|
||||
"""ps returns 0 but empty stdout — no warning."""
|
||||
def test_unrelated_process_containing_word_dashboard_not_matched(self):
|
||||
"""Guards against greedy pgrep-style matching catching chat sessions
|
||||
or unrelated processes whose cmdline happens to contain 'dashboard'.
|
||||
"""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="\n", stderr=""
|
||||
returncode=0,
|
||||
stdout="\n".join([
|
||||
_ps_line(12345, "python3 -m hermes_cli.main dashboard --port 9119"),
|
||||
_ps_line(22222, "python3 -m hermes_cli.main chat -q 'rewrite my dashboard'"),
|
||||
_ps_line(33333, "node /opt/grafana/dashboard-server.js"),
|
||||
]) + "\n",
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "dashboard process" not in output
|
||||
pids = _find_stale_dashboard_pids()
|
||||
assert pids == [12345]
|
||||
|
||||
def test_invalid_pid_lines_skipped(self, capsys):
|
||||
"""Malformed ps lines should be skipped gracefully."""
|
||||
def test_grep_lines_ignored(self):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="\n".join([
|
||||
_ps_line(99999, "grep hermes dashboard"),
|
||||
_ps_line(12345, "hermes dashboard --port 9119"),
|
||||
]) + "\n",
|
||||
stderr="",
|
||||
)
|
||||
pids = _find_stale_dashboard_pids()
|
||||
assert 99999 not in pids
|
||||
assert 12345 in pids
|
||||
|
||||
def test_invalid_pid_lines_skipped(self):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
@@ -128,50 +182,213 @@ class TestWarnStaleDashboardProcesses:
|
||||
]) + "\n",
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "PID 12345" in output
|
||||
assert "1 dashboard process" in output
|
||||
pids = _find_stale_dashboard_pids()
|
||||
assert pids == [12345]
|
||||
|
||||
def test_unrelated_process_containing_word_dashboard_not_matched(self, capsys):
|
||||
"""A process whose cmdline contains 'dashboard' but isn't a hermes
|
||||
dashboard process must NOT be flagged. This guards against the old
|
||||
``pgrep -f "hermes.*dashboard"`` greedy regex that matched e.g. a
|
||||
chat session argv containing both words.
|
||||
"""
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="POSIX kill semantics")
|
||||
class TestKillStaleDashboardPosix:
|
||||
"""Kill path on Linux / macOS: SIGTERM then SIGKILL any survivors."""
|
||||
|
||||
def test_no_stale_processes_is_a_noop(self, capsys):
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids", return_value=[]):
|
||||
_kill_stale_dashboard_processes()
|
||||
assert capsys.readouterr().out == ""
|
||||
|
||||
def test_sigterm_graceful_exit(self, capsys):
|
||||
"""Processes that exit on SIGTERM (the probe gets ProcessLookupError)
|
||||
are reported as stopped and SIGKILL is never sent."""
|
||||
import signal as _signal
|
||||
|
||||
killed_signals: list[tuple[int, int]] = []
|
||||
|
||||
def fake_kill(pid, sig):
|
||||
killed_signals.append((pid, sig))
|
||||
if sig == 0:
|
||||
# Probe after SIGTERM → "process gone".
|
||||
raise ProcessLookupError
|
||||
# SIGTERM itself: succeed silently.
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[12345, 12346]), \
|
||||
patch("os.kill", side_effect=fake_kill), \
|
||||
patch("time.sleep"):
|
||||
_kill_stale_dashboard_processes()
|
||||
|
||||
# Both got SIGTERM.
|
||||
sigterms = [pid for pid, sig in killed_signals if sig == _signal.SIGTERM]
|
||||
assert sorted(sigterms) == [12345, 12346]
|
||||
# No SIGKILL was needed.
|
||||
assert not any(sig == _signal.SIGKILL for _, sig in killed_signals)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Stopping 2 dashboard" in out
|
||||
assert "✓ stopped PID 12345" in out
|
||||
assert "✓ stopped PID 12346" in out
|
||||
assert "Restart the dashboard" in out
|
||||
|
||||
def test_sigkill_fallback_for_survivors(self, capsys):
|
||||
"""If a process survives SIGTERM + the grace window, SIGKILL is sent."""
|
||||
import signal as _signal
|
||||
|
||||
sent: list[tuple[int, int]] = []
|
||||
|
||||
def fake_kill(pid, sig):
|
||||
sent.append((pid, sig))
|
||||
# Simulate stubborn process: probe (sig 0) always succeeds,
|
||||
# SIGTERM does nothing, SIGKILL is where it "dies".
|
||||
if sig in (_signal.SIGTERM, 0, _signal.SIGKILL):
|
||||
return
|
||||
# Any other signal — also fine.
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[99999]), \
|
||||
patch("os.kill", side_effect=fake_kill), \
|
||||
patch("time.sleep"), \
|
||||
patch("time.monotonic", side_effect=[0.0] + [10.0] * 20):
|
||||
# monotonic jumps past the 3s deadline on the second read so the
|
||||
# grace loop exits immediately after one iteration.
|
||||
_kill_stale_dashboard_processes()
|
||||
|
||||
signals_sent = [sig for _, sig in sent]
|
||||
assert _signal.SIGTERM in signals_sent
|
||||
assert _signal.SIGKILL in signals_sent
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "✓ stopped PID 99999" in out
|
||||
|
||||
def test_permission_error_is_reported_not_raised(self, capsys):
|
||||
"""os.kill raising PermissionError (e.g. another user's process)
|
||||
must not abort hermes update — it's reported as a failure and we
|
||||
move on."""
|
||||
def fake_kill(pid, sig):
|
||||
raise PermissionError("Operation not permitted")
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[12345]), \
|
||||
patch("os.kill", side_effect=fake_kill), \
|
||||
patch("time.sleep"):
|
||||
_kill_stale_dashboard_processes() # must not raise
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "✗ failed to stop PID 12345" in out
|
||||
assert "Operation not permitted" in out
|
||||
|
||||
def test_process_already_gone_counts_as_stopped(self, capsys):
|
||||
"""ProcessLookupError on the initial SIGTERM means the process
|
||||
already exited between detection and the kill — treat as success."""
|
||||
def fake_kill(pid, sig):
|
||||
raise ProcessLookupError
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[12345]), \
|
||||
patch("os.kill", side_effect=fake_kill), \
|
||||
patch("time.sleep"):
|
||||
_kill_stale_dashboard_processes()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "✓ stopped PID 12345" in out
|
||||
assert "failed to stop" not in out
|
||||
|
||||
|
||||
class TestKillStaleDashboardWindows:
|
||||
"""Kill path on Windows: taskkill /F."""
|
||||
|
||||
def test_taskkill_invoked_for_each_pid(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(sys, "platform", "win32")
|
||||
|
||||
def fake_run(args, *a, **kw):
|
||||
# taskkill returns 0 on success
|
||||
return MagicMock(returncode=0, stdout="", stderr="")
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[12345, 12346]), \
|
||||
patch("subprocess.run", side_effect=fake_run) as mock_run:
|
||||
_kill_stale_dashboard_processes()
|
||||
|
||||
# Each PID triggered a taskkill /PID <n> /F invocation.
|
||||
taskkill_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if c.args and isinstance(c.args[0], list) and c.args[0][:1] == ["taskkill"]
|
||||
]
|
||||
assert len(taskkill_calls) == 2
|
||||
assert ["taskkill", "/PID", "12345", "/F"] in [c.args[0] for c in taskkill_calls]
|
||||
assert ["taskkill", "/PID", "12346", "/F"] in [c.args[0] for c in taskkill_calls]
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "✓ stopped PID 12345" in out
|
||||
assert "✓ stopped PID 12346" in out
|
||||
|
||||
def test_taskkill_failure_is_reported(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(sys, "platform", "win32")
|
||||
|
||||
def fake_run(args, *a, **kw):
|
||||
return MagicMock(returncode=128, stdout="",
|
||||
stderr="ERROR: Access is denied.")
|
||||
|
||||
with patch("hermes_cli.main._find_stale_dashboard_pids",
|
||||
return_value=[12345]), \
|
||||
patch("subprocess.run", side_effect=fake_run):
|
||||
_kill_stale_dashboard_processes() # must not raise
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "✗ failed to stop PID 12345" in out
|
||||
assert "Access is denied" in out
|
||||
|
||||
|
||||
class TestBackCompatAlias:
|
||||
"""``_warn_stale_dashboard_processes`` is kept as an alias for the
|
||||
new kill function so old imports don't break."""
|
||||
|
||||
def test_alias_is_the_kill_function(self):
|
||||
assert _warn_stale_dashboard_processes is _kill_stale_dashboard_processes
|
||||
|
||||
|
||||
class TestWindowsWmicEncoding:
|
||||
"""Regression tests for #17049 — the Windows wmic branch must not crash
|
||||
`hermes update` on non-UTF-8 system locales (e.g. cp936 on zh-CN).
|
||||
"""
|
||||
|
||||
def test_wmic_invoked_with_utf8_ignore_errors(self, monkeypatch):
|
||||
"""The wmic subprocess.run call must pass encoding='utf-8' and
|
||||
errors='ignore' so the subprocess reader thread cannot raise
|
||||
UnicodeDecodeError on non-UTF-8 wmic output."""
|
||||
monkeypatch.setattr(sys, "platform", "win32")
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="\n".join([
|
||||
# Legitimate dashboard — should match.
|
||||
_ps_line(12345, "python3 -m hermes_cli.main dashboard --port 9119"),
|
||||
# hermes running something else, with "dashboard" as a
|
||||
# substring of an unrelated arg — should NOT match.
|
||||
_ps_line(22222, "python3 -m hermes_cli.main chat -q 'rewrite my dashboard'"),
|
||||
# Completely unrelated process mentioning dashboard.
|
||||
_ps_line(33333, "node /opt/grafana/dashboard-server.js"),
|
||||
]) + "\n",
|
||||
stdout=(
|
||||
"CommandLine=python -m hermes_cli.main dashboard\n"
|
||||
"ProcessId=12345\n"
|
||||
),
|
||||
stderr="",
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "1 dashboard process" in output
|
||||
assert "PID 12345" in output
|
||||
assert "PID 22222" not in output
|
||||
assert "PID 33333" not in output
|
||||
_find_stale_dashboard_pids()
|
||||
|
||||
def test_grep_lines_ignored(self, capsys):
|
||||
"""Lines containing 'grep' (from a pipe in ps output) are ignored."""
|
||||
# The wmic call is the first subprocess.run invocation.
|
||||
assert mock_run.called, "subprocess.run was not invoked"
|
||||
wmic_call = mock_run.call_args_list[0]
|
||||
kwargs = wmic_call.kwargs
|
||||
assert kwargs.get("encoding") == "utf-8", (
|
||||
"encoding kwarg must be 'utf-8' so wmic output is decoded "
|
||||
"deterministically rather than via the implicit reader-thread "
|
||||
"default that crashes on non-UTF-8 locales (#17049)."
|
||||
)
|
||||
assert kwargs.get("errors") == "ignore", (
|
||||
"errors kwarg must be 'ignore' so undecodable bytes don't take "
|
||||
"down the reader thread (#17049)."
|
||||
)
|
||||
|
||||
def test_wmic_returns_none_stdout_does_not_crash(self, monkeypatch):
|
||||
"""If subprocess.run returns successfully but stdout is None — which
|
||||
is what Python 3.11 leaves behind when the reader thread silently
|
||||
crashed on UnicodeDecodeError before this fix landed — detection
|
||||
must short-circuit instead of raising AttributeError on
|
||||
``None.split('\\n')`` and aborting `hermes update` (#17049)."""
|
||||
monkeypatch.setattr(sys, "platform", "win32")
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="\n".join([
|
||||
_ps_line(99999, "grep hermes dashboard"),
|
||||
_ps_line(12345, "hermes dashboard --port 9119"),
|
||||
]) + "\n",
|
||||
stderr="",
|
||||
returncode=0, stdout=None, stderr=""
|
||||
)
|
||||
_warn_stale_dashboard_processes()
|
||||
output = capsys.readouterr().out
|
||||
assert "PID 99999" not in output
|
||||
assert "PID 12345" in output
|
||||
# Must not raise.
|
||||
assert _find_stale_dashboard_pids() == []
|
||||
|
||||
@@ -453,6 +453,142 @@ def test_list_authenticated_providers_no_duplicate_labels_across_schemas(monkeyp
|
||||
)
|
||||
|
||||
|
||||
def test_list_authenticated_providers_hides_custom_shadowing_builtin_endpoint(monkeypatch):
|
||||
"""#16970: a custom_providers entry whose ``base_url`` matches a built-in
|
||||
provider's endpoint should be hidden. The built-in row already represents
|
||||
that endpoint with its canonical slug, curated model list, and auth wiring.
|
||||
|
||||
Repro: user sets ``DASHSCOPE_API_KEY`` (triggers the built-in ``alibaba``
|
||||
row pointing at the static ``inference_base_url``) AND defines a
|
||||
``my-alibaba`` custom provider pointing at the same URL. Before the fix,
|
||||
the picker showed both rows for one endpoint.
|
||||
"""
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test")
|
||||
monkeypatch.setattr(
|
||||
"agent.models_dev.fetch_models_dev",
|
||||
lambda: {
|
||||
"alibaba": {
|
||||
"name": "Alibaba Cloud (DashScope)",
|
||||
"env": ["DASHSCOPE_API_KEY"],
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
custom_providers = [
|
||||
{
|
||||
"name": "my-alibaba",
|
||||
# Matches PROVIDER_REGISTRY['alibaba'].inference_base_url exactly.
|
||||
"base_url": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "sk-sp-test",
|
||||
"model": "qwen3.6-plus",
|
||||
"models": {"qwen3.6-plus": {"context_length": 500000}},
|
||||
}
|
||||
]
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="my-alibaba",
|
||||
user_providers={},
|
||||
custom_providers=custom_providers,
|
||||
max_models=50,
|
||||
)
|
||||
|
||||
slugs = [p["slug"] for p in providers]
|
||||
# Built-in alibaba row should be present.
|
||||
assert "alibaba" in slugs, (
|
||||
f"Expected built-in alibaba row, got slugs: {slugs}"
|
||||
)
|
||||
# Custom shadow row should be hidden — its base_url matches the built-in's.
|
||||
assert not any("my-alibaba" in s for s in slugs), (
|
||||
f"Custom my-alibaba should have been dedup'd against the built-in "
|
||||
f"alibaba endpoint, got slugs: {slugs}"
|
||||
)
|
||||
|
||||
|
||||
def test_list_authenticated_providers_keeps_custom_with_distinct_endpoint(monkeypatch):
|
||||
"""Dedup must only apply when the endpoint matches a built-in. A custom
|
||||
provider on a genuinely distinct endpoint stays visible even if a
|
||||
built-in is also authenticated."""
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test")
|
||||
monkeypatch.setattr(
|
||||
"agent.models_dev.fetch_models_dev",
|
||||
lambda: {
|
||||
"alibaba": {
|
||||
"name": "Alibaba Cloud (DashScope)",
|
||||
"env": ["DASHSCOPE_API_KEY"],
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
custom_providers = [
|
||||
{
|
||||
"name": "my-private-relay",
|
||||
"base_url": "https://relay.example.internal/v1",
|
||||
"api_key": "sk-relay-test",
|
||||
"model": "qwen3.6-plus",
|
||||
"models": {"qwen3.6-plus": {}},
|
||||
}
|
||||
]
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="my-private-relay",
|
||||
user_providers={},
|
||||
custom_providers=custom_providers,
|
||||
max_models=50,
|
||||
)
|
||||
|
||||
slugs = [p["slug"] for p in providers]
|
||||
assert any("my-private-relay" in s for s in slugs), (
|
||||
f"Custom provider on distinct endpoint must stay visible, got: {slugs}"
|
||||
)
|
||||
|
||||
|
||||
def test_list_authenticated_providers_dedup_honors_base_url_env_override(monkeypatch):
|
||||
"""The dedup must track the EFFECTIVE endpoint — if DASHSCOPE_BASE_URL
|
||||
overrides the static inference_base_url, a custom provider pointing at
|
||||
the overridden URL (not the static one) should still be recognized as
|
||||
a duplicate."""
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test")
|
||||
monkeypatch.setenv(
|
||||
"DASHSCOPE_BASE_URL",
|
||||
"https://custom-dashscope.example.com/v1",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.models_dev.fetch_models_dev",
|
||||
lambda: {
|
||||
"alibaba": {
|
||||
"name": "Alibaba Cloud (DashScope)",
|
||||
"env": ["DASHSCOPE_API_KEY"],
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
custom_providers = [
|
||||
{
|
||||
"name": "my-dashscope-override",
|
||||
# Same URL as DASHSCOPE_BASE_URL env override above.
|
||||
"base_url": "https://custom-dashscope.example.com/v1",
|
||||
"api_key": "sk-test",
|
||||
"model": "qwen3.6-plus",
|
||||
}
|
||||
]
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="alibaba",
|
||||
user_providers={},
|
||||
custom_providers=custom_providers,
|
||||
max_models=50,
|
||||
)
|
||||
|
||||
slugs = [p["slug"] for p in providers]
|
||||
assert not any("my-dashscope-override" in s for s in slugs), (
|
||||
f"Custom entry matching env-overridden built-in endpoint should be "
|
||||
f"dedup'd, got: {slugs}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for _get_named_custom_provider with providers: dict
|
||||
# =============================================================================
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestReloadEnv:
|
||||
"""reload_env() adds vars from .env that are not in os.environ."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("TEST_RELOAD_VAR=hello123\n")
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}):
|
||||
os.environ.pop("TEST_RELOAD_VAR", None)
|
||||
count = reload_env()
|
||||
assert count >= 1
|
||||
@@ -40,7 +40,7 @@ class TestReloadEnv:
|
||||
"""reload_env() updates vars whose value changed on disk."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("TEST_RELOAD_VAR=old_value\n")
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}):
|
||||
os.environ["TEST_RELOAD_VAR"] = "old_value"
|
||||
# Now change the file
|
||||
env_file.write_text("TEST_RELOAD_VAR=new_value\n")
|
||||
@@ -55,7 +55,7 @@ class TestReloadEnv:
|
||||
env_file.write_text("") # empty .env
|
||||
# Pick a known key from OPTIONAL_ENV_VARS
|
||||
known_key = next(iter(OPTIONAL_ENV_VARS.keys()))
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}):
|
||||
os.environ[known_key] = "stale_value"
|
||||
count = reload_env()
|
||||
assert known_key not in os.environ
|
||||
@@ -65,7 +65,7 @@ class TestReloadEnv:
|
||||
"""reload_env() preserves non-Hermes env vars even when absent from .env."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("")
|
||||
with patch("hermes_cli.config.get_env_path", return_value=env_file):
|
||||
with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}):
|
||||
os.environ["MY_CUSTOM_UNRELATED_VAR"] = "keep_me"
|
||||
reload_env()
|
||||
assert os.environ.get("MY_CUSTOM_UNRELATED_VAR") == "keep_me"
|
||||
@@ -371,6 +371,12 @@ class TestBuildSchemaFromConfig:
|
||||
assert entry["type"] == "select"
|
||||
assert "options" in entry
|
||||
assert "local" in entry["options"]
|
||||
assert "vercel_sandbox" in entry["options"]
|
||||
runtime_entry = CONFIG_SCHEMA["terminal.vercel_runtime"]
|
||||
assert runtime_entry["type"] == "select"
|
||||
assert "node24" in runtime_entry["options"]
|
||||
assert "python3.13" in runtime_entry["options"]
|
||||
assert len(runtime_entry["options"]) >= 3
|
||||
|
||||
def test_empty_prefix_produces_correct_keys(self):
|
||||
from hermes_cli.web_server import _build_schema_from_config
|
||||
@@ -671,8 +677,12 @@ class TestNewEndpoints:
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["command"] == "hermes setup"
|
||||
|
||||
def test_profiles_create_creates_wrapper_alias_when_safe(self):
|
||||
from pathlib import Path
|
||||
def test_profiles_create_creates_wrapper_alias_when_safe(self, monkeypatch, tmp_path):
|
||||
import hermes_cli.profiles as profiles_mod
|
||||
|
||||
wrapper_dir = tmp_path / "bin"
|
||||
wrapper_dir.mkdir()
|
||||
monkeypatch.setattr(profiles_mod, "_get_wrapper_dir", lambda: wrapper_dir)
|
||||
|
||||
resp = self.client.post(
|
||||
"/api/profiles",
|
||||
@@ -680,7 +690,7 @@ class TestNewEndpoints:
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
wrapper_path = Path.home() / ".local" / "bin" / "writer"
|
||||
wrapper_path = wrapper_dir / "writer"
|
||||
assert wrapper_path.exists()
|
||||
assert wrapper_path.read_text() == '#!/bin/sh\nexec hermes -p writer "$@"\n'
|
||||
|
||||
@@ -2057,14 +2067,24 @@ class TestPtyWebSocket:
|
||||
assert b"round-trip-payload" in buf
|
||||
|
||||
def test_resize_escape_is_forwarded(self, monkeypatch):
|
||||
# Resize escape gets intercepted and applied via TIOCSWINSZ,
|
||||
# then ``tput cols/lines`` reports the new dimensions back.
|
||||
# Resize escape gets intercepted and applied via TIOCSWINSZ, then the
|
||||
# child reads the TTY ioctl directly. Avoid tput because CI may not set
|
||||
# TERM for non-interactive shells.
|
||||
import sys
|
||||
|
||||
winsize_script = (
|
||||
"import fcntl, struct, termios, time; "
|
||||
"time.sleep(0.15); "
|
||||
"rows, cols, *_ = struct.unpack('HHHH', "
|
||||
"fcntl.ioctl(0, termios.TIOCGWINSZ, b'\\0' * 8)); "
|
||||
"print(cols); print(rows)"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
self.ws_module,
|
||||
"_resolve_chat_argv",
|
||||
# sleep gives the test time to push the resize before tput runs
|
||||
# sleep gives the test time to push the resize before the child reads the ioctl.
|
||||
lambda resume=None, sidecar_url=None: (
|
||||
["/bin/sh", "-c", "sleep 0.15; tput cols; tput lines"],
|
||||
[sys.executable, "-c", winsize_script],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
@@ -2153,13 +2173,30 @@ class TestPtyWebSocket:
|
||||
def test_pub_broadcasts_to_events_subscribers(self, monkeypatch):
|
||||
"""Frame written to /api/pub is rebroadcast verbatim to every
|
||||
/api/events subscriber on the same channel."""
|
||||
import time
|
||||
from urllib.parse import urlencode
|
||||
from hermes_cli import web_server as ws_mod
|
||||
|
||||
qs = urlencode({"token": self.token, "channel": "broadcast-test"})
|
||||
pub_path = f"/api/pub?{qs}"
|
||||
sub_path = f"/api/events?{qs}"
|
||||
|
||||
with self.client.websocket_connect(sub_path) as sub:
|
||||
# Wait for the subscriber to be registered on the server side.
|
||||
# websocket_connect returns when ws.accept() completes, but the
|
||||
# server adds us to ``_event_channels`` in a follow-up await,
|
||||
# so a publish immediately after connect can race ahead of the
|
||||
# subscriber registration and the message is dropped.
|
||||
deadline = time.monotonic() + 5.0
|
||||
while time.monotonic() < deadline:
|
||||
if ws_mod._event_channels.get("broadcast-test"):
|
||||
break
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"subscriber did not register on channel within 5s"
|
||||
)
|
||||
|
||||
with self.client.websocket_connect(pub_path) as pub:
|
||||
pub.send_text('{"type":"tool.start","payload":{"tool_id":"t1"}}')
|
||||
received = sub.receive_text()
|
||||
|
||||
233
tests/openviking_plugin/test_openviking.py
Normal file
233
tests/openviking_plugin/test_openviking.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Tests for plugins/memory/openviking/__init__.py — URI normalization and payload handling."""
|
||||
|
||||
import json
|
||||
|
||||
from plugins.memory.openviking import OpenVikingMemoryProvider
|
||||
|
||||
|
||||
class FakeVikingClient:
|
||||
def __init__(self, responses):
|
||||
self.responses = responses
|
||||
self.calls = []
|
||||
|
||||
def get(self, path, params=None, **kwargs):
|
||||
self.calls.append((path, params or {}))
|
||||
response = self.responses[(path, tuple(sorted((params or {}).items())))]
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
return response
|
||||
|
||||
|
||||
class TestOpenVikingSummaryUriNormalization:
|
||||
def test_normalize_summary_uri_maps_pseudo_files_to_parent_directory(self):
|
||||
assert OpenVikingMemoryProvider._normalize_summary_uri("viking://user/hermes/.overview.md") == "viking://user/hermes"
|
||||
assert OpenVikingMemoryProvider._normalize_summary_uri("viking://resources/.abstract.md") == "viking://resources"
|
||||
assert OpenVikingMemoryProvider._normalize_summary_uri("viking://") == "viking://"
|
||||
assert OpenVikingMemoryProvider._normalize_summary_uri("viking://user/hermes/memories/profile.md") == "viking://user/hermes/memories/profile.md"
|
||||
|
||||
|
||||
class TestOpenVikingRead:
|
||||
def test_overview_read_normalizes_uri_and_unwraps_result(self):
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/content/overview",
|
||||
(("uri", "viking://user/hermes"),),
|
||||
): {"result": {"content": "overview text"}},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_read({"uri": "viking://user/hermes/.overview.md", "level": "overview"}))
|
||||
|
||||
assert result["uri"] == "viking://user/hermes/.overview.md"
|
||||
assert result["resolved_uri"] == "viking://user/hermes"
|
||||
assert result["level"] == "overview"
|
||||
assert result["content"] == "overview text"
|
||||
assert provider._client.calls == [(
|
||||
"/api/v1/content/overview",
|
||||
{"uri": "viking://user/hermes"},
|
||||
)]
|
||||
|
||||
def test_full_read_keeps_original_uri(self):
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/content/read",
|
||||
(("uri", "viking://user/hermes/memories/profile.md"),),
|
||||
): {"result": "full text"},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_read({"uri": "viking://user/hermes/memories/profile.md", "level": "full"}))
|
||||
|
||||
assert result["uri"] == "viking://user/hermes/memories/profile.md"
|
||||
assert result["resolved_uri"] == "viking://user/hermes/memories/profile.md"
|
||||
assert result["level"] == "full"
|
||||
assert result["content"] == "full text"
|
||||
assert provider._client.calls == [(
|
||||
"/api/v1/content/read",
|
||||
{"uri": "viking://user/hermes/memories/profile.md"},
|
||||
)]
|
||||
|
||||
def test_overview_file_uri_routes_straight_to_content_read_via_stat_probe(self):
|
||||
"""Pre-check via fs/stat: file URIs skip the directory-only endpoint entirely."""
|
||||
provider = OpenVikingMemoryProvider()
|
||||
file_uri = "viking://user/hermes/memories/entities/mem_abc.md"
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/fs/stat",
|
||||
(("uri", file_uri),),
|
||||
): {"result": {"isDir": False}},
|
||||
(
|
||||
"/api/v1/content/read",
|
||||
(("uri", file_uri),),
|
||||
): {"result": {"content": "full content"}},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_read({"uri": file_uri, "level": "overview"}))
|
||||
|
||||
assert result["uri"] == file_uri
|
||||
assert result["resolved_uri"] == file_uri
|
||||
assert result["level"] == "overview"
|
||||
assert result["fallback"] == "content/read"
|
||||
assert result["content"] == "full content"
|
||||
assert provider._client.calls == [
|
||||
("/api/v1/fs/stat", {"uri": file_uri}),
|
||||
("/api/v1/content/read", {"uri": file_uri}),
|
||||
]
|
||||
|
||||
def test_overview_dir_uri_skips_stat_when_pseudo_summary(self):
|
||||
"""Pseudo-URI path already resolves to dir, so no stat probe needed."""
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/content/overview",
|
||||
(("uri", "viking://user/hermes"),),
|
||||
): {"result": "overview"},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_read({"uri": "viking://user/hermes/.overview.md", "level": "overview"}))
|
||||
|
||||
assert result["content"] == "overview"
|
||||
# No fs/stat call — normalization already determined it's a directory.
|
||||
assert provider._client.calls == [
|
||||
("/api/v1/content/overview", {"uri": "viking://user/hermes"}),
|
||||
]
|
||||
|
||||
def test_overview_directory_uri_uses_stat_probe_then_overview(self):
|
||||
"""Non-pseudo directory URI: stat → isDir=True → summary endpoint."""
|
||||
provider = OpenVikingMemoryProvider()
|
||||
dir_uri = "viking://user/hermes/memories"
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/fs/stat",
|
||||
(("uri", dir_uri),),
|
||||
): {"result": {"isDir": True}},
|
||||
(
|
||||
"/api/v1/content/overview",
|
||||
(("uri", dir_uri),),
|
||||
): {"result": "dir overview"},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_read({"uri": dir_uri, "level": "overview"}))
|
||||
|
||||
assert result["content"] == "dir overview"
|
||||
assert "fallback" not in result
|
||||
assert provider._client.calls == [
|
||||
("/api/v1/fs/stat", {"uri": dir_uri}),
|
||||
("/api/v1/content/overview", {"uri": dir_uri}),
|
||||
]
|
||||
|
||||
def test_overview_file_uri_falls_back_via_exception_when_stat_indeterminate(self):
|
||||
"""If fs/stat raises or returns unknown shape, legacy exception fallback still kicks in."""
|
||||
provider = OpenVikingMemoryProvider()
|
||||
file_uri = "viking://user/hermes/memories/entities/mem_abc.md"
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/fs/stat",
|
||||
(("uri", file_uri),),
|
||||
): RuntimeError("stat unavailable"),
|
||||
(
|
||||
"/api/v1/content/overview",
|
||||
(("uri", file_uri),),
|
||||
): RuntimeError("500 Internal Server Error"),
|
||||
(
|
||||
"/api/v1/content/read",
|
||||
(("uri", file_uri),),
|
||||
): {"result": {"content": "fallback full content"}},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_read({"uri": file_uri, "level": "overview"}))
|
||||
|
||||
assert result["uri"] == file_uri
|
||||
assert result["level"] == "overview"
|
||||
assert result["fallback"] == "content/read"
|
||||
assert result["content"] == "fallback full content"
|
||||
assert provider._client.calls == [
|
||||
("/api/v1/fs/stat", {"uri": file_uri}),
|
||||
("/api/v1/content/overview", {"uri": file_uri}),
|
||||
("/api/v1/content/read", {"uri": file_uri}),
|
||||
]
|
||||
|
||||
def test_summary_uri_error_does_not_fallback_and_raises(self):
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/content/overview",
|
||||
(("uri", "viking://user/hermes"),),
|
||||
): RuntimeError("500 Internal Server Error"),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
provider._tool_read({"uri": "viking://user/hermes/.overview.md", "level": "overview"})
|
||||
assert False, "Expected summary endpoint error to be raised"
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
assert provider._client.calls == [
|
||||
("/api/v1/content/overview", {"uri": "viking://user/hermes"}),
|
||||
]
|
||||
|
||||
|
||||
class TestOpenVikingBrowse:
|
||||
def test_list_browse_unwraps_and_normalizes_entry_shapes(self):
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = FakeVikingClient(
|
||||
{
|
||||
(
|
||||
"/api/v1/fs/ls",
|
||||
(("uri", "viking://user/hermes"),),
|
||||
): {
|
||||
"result": {
|
||||
"entries": [
|
||||
{"name": "memories", "uri": "viking://user/hermes/memories", "type": "dir"},
|
||||
{"rel_path": "profile.md", "uri": "viking://user/hermes/memories/profile.md", "isDir": False, "abstract": "Profile"},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = json.loads(provider._tool_browse({"action": "list", "path": "viking://user/hermes"}))
|
||||
|
||||
assert result["path"] == "viking://user/hermes"
|
||||
assert result["entries"] == [
|
||||
{"name": "memories", "uri": "viking://user/hermes/memories", "type": "dir", "abstract": ""},
|
||||
{"name": "profile.md", "uri": "viking://user/hermes/memories/profile.md", "type": "file", "abstract": "Profile"},
|
||||
]
|
||||
assert provider._client.calls == [(
|
||||
"/api/v1/fs/ls",
|
||||
{"uri": "viking://user/hermes"},
|
||||
)]
|
||||
@@ -669,7 +669,7 @@ class TestSyncTurn:
|
||||
p._client = _make_mock_client()
|
||||
|
||||
p.sync_turn("hello", "hi there")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
||||
@@ -710,8 +710,7 @@ class TestSyncTurn:
|
||||
def test_sync_turn_with_tags(self, provider_with_config):
|
||||
p = provider_with_config(retain_tags=["conv", "session1"])
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert "conv" in item["tags"]
|
||||
assert "session1" in item["tags"]
|
||||
@@ -720,8 +719,7 @@ class TestSyncTurn:
|
||||
def test_sync_turn_uses_aretain_batch(self, provider):
|
||||
"""sync_turn should use aretain_batch with retain_async."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
provider._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"].startswith("test-session-")
|
||||
@@ -732,8 +730,7 @@ class TestSyncTurn:
|
||||
def test_sync_turn_custom_context(self, provider_with_config):
|
||||
p = provider_with_config(retain_context="my-agent")
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert item["context"] == "my-agent"
|
||||
|
||||
@@ -744,7 +741,7 @@ class TestSyncTurn:
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
assert p._sync_thread is None
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"].startswith("test-session-")
|
||||
@@ -765,15 +762,13 @@ class TestSyncTurn:
|
||||
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
p._client.aretain_batch.reset_mock()
|
||||
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
p.sync_turn("turn4-user", "turn4-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
||||
# Should contain ALL turns from the session
|
||||
@@ -785,8 +780,7 @@ class TestSyncTurn:
|
||||
def test_sync_turn_passes_document_id(self, provider):
|
||||
"""sync_turn should pass document_id (session_id + per-startup ts)."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
# Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds}
|
||||
assert call_kwargs["document_id"].startswith("test-session-")
|
||||
@@ -819,8 +813,7 @@ class TestSyncTurn:
|
||||
def test_sync_turn_session_tag(self, provider):
|
||||
"""Each retain should be tagged with session:<id> for filtering."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert "session:test-session" in item["tags"]
|
||||
|
||||
@@ -841,8 +834,7 @@ class TestSyncTurn:
|
||||
)
|
||||
p._client = _make_mock_client()
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert "session:child-session" in item["tags"]
|
||||
@@ -851,15 +843,14 @@ class TestSyncTurn:
|
||||
def test_sync_turn_error_does_not_raise(self, provider):
|
||||
provider._client.aretain_batch.side_effect = RuntimeError("network error")
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
|
||||
def test_sync_turn_preserves_unicode(self, provider_with_config):
|
||||
"""Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact."""
|
||||
p = provider_with_config()
|
||||
p._client = _make_mock_client()
|
||||
p.sync_turn("안녕 こんにちは 你好", "👨👩👧👦 family")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
# ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON,
|
||||
@@ -871,6 +862,216 @@ class TestSyncTurn:
|
||||
assert "👨👩👧👦" in raw_json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shutdown / writer tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShutdownRace:
|
||||
def test_sync_turn_uses_single_writer_thread(self, provider):
|
||||
"""All retains run through one long-lived writer thread."""
|
||||
provider.sync_turn("a", "b")
|
||||
provider._retain_queue.join()
|
||||
first_writer = provider._writer_thread
|
||||
assert first_writer is not None
|
||||
assert first_writer.is_alive()
|
||||
|
||||
provider.sync_turn("c", "d")
|
||||
provider._retain_queue.join()
|
||||
# Same thread reused — no ad-hoc thread per call.
|
||||
assert provider._writer_thread is first_writer
|
||||
assert provider._client.aretain_batch.call_count == 2
|
||||
|
||||
def test_sync_turn_after_shutdown_is_dropped(self, provider):
|
||||
"""Once shutdown has fired, new sync_turn() calls are no-ops.
|
||||
|
||||
This is the core of the fix: the plugin must not enqueue a retain
|
||||
during interpreter teardown — that's what causes the
|
||||
'cannot schedule new futures' RuntimeError + unclosed aiohttp
|
||||
sessions on CLI exit.
|
||||
"""
|
||||
client = provider._client
|
||||
provider.shutdown()
|
||||
before_calls = client.aretain_batch.call_count
|
||||
provider.sync_turn("late", "turn")
|
||||
# No new enqueue — the retain queue stays empty.
|
||||
assert provider._retain_queue.empty()
|
||||
# And no new client call (would be impossible anyway since shutdown
|
||||
# nulled self._client; we assert via the captured handle).
|
||||
assert client.aretain_batch.call_count == before_calls
|
||||
|
||||
def test_queue_prefetch_after_shutdown_is_dropped(self, provider):
|
||||
provider.shutdown()
|
||||
provider.queue_prefetch("late query")
|
||||
assert provider._prefetch_thread is None
|
||||
|
||||
def test_shutdown_drains_pending_retains(self, provider):
|
||||
"""Shutdown must wait for queued retains to complete, not abandon them.
|
||||
|
||||
Otherwise the LAST in-flight turn — typically the most important —
|
||||
is silently lost.
|
||||
"""
|
||||
client = provider._client
|
||||
provider.sync_turn("a", "b")
|
||||
provider.sync_turn("c", "d")
|
||||
provider.shutdown()
|
||||
# Both retains drained before shutdown returned.
|
||||
assert client.aretain_batch.call_count == 2
|
||||
assert provider._retain_queue.empty()
|
||||
|
||||
def test_shutdown_is_idempotent(self, provider):
|
||||
provider.sync_turn("a", "b")
|
||||
provider.shutdown()
|
||||
# Second shutdown shouldn't blow up or re-close the client.
|
||||
provider.shutdown()
|
||||
assert provider._shutting_down.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_session_switch — flush + prefetch reset behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionSwitchBufferFlush:
|
||||
def test_buffered_turns_flushed_before_clear(self, provider_with_config):
|
||||
"""retain_every_n_turns > 1 must not silently drop partial buffers
|
||||
on session switch. Whatever's in _session_turns at switch time
|
||||
should land in the OLD document under the OLD session id."""
|
||||
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
|
||||
old_doc = p._document_id
|
||||
|
||||
# Two turns buffered, no retain yet (boundary is at turn 3). The
|
||||
# writer hasn't been started either — sync_turn's early return
|
||||
# skips _ensure_writer when no retain is due.
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
assert p._sync_thread is None
|
||||
p._client.aretain_batch.assert_not_called()
|
||||
|
||||
# Switch — flush should fire under OLD document_id via the writer queue.
|
||||
p.on_session_switch("new-sid", parent_session_id="test-session", reset=True)
|
||||
p._retain_queue.join()
|
||||
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
kw = p._client.aretain_batch.call_args.kwargs
|
||||
assert kw["document_id"] == old_doc
|
||||
item = kw["items"][0]
|
||||
# Both buffered turns must be present in the flushed payload.
|
||||
content = json.loads(item["content"])
|
||||
flat = json.dumps(content)
|
||||
assert "turn1-user" in flat
|
||||
assert "turn2-user" in flat
|
||||
# Old session id must appear in lineage tags / metadata.
|
||||
assert "session:test-session" in item["tags"]
|
||||
assert item["metadata"]["session_id"] == "test-session"
|
||||
|
||||
# And the new session must start with a clean slate.
|
||||
assert p._session_id == "new-sid"
|
||||
assert p._session_turns == []
|
||||
assert p._turn_counter == 0
|
||||
assert p._document_id != old_doc
|
||||
assert p._document_id.startswith("new-sid-")
|
||||
|
||||
def test_no_flush_when_buffer_empty(self, provider):
|
||||
"""Switch with no buffered turns must not fire a spurious retain."""
|
||||
provider.on_session_switch("new-sid")
|
||||
# Nothing enqueued — join is immediate.
|
||||
provider._retain_queue.join()
|
||||
provider._client.aretain_batch.assert_not_called()
|
||||
assert provider._session_id == "new-sid"
|
||||
|
||||
def test_prefetch_result_cleared_on_switch(self, provider):
|
||||
"""Stale recall text from the old session must not leak into the
|
||||
next session's first prefetch read."""
|
||||
provider._prefetch_result = "old-session recall: User likes Rust"
|
||||
provider.on_session_switch("new-sid")
|
||||
assert provider._prefetch_result == ""
|
||||
# And subsequent prefetch() should now report empty, not the leftover.
|
||||
assert provider.prefetch("anything") == ""
|
||||
|
||||
def test_in_flight_prefetch_thread_drained_on_switch(self, provider, monkeypatch):
|
||||
"""on_session_switch must wait for an in-flight prefetch from the
|
||||
old session to settle before clearing _prefetch_result, otherwise
|
||||
the thread can race and re-populate the field after the clear."""
|
||||
import threading
|
||||
import time as _time
|
||||
|
||||
gate = threading.Event()
|
||||
finished = threading.Event()
|
||||
|
||||
def _slow_prefetch():
|
||||
gate.wait(timeout=5.0)
|
||||
with provider._prefetch_lock:
|
||||
provider._prefetch_result = "old-session recall"
|
||||
finished.set()
|
||||
|
||||
provider._prefetch_thread = threading.Thread(target=_slow_prefetch, daemon=True)
|
||||
provider._prefetch_thread.start()
|
||||
|
||||
# Release the prefetch worker so it writes _prefetch_result, then
|
||||
# call on_session_switch — it must join the thread before clearing.
|
||||
gate.set()
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
assert finished.is_set(), "switch returned before prefetch thread settled"
|
||||
assert provider._prefetch_result == ""
|
||||
|
||||
def test_flush_serializes_behind_pending_retains_via_writer_queue(
|
||||
self, provider_with_config
|
||||
):
|
||||
"""The flush closure must ride the same _retain_queue sync_turn
|
||||
uses, so it lands FIFO behind any still-queued old-session
|
||||
retains rather than racing them on a separate thread.
|
||||
|
||||
Regression guard: an earlier draft spawned a raw threading.Thread
|
||||
for flush, overwriting _sync_thread and racing the writer against
|
||||
the same document_id.
|
||||
"""
|
||||
import threading as _threading
|
||||
|
||||
p = provider_with_config(retain_every_n_turns=2, retain_async=False)
|
||||
|
||||
# Block the first writer job until we've enqueued the flush
|
||||
# behind it. This proves ordering — the flush MUST wait.
|
||||
gate = _threading.Event()
|
||||
call_order: list[str] = []
|
||||
|
||||
def _aretain_batch_tracking(**kw):
|
||||
idx = kw["items"][0]["metadata"].get("turn_index", "")
|
||||
call_order.append(str(idx))
|
||||
if idx == "2":
|
||||
# First retain blocks until we've enqueued the flush.
|
||||
gate.wait(timeout=5.0)
|
||||
|
||||
p._client.aretain_batch = AsyncMock(side_effect=_aretain_batch_tracking)
|
||||
|
||||
# Turn 1+2 → boundary hit → retain enqueued (will block).
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
|
||||
# One more buffered turn so flush has something to land.
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
|
||||
# Switch while the first retain is still blocked on `gate`.
|
||||
p.on_session_switch("new-sid", parent_session_id="test-session")
|
||||
|
||||
# Release the first retain. Flush must have been enqueued
|
||||
# BEHIND it, and run second.
|
||||
gate.set()
|
||||
p._retain_queue.join()
|
||||
|
||||
# The flush carries all buffered turns; sync_turn's retain #2
|
||||
# carried the batch at boundary time. Two distinct calls.
|
||||
assert p._client.aretain_batch.call_count == 2
|
||||
# First call landed while buffer was [t1, t2]; flush landed
|
||||
# after we added t3. So the second call must be strictly after.
|
||||
assert call_order[0] == "2"
|
||||
# Flush retain has turn_index matching the buffered count at
|
||||
# switch time (3 turns accumulated, _turn_index was set to 3
|
||||
# by the last sync_turn).
|
||||
assert call_order[1] == "3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
377
tests/plugins/test_achievements_plugin.py
Normal file
377
tests/plugins/test_achievements_plugin.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""Tests for the bundled hermes-achievements dashboard plugin.
|
||||
|
||||
These target the two behaviors that matter for official integration:
|
||||
|
||||
* The 200-session scan cap is removed — the plugin now walks the entire
|
||||
session history by default. Lifetime badges (tens of thousands of
|
||||
tool calls) were unreachable before this fix on long-running installs.
|
||||
* First-ever scans run in a background thread so the dashboard request
|
||||
path never blocks, even on 8000+ session databases where a cold scan
|
||||
takes minutes.
|
||||
|
||||
The upstream repo ships its own unittest suite under
|
||||
``plugins/hermes-achievements/tests/`` covering the achievement engine
|
||||
internals (tier math, secret-state handling, catalog invariants). These
|
||||
tests live at the hermes-agent level and focus on the integration
|
||||
contract: the plugin scans ALL of your sessions, not the first 200.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
PLUGIN_MODULE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "plugins"
|
||||
/ "hermes-achievements"
|
||||
/ "dashboard"
|
||||
/ "plugin_api.py"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_api(tmp_path, monkeypatch):
|
||||
"""Load plugin_api with isolated ~/.hermes so state/snapshot files don't collide.
|
||||
|
||||
We load the module fresh per test because the plugin keeps module-level
|
||||
caches (``_SNAPSHOT_CACHE``, ``_SCAN_STATUS``, background thread handle).
|
||||
Reloading gives each test a clean world.
|
||||
"""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"plugin_api_test_{id(tmp_path)}", PLUGIN_MODULE_PATH
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
# Stash monkeypatch so ``_install_fake_session_db`` can use it to
|
||||
# swap ``sys.modules['hermes_state']`` with auto-restoration. Without
|
||||
# this, a raw ``sys.modules[...] = fake`` assignment would leak the
|
||||
# fake into later tests in the same xdist worker — breaking every
|
||||
# test that does ``from hermes_state import SessionDB``.
|
||||
module._test_monkeypatch = monkeypatch
|
||||
yield module
|
||||
|
||||
|
||||
class _FakeSessionDB:
|
||||
"""Stand-in for hermes_state.SessionDB that records scan calls."""
|
||||
|
||||
def __init__(self, session_count: int):
|
||||
self.session_count = session_count
|
||||
self.last_limit: Optional[int] = None
|
||||
self.last_include_children: Optional[bool] = None
|
||||
self.list_calls = 0
|
||||
self.messages_calls = 0
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
exclude_sources: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
include_children: bool = False,
|
||||
project_compression_tips: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
self.last_limit = limit
|
||||
self.last_include_children = include_children
|
||||
self.list_calls += 1
|
||||
# SQLite semantics: LIMIT -1 = unlimited. Honor that here.
|
||||
effective = self.session_count if limit == -1 else min(self.session_count, limit)
|
||||
now = int(time.time())
|
||||
return [
|
||||
{
|
||||
"id": f"sess-{i}",
|
||||
"title": f"Session {i}",
|
||||
"preview": f"preview {i}",
|
||||
"started_at": now - (self.session_count - i) * 60,
|
||||
"last_active": now - (self.session_count - i) * 60 + 30,
|
||||
"source": "cli",
|
||||
"model": "test-model",
|
||||
}
|
||||
for i in range(effective)
|
||||
]
|
||||
|
||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
self.messages_calls += 1
|
||||
return [
|
||||
{"role": "user", "content": f"ask {session_id}"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"function": {"name": "terminal"}}],
|
||||
},
|
||||
{"role": "tool", "tool_name": "terminal", "content": "ok"},
|
||||
]
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _install_fake_session_db(plugin_api, fake_db):
|
||||
"""Inject a fake SessionDB so ``scan_sessions`` finds it via its local import.
|
||||
|
||||
Uses the monkeypatch stashed on ``plugin_api`` by the fixture, so the
|
||||
``sys.modules['hermes_state']`` swap is auto-restored at test teardown
|
||||
and cannot leak into unrelated tests in the same xdist worker.
|
||||
"""
|
||||
fake_module = type(sys)("hermes_state")
|
||||
fake_module.SessionDB = lambda: fake_db
|
||||
plugin_api._test_monkeypatch.setitem(sys.modules, "hermes_state", fake_module)
|
||||
|
||||
|
||||
def test_scan_sessions_default_scans_all_history_not_first_200(plugin_api):
|
||||
"""Bug regression: ``scan_sessions()`` used to cap at limit=200.
|
||||
|
||||
A user with 8000+ sessions would only see ~2% of their history in
|
||||
achievement totals, making lifetime badges unreachable. The default
|
||||
now passes ``LIMIT -1`` (SQLite "unlimited") to ``list_sessions_rich``.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=500) # > old 200 cap
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
result = plugin_api.scan_sessions()
|
||||
|
||||
assert fake_db.last_limit == -1, (
|
||||
"scan_sessions() must pass LIMIT=-1 (unlimited) to list_sessions_rich "
|
||||
f"by default, got {fake_db.last_limit}"
|
||||
)
|
||||
assert fake_db.last_include_children is True, (
|
||||
"scan_sessions() must include subagent/compression child sessions so "
|
||||
"tool calls made in delegated agents still count toward achievements"
|
||||
)
|
||||
assert len(result["sessions"]) == 500
|
||||
assert result["scan_meta"]["sessions_total"] == 500
|
||||
|
||||
|
||||
def test_scan_sessions_explicit_positive_limit_is_honored(plugin_api):
|
||||
"""Callers can still pass a small limit for smoke tests."""
|
||||
fake_db = _FakeSessionDB(session_count=500)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
result = plugin_api.scan_sessions(limit=10)
|
||||
|
||||
assert fake_db.last_limit == 10
|
||||
assert len(result["sessions"]) == 10
|
||||
|
||||
|
||||
def test_scan_sessions_zero_or_negative_limit_means_unlimited(plugin_api):
|
||||
"""``limit=0`` and ``limit=-1`` both map to the unlimited path."""
|
||||
fake_db = _FakeSessionDB(session_count=300)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
plugin_api.scan_sessions(limit=0)
|
||||
assert fake_db.last_limit == -1
|
||||
|
||||
plugin_api.scan_sessions(limit=-1)
|
||||
assert fake_db.last_limit == -1
|
||||
|
||||
|
||||
def test_evaluate_all_first_run_returns_pending_and_starts_background_scan(plugin_api):
|
||||
"""First-ever evaluate_all with no cache returns a pending placeholder
|
||||
immediately and kicks off a background scan thread. Cold scans on
|
||||
large DBs take minutes — blocking the dashboard request path is not
|
||||
acceptable.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=50)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
# Wrap _run_scan_and_update_cache so we can release it on demand,
|
||||
# simulating a slow cold scan without actually waiting.
|
||||
scan_started = threading.Event()
|
||||
allow_scan_finish = threading.Event()
|
||||
original_run = plugin_api._run_scan_and_update_cache
|
||||
|
||||
def gated_run(*args, **kwargs):
|
||||
scan_started.set()
|
||||
allow_scan_finish.wait(timeout=5)
|
||||
original_run(*args, **kwargs)
|
||||
|
||||
plugin_api._run_scan_and_update_cache = gated_run
|
||||
|
||||
t0 = time.time()
|
||||
result = plugin_api.evaluate_all()
|
||||
elapsed = time.time() - t0
|
||||
|
||||
# Immediate return — should not block waiting for the scan.
|
||||
assert elapsed < 1.0, f"evaluate_all blocked for {elapsed:.2f}s on first run"
|
||||
assert result["scan_meta"]["mode"] == "pending"
|
||||
assert result["unlocked_count"] == 0
|
||||
# Catalog still rendered so UI has something to draw.
|
||||
assert result["total_count"] >= 60
|
||||
|
||||
# Background scan is running.
|
||||
assert scan_started.wait(timeout=2), "background scan did not start"
|
||||
|
||||
# Let the scan complete, then a second call returns real data.
|
||||
allow_scan_finish.set()
|
||||
# Wait for thread to finish.
|
||||
thread = plugin_api._BACKGROUND_SCAN_THREAD
|
||||
assert thread is not None
|
||||
thread.join(timeout=5)
|
||||
assert not thread.is_alive()
|
||||
|
||||
second = plugin_api.evaluate_all()
|
||||
assert second["scan_meta"]["mode"] != "pending"
|
||||
assert second["scan_meta"].get("sessions_total") == 50
|
||||
|
||||
|
||||
def test_evaluate_all_stale_cache_serves_stale_and_refreshes_in_background(plugin_api):
|
||||
"""When the snapshot is on-disk but older than TTL, evaluate_all returns
|
||||
the stale data immediately and kicks a background refresh. Users don't
|
||||
stare at a loading spinner every time TTL expires.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=10)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
# Seed a stale snapshot on disk.
|
||||
stale_generated_at = int(time.time()) - plugin_api.SNAPSHOT_TTL_SECONDS - 60
|
||||
stale_payload = {
|
||||
"achievements": [],
|
||||
"sessions": [],
|
||||
"aggregate": {},
|
||||
"scan_meta": {"mode": "full", "sessions_total": 1, "sessions_rescanned": 1, "sessions_reused": 0},
|
||||
"error": None,
|
||||
"unlocked_count": 0,
|
||||
"discovered_count": 0,
|
||||
"secret_count": 0,
|
||||
"total_count": 0,
|
||||
"generated_at": stale_generated_at,
|
||||
}
|
||||
plugin_api.save_snapshot(stale_payload)
|
||||
|
||||
t0 = time.time()
|
||||
result = plugin_api.evaluate_all()
|
||||
elapsed = time.time() - t0
|
||||
|
||||
assert elapsed < 1.0, f"evaluate_all blocked for {elapsed:.2f}s serving stale data"
|
||||
assert result["generated_at"] == stale_generated_at
|
||||
|
||||
# Background scan should be running or have completed.
|
||||
thread = plugin_api._BACKGROUND_SCAN_THREAD
|
||||
assert thread is not None
|
||||
thread.join(timeout=5)
|
||||
|
||||
fresh = plugin_api.evaluate_all()
|
||||
assert fresh["generated_at"] >= stale_generated_at
|
||||
|
||||
|
||||
def test_evaluate_all_force_runs_synchronously(plugin_api):
|
||||
"""Manual /rescan (force=True) blocks the caller — users clicking
|
||||
the rescan button expect up-to-date data when the call returns.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=25)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
result = plugin_api.evaluate_all(force=True)
|
||||
|
||||
# Synchronous — snapshot is fresh on return.
|
||||
assert result["scan_meta"].get("sessions_total") == 25
|
||||
assert result["scan_meta"]["mode"] in ("full", "incremental")
|
||||
|
||||
|
||||
def test_start_background_scan_is_idempotent_while_running(plugin_api):
|
||||
"""Multiple concurrent dashboard requests must not spawn duplicate scans."""
|
||||
fake_db = _FakeSessionDB(session_count=5)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
release = threading.Event()
|
||||
original_run = plugin_api._run_scan_and_update_cache
|
||||
|
||||
def gated_run(*args, **kwargs):
|
||||
release.wait(timeout=5)
|
||||
original_run(*args, **kwargs)
|
||||
|
||||
plugin_api._run_scan_and_update_cache = gated_run
|
||||
|
||||
plugin_api._start_background_scan()
|
||||
first_thread = plugin_api._BACKGROUND_SCAN_THREAD
|
||||
assert first_thread is not None and first_thread.is_alive()
|
||||
|
||||
plugin_api._start_background_scan()
|
||||
plugin_api._start_background_scan()
|
||||
|
||||
assert plugin_api._BACKGROUND_SCAN_THREAD is first_thread
|
||||
|
||||
release.set()
|
||||
first_thread.join(timeout=5)
|
||||
|
||||
|
||||
def test_background_scan_publishes_partial_snapshots(plugin_api):
|
||||
"""The background scanner publishes intermediate snapshots to the cache
|
||||
every ~N sessions. Each dashboard refresh during a long cold scan sees
|
||||
more badges unlocked instead of staring at zeros for minutes and then
|
||||
having everything pop at the end.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=750)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
# Record every partial snapshot the scanner publishes.
|
||||
partial_snapshots: List[Dict[str, Any]] = []
|
||||
original_compute_from_scan = plugin_api._compute_from_scan
|
||||
|
||||
def recording_compute(scan, *, is_partial=False):
|
||||
result = original_compute_from_scan(scan, is_partial=is_partial)
|
||||
if is_partial:
|
||||
partial_snapshots.append(result)
|
||||
return result
|
||||
|
||||
plugin_api._compute_from_scan = recording_compute
|
||||
|
||||
# scan 750 sessions with progress_every=250 → expect 2 intermediate
|
||||
# publications (at 250 and 500; the final 750 call goes through the
|
||||
# finished, non-partial path).
|
||||
plugin_api._run_scan_and_update_cache(publish_partial_snapshots=True)
|
||||
|
||||
assert len(partial_snapshots) >= 2, (
|
||||
f"expected at least 2 partial publications on a 750-session scan with "
|
||||
f"progress_every=250, got {len(partial_snapshots)}"
|
||||
)
|
||||
# Partial snapshots should report growing session counts.
|
||||
counts = [p["scan_meta"].get("sessions_scanned_so_far") for p in partial_snapshots]
|
||||
assert counts == sorted(counts), f"partial session counts not monotonic: {counts}"
|
||||
assert counts[0] < 750 and counts[-1] < 750, (
|
||||
f"partial counts should be less than the final total; got {counts}"
|
||||
)
|
||||
# Every partial reports the expected end-state total so the UI can
|
||||
# show an accurate progress bar.
|
||||
for p in partial_snapshots:
|
||||
assert p["scan_meta"].get("sessions_expected_total") == 750
|
||||
|
||||
# Final snapshot in cache is the real (non-partial) one.
|
||||
final = plugin_api._SNAPSHOT_CACHE
|
||||
assert final is not None
|
||||
assert final["scan_meta"].get("mode") != "in_progress"
|
||||
assert final["scan_meta"].get("sessions_total") == 750
|
||||
|
||||
|
||||
def test_partial_snapshots_do_not_persist_unlock_timestamps(plugin_api):
|
||||
"""Intermediate snapshots must not write to state.json — an unlock
|
||||
that appears at 30% scan progress could disappear when a later session
|
||||
rebalances the aggregate. Only the final snapshot records ``unlocked_at``.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=10)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
# Seed empty state, then invoke partial compute directly.
|
||||
plugin_api.save_state({"unlocks": {}})
|
||||
partial_scan = {
|
||||
"sessions": [{"session_id": "x", "tool_call_count": 99999, "tool_names": set()}],
|
||||
"aggregate": {"max_tool_calls_in_session": 99999, "total_tool_calls": 99999},
|
||||
"scan_meta": {"mode": "in_progress"},
|
||||
}
|
||||
result = plugin_api._compute_from_scan(partial_scan, is_partial=True)
|
||||
|
||||
# Some achievements should evaluate as unlocked in this aggregate...
|
||||
assert any(a["unlocked"] for a in result["achievements"])
|
||||
|
||||
# ...but state.json on disk stays empty (no timestamps were recorded).
|
||||
persisted = plugin_api.load_state()
|
||||
assert persisted.get("unlocks", {}) == {}, (
|
||||
"partial scans must not record unlock timestamps — a later session "
|
||||
"could change whether the badge deserves to be unlocked yet"
|
||||
)
|
||||
@@ -89,15 +89,75 @@ class TestThirdPartyAnthropicGateway:
|
||||
assert should is True, "Third-party Anthropic gateway with Claude must cache"
|
||||
assert native is True, "Third-party Anthropic gateway uses native cache_control layout"
|
||||
|
||||
def test_third_party_without_claude_name_does_not_cache(self):
|
||||
# A provider exposing e.g. GLM via anthropic_messages transport — we
|
||||
# don't know whether it supports cache_control, so stay conservative.
|
||||
def test_third_party_anthropic_non_claude_unknown_provider_does_not_cache(self):
|
||||
# A provider exposing e.g. GLM via anthropic_messages transport from
|
||||
# a host we don't recognize — we don't know whether it supports
|
||||
# cache_control, so stay conservative.
|
||||
agent = _make_agent(
|
||||
provider="custom",
|
||||
base_url="https://some-unknown-gateway.example.com/anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
model="glm-4.5",
|
||||
)
|
||||
assert agent._anthropic_prompt_cache_policy() == (False, False)
|
||||
|
||||
|
||||
class TestMiniMaxAnthropicWire:
|
||||
"""MiniMax's own model family on its Anthropic-compatible endpoint.
|
||||
|
||||
MiniMax documents cache_control support on ``/anthropic`` (0.1× read
|
||||
pricing, 5-minute TTL). Issue #17332: the blanket ``is_claude`` gate on
|
||||
the third-party-gateway branch left MiniMax-M2.7 etc. paying full input
|
||||
cost every turn. Allowlist MiniMax explicitly via provider id or host.
|
||||
"""
|
||||
|
||||
def test_minimax_m27_on_provider_minimax_caches_native_layout(self):
|
||||
agent = _make_agent(
|
||||
provider="minimax",
|
||||
base_url="https://api.minimax.io/anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
model="minimax-m2.7",
|
||||
)
|
||||
assert agent._anthropic_prompt_cache_policy() == (True, True)
|
||||
|
||||
def test_minimax_m25_on_provider_minimax_cn_caches_native_layout(self):
|
||||
agent = _make_agent(
|
||||
provider="minimax-cn",
|
||||
base_url="https://api.minimaxi.com/anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
model="minimax-m2.5",
|
||||
)
|
||||
assert agent._anthropic_prompt_cache_policy() == (True, True)
|
||||
|
||||
def test_custom_provider_pointed_at_minimax_host_caches(self):
|
||||
# User wires a custom provider manually at MiniMax's Anthropic URL;
|
||||
# host match alone should be sufficient to enable caching.
|
||||
agent = _make_agent(
|
||||
provider="custom",
|
||||
base_url="https://api.minimax.io/anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
model="minimax-m2.7",
|
||||
)
|
||||
assert agent._anthropic_prompt_cache_policy() == (True, True)
|
||||
|
||||
def test_minimax_host_china_endpoint_caches(self):
|
||||
agent = _make_agent(
|
||||
provider="custom",
|
||||
base_url="https://api.minimaxi.com/anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
model="minimax-m2.1",
|
||||
)
|
||||
assert agent._anthropic_prompt_cache_policy() == (True, True)
|
||||
|
||||
def test_minimax_provider_on_openai_wire_does_not_cache(self):
|
||||
# chat_completions transport — MiniMax's cache_control support is
|
||||
# documented only for the /anthropic endpoint. Stay off.
|
||||
agent = _make_agent(
|
||||
provider="minimax",
|
||||
base_url="https://api.minimax.io/v1",
|
||||
api_mode="chat_completions",
|
||||
model="minimax-m2.7",
|
||||
)
|
||||
assert agent._anthropic_prompt_cache_policy() == (False, False)
|
||||
|
||||
|
||||
|
||||
@@ -8,12 +8,10 @@ effects (terminal, send_message, delegate_task, etc.).
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_agent_stub():
|
||||
def _make_agent_stub(agent_cls):
|
||||
"""Create a minimal AIAgent-like object with just enough state for _spawn_background_review."""
|
||||
agent = object.__new__(AIAgent)
|
||||
agent = object.__new__(agent_cls)
|
||||
agent.model = "test-model"
|
||||
agent.platform = "test"
|
||||
agent.provider = "openai"
|
||||
@@ -45,14 +43,16 @@ class _SyncThread:
|
||||
|
||||
def test_background_review_agent_uses_restricted_toolsets():
|
||||
"""The review agent must only have access to 'memory' and 'skills' toolsets."""
|
||||
agent = _make_agent_stub()
|
||||
import run_agent
|
||||
|
||||
agent = _make_agent_stub(run_agent.AIAgent)
|
||||
captured = {}
|
||||
|
||||
def _capture_init(self, *args, **kwargs):
|
||||
captured["enabled_toolsets"] = kwargs.get("enabled_toolsets")
|
||||
raise RuntimeError("stop after capturing init args")
|
||||
|
||||
with patch.object(AIAgent, "__init__", _capture_init), \
|
||||
with patch.object(run_agent.AIAgent, "__init__", _capture_init), \
|
||||
patch("threading.Thread", _SyncThread):
|
||||
agent._spawn_background_review(
|
||||
messages_snapshot=[],
|
||||
|
||||
@@ -31,6 +31,10 @@ def _bare_agent():
|
||||
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent._memory_manager = MagicMock()
|
||||
# session_id is now propagated into sync_all / queue_prefetch_all so
|
||||
# providers that cache per-session state can update it mid-process
|
||||
# (see #6672).
|
||||
agent.session_id = "test_session_001"
|
||||
return agent
|
||||
|
||||
|
||||
@@ -80,9 +84,11 @@ class TestSyncExternalMemoryForTurn:
|
||||
)
|
||||
agent._memory_manager.sync_all.assert_called_once_with(
|
||||
"What's the weather in Paris?", "It's sunny and 22°C.",
|
||||
session_id="test_session_001",
|
||||
)
|
||||
agent._memory_manager.queue_prefetch_all.assert_called_once_with(
|
||||
"What's the weather in Paris?",
|
||||
session_id="test_session_001",
|
||||
)
|
||||
|
||||
# --- Edge cases (pre-existing behaviour preserved) ------------------
|
||||
|
||||
@@ -144,6 +144,36 @@ class TestBuildApiKwargsOpenRouter:
|
||||
assert messages[1]["tool_calls"][0]["response_item_id"] == "fc_123"
|
||||
assert "codex_reasoning_items" in messages[1]
|
||||
|
||||
def test_gemini_native_passes_base_url_for_top_level_thinking_config(self, monkeypatch):
|
||||
agent = _make_agent(
|
||||
monkeypatch,
|
||||
"gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
model="gemini-3-flash-preview",
|
||||
)
|
||||
agent.reasoning_config = {"enabled": True, "effort": "high"}
|
||||
kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}])
|
||||
assert kwargs["extra_body"]["thinking_config"] == {
|
||||
"includeThoughts": True,
|
||||
"thinkingLevel": "high",
|
||||
}
|
||||
assert "extra_body" not in kwargs["extra_body"]
|
||||
|
||||
def test_gemini_openai_compat_passes_base_url_for_nested_google_thinking_config(self, monkeypatch):
|
||||
agent = _make_agent(
|
||||
monkeypatch,
|
||||
"gemini",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
model="gemini-3.1-pro-preview",
|
||||
)
|
||||
agent.reasoning_config = {"enabled": True, "effort": "high"}
|
||||
kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}])
|
||||
assert "thinking_config" not in kwargs["extra_body"]
|
||||
assert kwargs["extra_body"]["extra_body"]["google"]["thinking_config"] == {
|
||||
"include_thoughts": True,
|
||||
"thinking_level": "high",
|
||||
}
|
||||
|
||||
def test_should_sanitize_tool_calls_codex_vs_chat(self, monkeypatch):
|
||||
"""Codex API should NOT sanitize, all other APIs should sanitize."""
|
||||
# Codex mode should NOT need sanitization
|
||||
@@ -936,17 +966,25 @@ class TestAuxiliaryClientProviderPriority:
|
||||
client, model = get_text_auxiliary_client()
|
||||
assert mock.call_args.kwargs["base_url"] == "http://localhost:1234/v1"
|
||||
|
||||
def test_codex_fallback_last_resort(self, monkeypatch):
|
||||
def test_codex_not_in_auto_fallback(self, monkeypatch):
|
||||
"""Codex is deliberately NOT part of the auto fallback chain.
|
||||
|
||||
ChatGPT-account Codex gates which models it accepts via an
|
||||
undocumented, shifting allow-list, so falling through to Codex with
|
||||
a hardcoded default model breaks silently whenever OpenAI rotates
|
||||
the list. When nothing else is available, ``get_text_auxiliary_client``
|
||||
now returns (None, None) rather than guessing a Codex model.
|
||||
"""
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
from agent.auxiliary_client import get_text_auxiliary_client, CodexAuxiliaryClient
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-tok"), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client()
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
|
||||
# ── Provider routing tests ───────────────────────────────────────────────────
|
||||
|
||||
@@ -862,6 +862,26 @@ class TestBuildSystemPrompt:
|
||||
prompt = agent._build_system_prompt()
|
||||
assert DEFAULT_AGENT_IDENTITY in prompt
|
||||
|
||||
def test_can_use_soul_identity_even_when_context_files_are_skipped(self):
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("terminal")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch("run_agent.load_soul_md", return_value="SOUL IDENTITY"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
load_soul_identity=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
prompt = agent._build_system_prompt()
|
||||
|
||||
assert "SOUL IDENTITY" in prompt
|
||||
assert DEFAULT_AGENT_IDENTITY not in prompt
|
||||
|
||||
def test_includes_system_message(self, agent):
|
||||
prompt = agent._build_system_prompt(system_message="Custom instruction")
|
||||
assert "Custom instruction" in prompt
|
||||
|
||||
@@ -96,7 +96,7 @@ class TestCompactBannerSkinIntegration:
|
||||
set_active_skin("default")
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"):
|
||||
patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v0.1.0 (test)"}):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert "NOUS HERMES" in banner
|
||||
@@ -105,7 +105,7 @@ class TestCompactBannerSkinIntegration:
|
||||
set_active_skin("poseidon")
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"):
|
||||
patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v0.1.0 (test)"}):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert "Poseidon Agent" in banner
|
||||
@@ -116,7 +116,7 @@ class TestCompactBannerSkinIntegration:
|
||||
skin = get_active_skin()
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"):
|
||||
patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v0.1.0 (test)"}):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert skin.get_color("banner_border") in banner
|
||||
@@ -127,7 +127,7 @@ class TestCompactBannerSkinIntegration:
|
||||
set_active_skin("default")
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v1.0 (test) · upstream abc12345"):
|
||||
patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v1.0 (test) · upstream abc12345"}):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert "upstream abc12345" in banner
|
||||
|
||||
466
tests/test_minimax_oauth.py
Normal file
466
tests/test_minimax_oauth.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Tests for MiniMax OAuth provider (hermes_cli/auth.py).
|
||||
|
||||
Covers:
|
||||
- PKCE pair generation (S256 challenge)
|
||||
- _minimax_request_user_code happy path and state-mismatch error
|
||||
- _minimax_poll_token: pending→success flow, error status, timeout
|
||||
- _refresh_minimax_oauth_state: skip when not expired, update on success,
|
||||
re-login required on invalid_grant
|
||||
- resolve_minimax_oauth_runtime_credentials: error when not logged in
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
AuthError,
|
||||
MINIMAX_OAUTH_CLIENT_ID,
|
||||
MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
MINIMAX_OAUTH_GLOBAL_INFERENCE,
|
||||
MINIMAX_OAUTH_CN_BASE,
|
||||
MINIMAX_OAUTH_CN_INFERENCE,
|
||||
MINIMAX_OAUTH_REFRESH_SKEW_SECONDS,
|
||||
_minimax_pkce_pair,
|
||||
_minimax_request_user_code,
|
||||
_minimax_poll_token,
|
||||
_refresh_minimax_oauth_state,
|
||||
resolve_minimax_oauth_runtime_credentials,
|
||||
get_minimax_oauth_auth_status,
|
||||
get_provider_auth_state,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_httpx_response(status_code: int, body: dict | None = None, text: str = ""):
|
||||
"""Return a minimal mock that quacks like httpx.Response."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
if body is not None:
|
||||
resp.json.return_value = body
|
||||
resp.text = json.dumps(body)
|
||||
else:
|
||||
resp.json.side_effect = Exception("No body")
|
||||
resp.text = text
|
||||
resp.reason_phrase = "OK" if status_code == 200 else "Error"
|
||||
return resp
|
||||
|
||||
|
||||
def _future_iso(seconds_from_now: int = 3600) -> str:
|
||||
ts = time.time() + seconds_from_now
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _past_iso(seconds_ago: int = 3600) -> str:
|
||||
ts = time.time() - seconds_ago
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. test_pkce_pair_produces_valid_s256
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_pkce_pair_produces_valid_s256():
|
||||
verifier, challenge, state = _minimax_pkce_pair()
|
||||
|
||||
# Verifier must be non-empty and URL-safe
|
||||
assert isinstance(verifier, str)
|
||||
assert len(verifier) >= 32
|
||||
|
||||
# Challenge must be URL-safe base64 without trailing "="
|
||||
assert isinstance(challenge, str)
|
||||
assert "=" not in challenge
|
||||
|
||||
# Re-compute challenge from verifier and verify it matches
|
||||
expected = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(verifier.encode()).digest()
|
||||
).decode().rstrip("=")
|
||||
assert challenge == expected
|
||||
|
||||
# State must be non-empty
|
||||
assert isinstance(state, str)
|
||||
assert len(state) >= 8
|
||||
|
||||
# Two calls must return different values (randomness)
|
||||
v2, c2, s2 = _minimax_pkce_pair()
|
||||
assert verifier != v2
|
||||
assert state != s2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. test_request_user_code_happy_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_request_user_code_happy_path():
|
||||
state = "test-state-abc"
|
||||
mock_response = _make_httpx_response(200, {
|
||||
"user_code": "ABC-123",
|
||||
"verification_uri": "https://minimax.io/verify",
|
||||
"expired_in": int(time.time() * 1000) + 300_000,
|
||||
"state": state,
|
||||
})
|
||||
|
||||
client = MagicMock()
|
||||
client.post.return_value = mock_response
|
||||
|
||||
result = _minimax_request_user_code(
|
||||
client,
|
||||
portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
client_id=MINIMAX_OAUTH_CLIENT_ID,
|
||||
code_challenge="test-challenge",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert result["user_code"] == "ABC-123"
|
||||
assert result["verification_uri"] == "https://minimax.io/verify"
|
||||
assert result["state"] == state
|
||||
|
||||
# Verify correct endpoint was called
|
||||
call_args = client.post.call_args
|
||||
assert "/oauth/code" in call_args[0][0]
|
||||
headers = call_args[1].get("headers", {})
|
||||
assert "x-request-id" in headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. test_request_user_code_state_mismatch_raises
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_request_user_code_state_mismatch_raises():
|
||||
mock_response = _make_httpx_response(200, {
|
||||
"user_code": "XYZ",
|
||||
"verification_uri": "https://minimax.io/verify",
|
||||
"expired_in": 300,
|
||||
"state": "wrong-state", # Mismatched!
|
||||
})
|
||||
|
||||
client = MagicMock()
|
||||
client.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
_minimax_request_user_code(
|
||||
client,
|
||||
portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
client_id=MINIMAX_OAUTH_CLIENT_ID,
|
||||
code_challenge="challenge",
|
||||
state="correct-state",
|
||||
)
|
||||
|
||||
assert exc_info.value.code == "state_mismatch"
|
||||
assert "CSRF" in str(exc_info.value) or "mismatch" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. test_request_user_code_non_200_raises
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_request_user_code_non_200_raises():
|
||||
mock_response = _make_httpx_response(400, text="Bad Request")
|
||||
mock_response.json.side_effect = Exception("no json")
|
||||
mock_response.text = "Bad Request"
|
||||
|
||||
client = MagicMock()
|
||||
client.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
_minimax_request_user_code(
|
||||
client,
|
||||
portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
client_id=MINIMAX_OAUTH_CLIENT_ID,
|
||||
code_challenge="challenge",
|
||||
state="state",
|
||||
)
|
||||
|
||||
assert exc_info.value.code == "authorization_failed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. test_poll_token_pending_then_success
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_poll_token_pending_then_success():
|
||||
# Set a deadline far enough in the future for polling
|
||||
deadline_ms = int(time.time() * 1000) + 60_000 # 60 seconds from now
|
||||
|
||||
pending_body = {"status": "pending"}
|
||||
success_body = {
|
||||
"status": "success",
|
||||
"access_token": "access-abc",
|
||||
"refresh_token": "refresh-xyz",
|
||||
"expired_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
pending_resp = _make_httpx_response(200, pending_body)
|
||||
success_resp = _make_httpx_response(200, success_body)
|
||||
|
||||
client = MagicMock()
|
||||
client.post.side_effect = [pending_resp, pending_resp, success_resp]
|
||||
|
||||
with patch("time.sleep"): # don't actually sleep
|
||||
result = _minimax_poll_token(
|
||||
client,
|
||||
portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
client_id=MINIMAX_OAUTH_CLIENT_ID,
|
||||
user_code="USER-CODE",
|
||||
code_verifier="verifier",
|
||||
expired_in=deadline_ms,
|
||||
interval_ms=2000,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["access_token"] == "access-abc"
|
||||
assert result["refresh_token"] == "refresh-xyz"
|
||||
assert client.post.call_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. test_poll_token_error_raises
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_poll_token_error_raises():
|
||||
deadline_ms = int(time.time() * 1000) + 60_000
|
||||
error_body = {"status": "error"}
|
||||
error_resp = _make_httpx_response(200, error_body)
|
||||
|
||||
client = MagicMock()
|
||||
client.post.return_value = error_resp
|
||||
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
_minimax_poll_token(
|
||||
client,
|
||||
portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
client_id=MINIMAX_OAUTH_CLIENT_ID,
|
||||
user_code="U",
|
||||
code_verifier="v",
|
||||
expired_in=deadline_ms,
|
||||
interval_ms=2000,
|
||||
)
|
||||
|
||||
assert exc_info.value.code == "authorization_denied"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. test_poll_token_timeout_raises
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_poll_token_timeout_raises():
|
||||
# expired_in is a small duration (treated as seconds from now, already expired)
|
||||
expired_in = 1 # 1 second from now
|
||||
# Make sleep a no-op and time.time advance quickly by using a small deadline
|
||||
# We use a duration-style expired_in (small enough to not be a unix timestamp)
|
||||
# duration mode: deadline = time.time() + max(1, expired_in)
|
||||
# We need time() to exceed deadline immediately.
|
||||
|
||||
fixed_now = time.time()
|
||||
call_count = [0]
|
||||
|
||||
def fake_time():
|
||||
call_count[0] += 1
|
||||
# After 2 calls, return a time past the deadline
|
||||
if call_count[0] > 2:
|
||||
return fixed_now + 10 # past deadline
|
||||
return fixed_now
|
||||
|
||||
client = MagicMock()
|
||||
pending_resp = _make_httpx_response(200, {"status": "pending"})
|
||||
client.post.return_value = pending_resp
|
||||
|
||||
import hermes_cli.auth as auth_module
|
||||
with patch.object(auth_module, "time") as mock_time_mod:
|
||||
# We need to patch the 'time' module used inside _minimax_poll_token
|
||||
# The function imports 'import time as _time' locally.
|
||||
# Patch time.sleep and time.time in the auth module's local scope.
|
||||
pass
|
||||
|
||||
# Use a simpler approach: expired_in as past timestamp (already expired)
|
||||
past_deadline_ms = int((time.time() - 1) * 1000) # 1 second ago
|
||||
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
_minimax_poll_token(
|
||||
client,
|
||||
portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
client_id=MINIMAX_OAUTH_CLIENT_ID,
|
||||
user_code="U",
|
||||
code_verifier="v",
|
||||
expired_in=past_deadline_ms,
|
||||
interval_ms=2000,
|
||||
)
|
||||
|
||||
assert exc_info.value.code == "timeout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. test_refresh_skip_when_not_expired
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_refresh_skip_when_not_expired():
|
||||
"""When token is far from expiry, refresh should return the same state."""
|
||||
state = {
|
||||
"access_token": "old-access",
|
||||
"refresh_token": "refresh-token",
|
||||
"portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
"client_id": MINIMAX_OAUTH_CLIENT_ID,
|
||||
"inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE,
|
||||
"expires_at": _future_iso(3600), # 1 hour in the future
|
||||
}
|
||||
|
||||
result = _refresh_minimax_oauth_state(state)
|
||||
assert result["access_token"] == "old-access"
|
||||
assert result is state # Same object returned (no refresh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. test_refresh_updates_access_token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_refresh_updates_access_token():
|
||||
"""When token is close to expiry, refresh should update the state."""
|
||||
# expires_at just MINIMAX_OAUTH_REFRESH_SKEW_SECONDS - 1 from now (close to expiry)
|
||||
state = {
|
||||
"access_token": "old-access",
|
||||
"refresh_token": "my-refresh",
|
||||
"portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
"client_id": MINIMAX_OAUTH_CLIENT_ID,
|
||||
"inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE,
|
||||
"expires_at": _future_iso(MINIMAX_OAUTH_REFRESH_SKEW_SECONDS - 1),
|
||||
}
|
||||
|
||||
new_token_body = {
|
||||
"status": "success",
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"expired_in": 7200,
|
||||
}
|
||||
|
||||
mock_resp = _make_httpx_response(200, new_token_body)
|
||||
|
||||
with patch("httpx.Client") as mock_client_class:
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.__enter__ = MagicMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__exit__ = MagicMock(return_value=False)
|
||||
mock_client_instance.post.return_value = mock_resp
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
# Patch _minimax_save_auth_state to avoid touching the auth store
|
||||
with patch("hermes_cli.auth._minimax_save_auth_state"):
|
||||
result = _refresh_minimax_oauth_state(state)
|
||||
|
||||
assert result["access_token"] == "new-access"
|
||||
assert result["refresh_token"] == "new-refresh"
|
||||
assert result["expires_in"] == 7200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. test_refresh_reuse_triggers_relogin_required
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_refresh_reuse_triggers_relogin_required():
|
||||
"""On 400 + invalid_grant body, relogin_required should be set."""
|
||||
state = {
|
||||
"access_token": "old-access",
|
||||
"refresh_token": "old-refresh",
|
||||
"portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE,
|
||||
"client_id": MINIMAX_OAUTH_CLIENT_ID,
|
||||
"inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE,
|
||||
"expires_at": _past_iso(100), # already expired
|
||||
}
|
||||
|
||||
bad_resp = _make_httpx_response(400, text="invalid_grant")
|
||||
bad_resp.json.side_effect = Exception("no json")
|
||||
bad_resp.text = "invalid_grant"
|
||||
bad_resp.reason_phrase = "Bad Request"
|
||||
|
||||
with patch("httpx.Client") as mock_client_class:
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.__enter__ = MagicMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__exit__ = MagicMock(return_value=False)
|
||||
mock_client_instance.post.return_value = bad_resp
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
_refresh_minimax_oauth_state(state)
|
||||
|
||||
assert exc_info.value.code == "refresh_failed"
|
||||
assert exc_info.value.relogin_required is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. test_resolve_credentials_requires_login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resolve_credentials_requires_login():
|
||||
"""When no state is stored, resolve_minimax_oauth_runtime_credentials raises."""
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value=None):
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
resolve_minimax_oauth_runtime_credentials()
|
||||
|
||||
assert exc_info.value.code == "not_logged_in"
|
||||
assert exc_info.value.relogin_required is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. test_provider_registry_contains_minimax_oauth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_provider_registry_contains_minimax_oauth():
|
||||
assert "minimax-oauth" in PROVIDER_REGISTRY
|
||||
pconfig = PROVIDER_REGISTRY["minimax-oauth"]
|
||||
assert pconfig.auth_type == "oauth_minimax"
|
||||
assert pconfig.client_id == MINIMAX_OAUTH_CLIENT_ID
|
||||
assert MINIMAX_OAUTH_GLOBAL_BASE in pconfig.portal_base_url
|
||||
assert MINIMAX_OAUTH_GLOBAL_INFERENCE in pconfig.inference_base_url
|
||||
assert "cn_portal_base_url" in pconfig.extra
|
||||
assert "cn_inference_base_url" in pconfig.extra
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. test_minimax_oauth_alias_resolves
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_minimax_oauth_alias_resolves():
|
||||
from hermes_cli.auth import resolve_provider
|
||||
# Only test that minimax-oauth itself resolves (alias resolution is tested in models)
|
||||
result = resolve_provider("minimax-oauth")
|
||||
assert result == "minimax-oauth"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. test_get_minimax_oauth_auth_status_not_logged_in
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_minimax_oauth_auth_status_not_logged_in():
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value=None):
|
||||
status = get_minimax_oauth_auth_status()
|
||||
|
||||
assert status["logged_in"] is False
|
||||
assert status["provider"] == "minimax-oauth"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 15. test_get_minimax_oauth_auth_status_logged_in
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_minimax_oauth_auth_status_logged_in():
|
||||
state = {
|
||||
"access_token": "tok",
|
||||
"expires_at": _future_iso(3600),
|
||||
"region": "global",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value=state):
|
||||
status = get_minimax_oauth_auth_status()
|
||||
|
||||
assert status["logged_in"] is True
|
||||
assert status["region"] == "global"
|
||||
@@ -193,8 +193,15 @@ class TestPreToolCallBlocking:
|
||||
result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1"))
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_skip_flag_prevents_double_block_check(self, monkeypatch):
|
||||
"""When skip_pre_tool_call_hook=True, blocking is not checked (caller did it)."""
|
||||
def test_skip_flag_prevents_double_fire(self, monkeypatch):
|
||||
"""When skip_pre_tool_call_hook=True, the hook does not fire again.
|
||||
|
||||
The caller (e.g. run_agent._invoke_tool) has already called
|
||||
get_pre_tool_call_block_message(), which fires the hook once.
|
||||
handle_function_call must NOT fire it a second time — that was
|
||||
the classic double-fire bug where observer hooks logged every
|
||||
tool call twice.
|
||||
"""
|
||||
hook_calls = []
|
||||
|
||||
def fake_invoke_hook(hook_name, **kwargs):
|
||||
@@ -208,10 +215,58 @@ class TestPreToolCallBlocking:
|
||||
handle_function_call("web_search", {"q": "test"}, task_id="t1",
|
||||
skip_pre_tool_call_hook=True)
|
||||
|
||||
# Hook still fires for observer notification, but get_pre_tool_call_block_message
|
||||
# is not called — invoke_hook fires directly in the skip=True branch.
|
||||
assert "pre_tool_call" in hook_calls
|
||||
# Single-fire contract: when skip=True the caller already fired
|
||||
# pre_tool_call, so handle_function_call must not fire it again.
|
||||
assert hook_calls.count("pre_tool_call") == 0, (
|
||||
f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times "
|
||||
f"with skip_pre_tool_call_hook=True; expected 0 "
|
||||
f"(caller already fired it). hook_calls={hook_calls}"
|
||||
)
|
||||
# post_tool_call and transform_tool_result still fire — only the
|
||||
# pre-call block-check path is suppressed by the skip flag.
|
||||
assert "post_tool_call" in hook_calls
|
||||
assert "transform_tool_result" in hook_calls
|
||||
|
||||
def test_run_agent_pattern_fires_pre_tool_call_exactly_once(self, monkeypatch):
|
||||
"""End-to-end regression for the double-fire bug.
|
||||
|
||||
Mirrors run_agent._invoke_tool: first calls
|
||||
get_pre_tool_call_block_message() (which fires the hook as part of
|
||||
its block-directive poll), then calls
|
||||
handle_function_call(skip_pre_tool_call_hook=True). The plugin
|
||||
hook MUST fire exactly once across both calls — not twice as it
|
||||
did before the fix (observer plugins were seeing every tool
|
||||
execution logged twice).
|
||||
"""
|
||||
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||
|
||||
hook_calls = []
|
||||
|
||||
def fake_invoke_hook(hook_name, **kwargs):
|
||||
hook_calls.append(hook_name)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
|
||||
monkeypatch.setattr("model_tools.registry.dispatch",
|
||||
lambda *a, **kw: json.dumps({"ok": True}))
|
||||
|
||||
# Step 1: caller checks for a block directive (this fires pre_tool_call once).
|
||||
block = get_pre_tool_call_block_message(
|
||||
"web_search", {"q": "test"}, task_id="t1",
|
||||
)
|
||||
assert block is None
|
||||
|
||||
# Step 2: caller dispatches with skip=True so the hook isn't re-fired.
|
||||
handle_function_call(
|
||||
"web_search", {"q": "test"}, task_id="t1",
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
|
||||
assert hook_calls.count("pre_tool_call") == 1, (
|
||||
f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times "
|
||||
f"across the run_agent (block-check + dispatch) path; "
|
||||
f"expected exactly 1. hook_calls={hook_calls}"
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -199,20 +199,22 @@ class TestRunAsyncWithRunningLoop:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_uses_nonblocking_executor_shutdown(self, monkeypatch):
|
||||
"""A timeout in the running-loop branch must not wait for the worker.
|
||||
"""A timeout in the running-loop branch must not block the caller.
|
||||
|
||||
ThreadPoolExecutor's context manager performs shutdown(wait=True).
|
||||
If _run_async relies on that path after future.result(timeout=...)
|
||||
times out, the timeout does not bound wall-clock time because the
|
||||
caller still waits for the stuck coroutine's thread to finish.
|
||||
If shutdown ever waits for a stuck worker, a tool coroutine that
|
||||
ignores (or can't observe) cancellation would hang the whole agent.
|
||||
Guard: the caller must raise TimeoutError and pool.shutdown must be
|
||||
called with wait=False. The worker's own event loop handles cleanup
|
||||
(cancellation is scheduled via call_soon_threadsafe before the
|
||||
caller returns).
|
||||
"""
|
||||
import concurrent.futures
|
||||
from model_tools import _run_async
|
||||
|
||||
events = {
|
||||
"cancelled": False,
|
||||
"result_timeout": None,
|
||||
"shutdown_calls": [],
|
||||
"submitted_fn": None,
|
||||
}
|
||||
|
||||
class TimeoutFuture:
|
||||
@@ -221,7 +223,6 @@ class TestRunAsyncWithRunningLoop:
|
||||
raise concurrent.futures.TimeoutError()
|
||||
|
||||
def cancel(self):
|
||||
events["cancelled"] = True
|
||||
return True
|
||||
|
||||
class FakeExecutor:
|
||||
@@ -236,8 +237,10 @@ class TestRunAsyncWithRunningLoop:
|
||||
return False
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
if args and hasattr(args[0], "close"):
|
||||
args[0].close()
|
||||
# Record which function got submitted -- should be the
|
||||
# in-function worker wrapper, not bare asyncio.run, so we
|
||||
# know _run_async is using a loop it owns and can cancel.
|
||||
events["submitted_fn"] = getattr(fn, "__name__", repr(fn))
|
||||
return TimeoutFuture()
|
||||
|
||||
def shutdown(self, wait=True, cancel_futures=False):
|
||||
@@ -256,8 +259,82 @@ class TestRunAsyncWithRunningLoop:
|
||||
_run_async(_never_finishes())
|
||||
|
||||
assert events["result_timeout"] == 300
|
||||
assert events["cancelled"] is True
|
||||
assert events["shutdown_calls"] == [(False, True)]
|
||||
# The worker wrapper creates its own event loop so _run_async can
|
||||
# cancel the task on timeout — this must NOT be bare asyncio.run.
|
||||
assert events["submitted_fn"] != "run", (
|
||||
"_run_async submitted asyncio.run directly — it must submit a "
|
||||
"worker wrapper that owns the event loop so timeouts can cancel "
|
||||
"the task"
|
||||
)
|
||||
# Critical: shutdown must NOT wait. If wait=True, a stuck coroutine
|
||||
# would freeze the caller (converts a thread leak into a hang).
|
||||
assert events["shutdown_calls"], "shutdown was never called"
|
||||
for wait, _cancel in events["shutdown_calls"]:
|
||||
assert wait is False, (
|
||||
f"shutdown called with wait={wait} — a stuck tool coroutine "
|
||||
f"would hang the caller indefinitely"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_cancels_coroutine_in_worker_loop(self, monkeypatch):
|
||||
"""On timeout, the worker's event loop must receive a cancel request
|
||||
so the coroutine stops and the thread exits — not leaked.
|
||||
|
||||
Before the fix, future.cancel() on a running ThreadPoolExecutor
|
||||
future is a no-op, so the worker thread kept running the coroutine
|
||||
to completion (leaking one thread per tool-timeout).
|
||||
"""
|
||||
from model_tools import _run_async
|
||||
|
||||
# Shrink the 300s internal timeout by patching future.result.
|
||||
# We do this surgically: let everything else run for real so the
|
||||
# worker loop actually exists and can observe cancellation.
|
||||
import concurrent.futures as _cf
|
||||
|
||||
real_pool_cls = _cf.ThreadPoolExecutor
|
||||
|
||||
class FastTimeoutPool(real_pool_cls):
|
||||
def __init__(self, *a, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
# Patch future.result to time out after 1s instead of 300s.
|
||||
real_result = _cf.Future.result
|
||||
|
||||
def fast_result(self, timeout=None):
|
||||
return real_result(self, timeout=1.0 if timeout == 300 else timeout)
|
||||
|
||||
monkeypatch.setattr(_cf.Future, "result", fast_result)
|
||||
|
||||
cancel_observed = threading.Event()
|
||||
|
||||
async def _slow_cancellable():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancel_observed.set()
|
||||
raise
|
||||
|
||||
import time as _time
|
||||
t0 = _time.time()
|
||||
with pytest.raises(_cf.TimeoutError):
|
||||
_run_async(_slow_cancellable())
|
||||
elapsed = _time.time() - t0
|
||||
|
||||
# Caller must return fast (no hang waiting for the coro).
|
||||
assert elapsed < 3.0, (
|
||||
f"_run_async blocked caller for {elapsed:.1f}s — should return "
|
||||
f"on timeout regardless of whether the coroutine has finished"
|
||||
)
|
||||
|
||||
# Worker thread must cancel the task (not leak).
|
||||
deadline = _time.time() + 5
|
||||
while not cancel_observed.is_set() and _time.time() < deadline:
|
||||
_time.sleep(0.05)
|
||||
assert cancel_observed.is_set(), (
|
||||
"Coroutine never received CancelledError — worker thread leaked "
|
||||
"(ThreadPoolExecutor.cancel() is a no-op on a running future; "
|
||||
"_run_async must cancel the task inside its worker loop)"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -59,6 +59,147 @@ def test_write_json_returns_false_on_broken_pipe(monkeypatch):
|
||||
assert server.write_json({"ok": True}) is False
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_prefers_tui_env(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web, terminal, ,memory")
|
||||
|
||||
assert server._load_enabled_toolsets() == ["web", "terminal", "memory"]
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_filters_invalid_tui_env(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web, nope")
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: None),
|
||||
)
|
||||
|
||||
assert server._load_enabled_toolsets() == ["web"]
|
||||
assert "nope" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_accepts_plugin_env_after_discovery(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "plugin_demo")
|
||||
|
||||
import toolsets
|
||||
|
||||
discovered = {"ready": False}
|
||||
original_validate = toolsets.validate_toolset
|
||||
|
||||
def fake_validate(name):
|
||||
return name == "plugin_demo" and discovered["ready"] or original_validate(name)
|
||||
|
||||
monkeypatch.setattr(toolsets, "validate_toolset", fake_validate)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: discovered.update({"ready": True})),
|
||||
)
|
||||
|
||||
assert server._load_enabled_toolsets() == ["plugin_demo"]
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_rejects_disabled_mcp_env(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "mcp-off")
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: None),
|
||||
)
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
config_mod,
|
||||
"read_raw_config",
|
||||
lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}},
|
||||
)
|
||||
monkeypatch.setattr(config_mod, "load_config", lambda: {"platform_toolsets": {"cli": ["memory"]}})
|
||||
|
||||
assert server._load_enabled_toolsets() == ["memory"]
|
||||
err = capsys.readouterr().err
|
||||
assert "ignoring disabled MCP servers" in err
|
||||
assert "mcp-off" in err
|
||||
assert "using configured CLI toolsets" in err
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_falls_back_when_tui_env_invalid(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "nope")
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: None),
|
||||
)
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(config_mod, "load_config", lambda: {"platform_toolsets": {"cli": ["memory"]}})
|
||||
|
||||
assert server._load_enabled_toolsets() == ["memory"]
|
||||
assert "using configured CLI toolsets" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_warns_when_config_fallback_fails(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "nope")
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: None),
|
||||
)
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(config_mod, "load_config", lambda: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
|
||||
assert server._load_enabled_toolsets() is None
|
||||
assert "could not be loaded" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_honors_builtin_env_if_config_fails(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web")
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(config_mod, "load_config", lambda: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
|
||||
assert server._load_enabled_toolsets() == ["web"]
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_all_env_means_all(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "all")
|
||||
|
||||
assert server._load_enabled_toolsets() is None
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_all_env_warns_about_ignored_extra_entries(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "all,nope")
|
||||
|
||||
assert server._load_enabled_toolsets() is None
|
||||
assert "ignoring additional entries: nope" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_load_enabled_toolsets_reports_disabled_mcp_separately(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web,mcp-off,nope")
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"hermes_cli.plugins",
|
||||
types.SimpleNamespace(discover_plugins=lambda: None),
|
||||
)
|
||||
|
||||
import hermes_cli.config as config_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
config_mod,
|
||||
"read_raw_config",
|
||||
lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}},
|
||||
)
|
||||
|
||||
assert server._load_enabled_toolsets() == ["web"]
|
||||
err = capsys.readouterr().err
|
||||
assert "ignoring unknown HERMES_TUI_TOOLSETS entries: nope" in err
|
||||
assert "ignoring disabled MCP servers" in err
|
||||
assert "mcp-off" in err
|
||||
|
||||
|
||||
def test_history_to_messages_preserves_tool_calls_for_resume_display():
|
||||
history = [
|
||||
{"role": "user", "content": "first prompt"},
|
||||
@@ -879,6 +1020,36 @@ def test_config_set_statusbar_survives_non_dict_display(tmp_path, monkeypatch):
|
||||
assert saved["display"]["tui_statusbar"] == "bottom"
|
||||
|
||||
|
||||
def test_config_set_details_mode_pins_all_sections(tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
cfg_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{"display": {"sections": {"tools": "expanded", "activity": "hidden"}}}
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(server, "_hermes_home", tmp_path)
|
||||
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "config.set",
|
||||
"params": {"key": "details_mode", "value": "collapsed"},
|
||||
}
|
||||
)
|
||||
|
||||
assert resp["result"] == {"key": "details_mode", "value": "collapsed"}
|
||||
saved = yaml.safe_load(cfg_path.read_text())
|
||||
assert saved["display"]["details_mode"] == "collapsed"
|
||||
assert saved["display"]["sections"] == {
|
||||
"thinking": "collapsed",
|
||||
"tools": "collapsed",
|
||||
"subagents": "collapsed",
|
||||
"activity": "collapsed",
|
||||
}
|
||||
|
||||
|
||||
def test_config_set_section_writes_per_section_override(tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
@@ -1066,6 +1237,18 @@ def test_config_set_reasoning_updates_live_session_and_agent(tmp_path, monkeypat
|
||||
)
|
||||
assert resp_show["result"]["value"] == "show"
|
||||
assert server._sessions["sid"]["show_reasoning"] is True
|
||||
assert server._load_cfg()["display"]["sections"]["thinking"] == "expanded"
|
||||
|
||||
resp_hide = server.handle_request(
|
||||
{
|
||||
"id": "3",
|
||||
"method": "config.set",
|
||||
"params": {"session_id": "sid", "key": "reasoning", "value": "hide"},
|
||||
}
|
||||
)
|
||||
assert resp_hide["result"]["value"] == "hide"
|
||||
assert server._sessions["sid"]["show_reasoning"] is False
|
||||
assert server._load_cfg()["display"]["sections"]["thinking"] == "hidden"
|
||||
|
||||
|
||||
def test_config_set_verbose_updates_session_mode_and_agent(tmp_path, monkeypatch):
|
||||
@@ -1383,7 +1566,7 @@ def test_session_compress_uses_compress_helper(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
server,
|
||||
"_compress_session_history",
|
||||
lambda session, focus_topic=None: (2, {"total": 42}),
|
||||
lambda session, focus_topic=None, **_kw: (2, {"total": 42}),
|
||||
)
|
||||
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"})
|
||||
|
||||
@@ -1394,7 +1577,52 @@ def test_session_compress_uses_compress_helper(monkeypatch):
|
||||
|
||||
assert resp["result"]["removed"] == 2
|
||||
assert resp["result"]["usage"]["total"] == 42
|
||||
emit.assert_called_once_with("session.info", "sid", {"model": "x"})
|
||||
emit.assert_any_call("session.info", "sid", {"model": "x"})
|
||||
# Final status.update clears the pinned "compressing" indicator so the
|
||||
# status bar can revert to the neutral state when compaction finishes.
|
||||
emit.assert_any_call(
|
||||
"status.update", "sid", {"kind": "status", "text": "ready"}
|
||||
)
|
||||
|
||||
|
||||
def test_session_compress_syncs_session_key_after_rotation(monkeypatch):
|
||||
"""When AIAgent._compress_context rotates session_id (compression split),
|
||||
the gateway session_key must follow so subsequent approval routing,
|
||||
DB title/history lookups, and slash worker resume target the new
|
||||
continuation session — mirrors HermesCLI._manual_compress's
|
||||
session_id sync (cli.py).
|
||||
"""
|
||||
agent = types.SimpleNamespace(session_id="rotated-id")
|
||||
server._sessions["sid"] = _session(agent=agent)
|
||||
server._sessions["sid"]["session_key"] = "old-key"
|
||||
server._sessions["sid"]["pending_title"] = "stale title"
|
||||
|
||||
monkeypatch.setattr(
|
||||
server,
|
||||
"_compress_session_history",
|
||||
lambda session, focus_topic=None, **_kw: (2, {"total": 42}),
|
||||
)
|
||||
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"})
|
||||
restart_calls = []
|
||||
monkeypatch.setattr(
|
||||
server, "_restart_slash_worker", lambda s: restart_calls.append(s)
|
||||
)
|
||||
|
||||
try:
|
||||
with patch("tui_gateway.server._emit"):
|
||||
server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.compress",
|
||||
"params": {"session_id": "sid"},
|
||||
}
|
||||
)
|
||||
|
||||
assert server._sessions["sid"]["session_key"] == "rotated-id"
|
||||
assert server._sessions["sid"]["pending_title"] is None
|
||||
assert len(restart_calls) == 1
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_prompt_submit_sets_approval_session_key(monkeypatch):
|
||||
@@ -2240,6 +2468,39 @@ def test_mirror_slash_side_effects_allowed_when_idle(monkeypatch):
|
||||
assert applied["model"]
|
||||
|
||||
|
||||
def test_mirror_slash_compress_does_not_prelock_history(monkeypatch):
|
||||
"""Regression guard: /compress side effect must not hold history_lock
|
||||
when calling _compress_session_history (the helper snapshots under
|
||||
the same non-reentrant lock internally)."""
|
||||
import types
|
||||
|
||||
seen = {"compress": False, "sync": False}
|
||||
emitted = []
|
||||
|
||||
def _fake_compress(session, focus_topic=None, **_kw):
|
||||
seen["compress"] = True
|
||||
assert not session["history_lock"].locked()
|
||||
return (0, {"total": 0})
|
||||
|
||||
def _fake_sync(_sid, _session):
|
||||
seen["sync"] = True
|
||||
|
||||
monkeypatch.setattr(server, "_compress_session_history", _fake_compress)
|
||||
monkeypatch.setattr(server, "_sync_session_key_after_compress", _fake_sync)
|
||||
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"})
|
||||
monkeypatch.setattr(server, "_emit", lambda *args: emitted.append(args))
|
||||
|
||||
session = _session(running=False)
|
||||
session["agent"] = types.SimpleNamespace(model="x")
|
||||
|
||||
warning = server._mirror_slash_side_effects("sid", session, "/compress")
|
||||
|
||||
assert warning == ""
|
||||
assert seen["compress"]
|
||||
assert seen["sync"]
|
||||
assert ("session.info", "sid", {"model": "x"}) in emitted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session.create / session.close race: fast /new churn must not orphan the
|
||||
# slash_worker subprocess or the global approval-notify registration.
|
||||
@@ -2274,10 +2535,20 @@ def test_session_create_close_race_does_not_orphan_worker(monkeypatch):
|
||||
self.base_url = ""
|
||||
self.api_key = ""
|
||||
|
||||
# Make _build block until we release it — simulates slow agent init
|
||||
# Make _build block until we release it — simulates slow agent init.
|
||||
# Also signal when _build actually reaches _make_agent so the test
|
||||
# can close the session at the right moment: session.create now
|
||||
# defers _start_agent_build behind a 50ms timer (see the
|
||||
# `_deferred_build` path in @method("session.create")), so closing
|
||||
# before the build thread has even started would skip the orphan
|
||||
# detection entirely and the test would race a non-event.
|
||||
build_started = threading.Event()
|
||||
release_build = threading.Event()
|
||||
build_entered = threading.Event()
|
||||
|
||||
def _slow_make_agent(sid, key):
|
||||
def _slow_make_agent(sid, key, session_id=None):
|
||||
build_started.set()
|
||||
build_entered.set()
|
||||
release_build.wait(timeout=3.0)
|
||||
return _FakeAgent()
|
||||
|
||||
@@ -2315,6 +2586,13 @@ def test_session_create_close_race_does_not_orphan_worker(monkeypatch):
|
||||
)
|
||||
assert resp.get("result"), f"got error: {resp.get('error')}"
|
||||
sid = resp["result"]["session_id"]
|
||||
assert build_entered.wait(timeout=1.0), "deferred build did not start"
|
||||
|
||||
# Wait until the (deferred) build thread has actually entered
|
||||
# _make_agent — otherwise session.close pops _sessions[sid] before
|
||||
# _build ever runs, _start_agent_build never calls _build, and we
|
||||
# never exercise the orphan-cleanup path.
|
||||
assert build_started.wait(timeout=2.0), "build thread never entered _make_agent"
|
||||
|
||||
# Build thread is blocked in _slow_make_agent. Close the session
|
||||
# NOW — this pops _sessions[sid] before _build can install the
|
||||
@@ -2497,6 +2775,155 @@ def test_session_list_returns_clean_error_when_state_db_is_unavailable(monkeypat
|
||||
assert "state.db unavailable: locking protocol" in resp["error"]["message"]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# session.delete — TUI resume picker `d` key
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_delete_requires_session_id(monkeypatch):
|
||||
"""Empty / missing session_id is a 4006 client error (no DB call)."""
|
||||
called: list[tuple] = []
|
||||
|
||||
class _DB:
|
||||
def delete_session(self, *a, **kw):
|
||||
called.append((a, kw))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
|
||||
resp = server.handle_request({"id": "1", "method": "session.delete", "params": {}})
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4006
|
||||
assert called == []
|
||||
|
||||
|
||||
def test_session_delete_returns_db_unavailable_when_no_db(monkeypatch):
|
||||
monkeypatch.setattr(server, "_get_db", lambda: None)
|
||||
monkeypatch.setattr(server, "_db_error", "locked")
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.delete", "params": {"session_id": "abc"}}
|
||||
)
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 5036
|
||||
assert "state.db unavailable" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_session_delete_refuses_active_session(monkeypatch):
|
||||
"""Cannot delete a session currently bound to a live TUI session."""
|
||||
called: list[str] = []
|
||||
|
||||
class _DB:
|
||||
def delete_session(self, sid, sessions_dir=None):
|
||||
called.append(sid)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
monkeypatch.setitem(server._sessions, "live", {"session_key": "key-live"})
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.delete",
|
||||
"params": {"session_id": "key-live"},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("live", None)
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4023
|
||||
assert "active session" in resp["error"]["message"]
|
||||
assert called == [], "delete_session must not be called for active sessions"
|
||||
|
||||
|
||||
def test_session_delete_fails_closed_when_active_snapshot_raises(monkeypatch):
|
||||
"""Concurrent ``_sessions`` mutation from another RPC thread can raise
|
||||
``RuntimeError: dictionary changed size during iteration``. When the
|
||||
handler can't enumerate active sessions safely it must refuse the
|
||||
delete (fail closed) rather than fall through and allow it."""
|
||||
|
||||
class _DB:
|
||||
def delete_session(self, *a, **kw):
|
||||
raise AssertionError("delete must not run when active snapshot fails")
|
||||
|
||||
class _ExplodingDict:
|
||||
def values(self):
|
||||
raise RuntimeError("dictionary changed size during iteration")
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
monkeypatch.setattr(server, "_sessions", _ExplodingDict())
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.delete", "params": {"session_id": "x"}}
|
||||
)
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 5036
|
||||
assert "enumerate active sessions" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_session_delete_returns_4007_when_missing(monkeypatch):
|
||||
class _DB:
|
||||
def delete_session(self, sid, sessions_dir=None):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.delete", "params": {"session_id": "ghost"}}
|
||||
)
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 4007
|
||||
|
||||
|
||||
def test_session_delete_propagates_db_exception(monkeypatch):
|
||||
class _DB:
|
||||
def delete_session(self, sid, sessions_dir=None):
|
||||
raise RuntimeError("disk full")
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.delete", "params": {"session_id": "x"}}
|
||||
)
|
||||
|
||||
assert "error" in resp
|
||||
assert resp["error"]["code"] == 5036
|
||||
assert "disk full" in resp["error"]["message"]
|
||||
|
||||
|
||||
def test_session_delete_success_returns_deleted_id(monkeypatch):
|
||||
"""Happy path — DB delete succeeds, response carries the deleted id
|
||||
and the on-disk sessions dir is forwarded so transcript files get
|
||||
cleaned up alongside the row."""
|
||||
captured: dict = {}
|
||||
|
||||
class _DB:
|
||||
def delete_session(self, sid, sessions_dir=None):
|
||||
captured["sid"] = sid
|
||||
captured["sessions_dir"] = sessions_dir
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "session.delete", "params": {"session_id": "old-1"}}
|
||||
)
|
||||
|
||||
assert "result" in resp, resp
|
||||
assert resp["result"] == {"deleted": "old-1"}
|
||||
assert captured["sid"] == "old-1"
|
||||
# sessions_dir must be forwarded so transcript files get cleaned up
|
||||
# too — not just the SQLite row. The autouse _isolate_hermes_home
|
||||
# fixture pins HERMES_HOME to a temp dir; the handler should append
|
||||
# /sessions to it.
|
||||
assert captured["sessions_dir"] is not None
|
||||
assert str(captured["sessions_dir"]).endswith("sessions")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# model.options — curated-list parity with `hermes model` and classic /model
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
@@ -127,7 +127,11 @@ class TestReadTrackerCaps:
|
||||
td = ft._read_tracker["long-session"]
|
||||
assert len(td["read_history"]) <= 3
|
||||
assert len(td["dedup"]) <= 3
|
||||
assert len(td["read_timestamps"]) <= 3
|
||||
# read_timestamps is populated lazily (via setdefault) only
|
||||
# when os.path.getmtime() succeeds. On some CI filesystems
|
||||
# that stat can race with file creation — skip rather than
|
||||
# hard-error if the dict hasn't been created yet.
|
||||
assert len(td.get("read_timestamps", {})) <= 3
|
||||
|
||||
|
||||
class TestCompletionConsumedPrune:
|
||||
|
||||
@@ -131,15 +131,15 @@ class TestApprovalHeartbeat:
|
||||
"""Polling slices don't delay responsiveness — resolve is near-instant."""
|
||||
from tools.approval import (
|
||||
check_all_command_guards,
|
||||
has_blocking_approval,
|
||||
register_gateway_notify,
|
||||
resolve_gateway_approval,
|
||||
)
|
||||
|
||||
register_gateway_notify(self.SESSION_KEY, lambda _payload: None)
|
||||
|
||||
start_time = time.monotonic()
|
||||
result_holder: dict = {}
|
||||
|
||||
register_gateway_notify(self.SESSION_KEY, lambda _payload: None)
|
||||
|
||||
def _run_check():
|
||||
result_holder["result"] = check_all_command_guards(
|
||||
"rm -rf /tmp/nonexistent-fast-target", "local"
|
||||
@@ -148,9 +148,18 @@ class TestApprovalHeartbeat:
|
||||
thread = threading.Thread(target=_run_check, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait until the worker has actually enqueued the approval. Resolving
|
||||
# before registration is a test race, not a responsiveness signal.
|
||||
deadline = time.monotonic() + 5.0
|
||||
while time.monotonic() < deadline:
|
||||
if has_blocking_approval(self.SESSION_KEY):
|
||||
break
|
||||
time.sleep(0.01)
|
||||
assert has_blocking_approval(self.SESSION_KEY)
|
||||
|
||||
# Resolve almost immediately — the wait loop should return within
|
||||
# its current 1s poll slice.
|
||||
time.sleep(0.1)
|
||||
start_time = time.monotonic()
|
||||
resolve_gateway_approval(self.SESSION_KEY, "once")
|
||||
thread.join(timeout=5)
|
||||
elapsed = time.monotonic() - start_time
|
||||
|
||||
@@ -354,6 +354,7 @@ class TestOwnerPidCrossProcess:
|
||||
monkeypatch.setattr(
|
||||
bt, "_requires_real_termux_browser_install", lambda *a: False
|
||||
)
|
||||
monkeypatch.setattr(bt, "_chromium_installed", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
bt, "_get_session_info",
|
||||
lambda task_id: {"session_name": session_name},
|
||||
|
||||
@@ -205,36 +205,53 @@ class TestMacosOsascript:
|
||||
|
||||
class TestIsWsl:
|
||||
def setup_method(self):
|
||||
# _is_wsl is now hermes_constants.is_wsl — reset its cache
|
||||
# _is_wsl is hermes_constants.is_wsl; reset the function's own module
|
||||
# globals so this stays stable even if hermes_constants was imported
|
||||
# through a different module object earlier in a large xdist run.
|
||||
import hermes_constants
|
||||
hermes_constants._wsl_detected = None
|
||||
_is_wsl.__globals__["_wsl_detected"] = None
|
||||
|
||||
def teardown_method(self):
|
||||
# Reset again after the test so we don't leak a cached value
|
||||
# (True/False) into whichever test the xdist worker runs next.
|
||||
import hermes_constants
|
||||
hermes_constants._wsl_detected = None
|
||||
_is_wsl.__globals__["_wsl_detected"] = None
|
||||
|
||||
def test_wsl2_detected(self):
|
||||
content = "Linux version 5.15.0 (microsoft-standard-WSL2)"
|
||||
with patch("builtins.open", mock_open(read_data=content)):
|
||||
with patch.dict(_is_wsl.__globals__, {"open": mock_open(read_data=content)}):
|
||||
assert _is_wsl() is True
|
||||
|
||||
def test_wsl1_detected(self):
|
||||
content = "Linux version 4.4.0-microsoft-standard"
|
||||
with patch("builtins.open", mock_open(read_data=content)):
|
||||
with patch.dict(_is_wsl.__globals__, {"open": mock_open(read_data=content)}):
|
||||
assert _is_wsl() is True
|
||||
|
||||
def test_regular_linux(self):
|
||||
# GHA hosted runners are Azure VMs whose real /proc/version often
|
||||
# contains "microsoft". Patching builtins.open with mock_open is
|
||||
# supposed to intercept hermes_constants.is_wsl's `open` call,
|
||||
# but if another test on the same xdist worker already cached
|
||||
# _wsl_detected=True, the mock never runs because the function
|
||||
# short-circuits on the cache. setup_method resets, so we just
|
||||
# need to be sure the patched `open` is actually reached.
|
||||
content = "Linux version 6.14.0-37-generic (buildd@lcy02-amd64-049)"
|
||||
with patch("builtins.open", mock_open(read_data=content)):
|
||||
with patch.dict(_is_wsl.__globals__, {"open": mock_open(read_data=content)}):
|
||||
assert _is_wsl() is False
|
||||
|
||||
def test_proc_version_missing(self):
|
||||
with patch("builtins.open", side_effect=FileNotFoundError):
|
||||
with patch.dict(_is_wsl.__globals__, {"open": MagicMock(side_effect=FileNotFoundError)}):
|
||||
assert _is_wsl() is False
|
||||
|
||||
def test_result_is_cached(self):
|
||||
import hermes_constants
|
||||
content = "Linux version 5.15.0 (microsoft-standard-WSL2)"
|
||||
with patch("builtins.open", mock_open(read_data=content)) as m:
|
||||
opener = mock_open(read_data=content)
|
||||
with patch.dict(_is_wsl.__globals__, {"open": opener}):
|
||||
assert _is_wsl() is True
|
||||
assert _is_wsl() is True
|
||||
m.assert_called_once() # only read once
|
||||
opener.assert_called_once() # only read once
|
||||
|
||||
|
||||
# ── WSL (powershell.exe) ────────────────────────────────────────────────
|
||||
|
||||
@@ -770,11 +770,19 @@ class TestLoadConfig(unittest.TestCase):
|
||||
|
||||
def test_returns_code_execution_section(self):
|
||||
from tools.code_execution_tool import _load_config
|
||||
mock_cli = MagicMock()
|
||||
mock_cli.CLI_CONFIG = {"code_execution": {"timeout": 120, "max_tool_calls": 10}}
|
||||
with patch.dict("sys.modules", {"cli": mock_cli}):
|
||||
with patch("hermes_cli.config.read_raw_config",
|
||||
return_value={"code_execution": {"timeout": 120, "max_tool_calls": 10}}):
|
||||
result = _load_config()
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertEqual(result, {"timeout": 120, "max_tool_calls": 10})
|
||||
|
||||
def test_does_not_import_interactive_cli(self):
|
||||
from tools.code_execution_tool import _load_config
|
||||
mock_cli = MagicMock()
|
||||
mock_cli.CLI_CONFIG = {"code_execution": {"timeout": 999}}
|
||||
with patch.dict("sys.modules", {"cli": mock_cli}), \
|
||||
patch("hermes_cli.config.read_raw_config", return_value={}):
|
||||
result = _load_config()
|
||||
self.assertEqual(result, {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -73,6 +73,10 @@ class TestContainerSkip:
|
||||
result = check_all_command_guards("rm -rf /", "daytona")
|
||||
assert result["approved"] is True
|
||||
|
||||
def test_vercel_sandbox_skips_both(self):
|
||||
result = check_all_command_guards("rm -rf /", "vercel_sandbox")
|
||||
assert result["approved"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tirith allow + safe command
|
||||
|
||||
@@ -231,3 +231,60 @@ class TestUnifiedCronjobTool:
|
||||
assert updated["success"] is True
|
||||
assert updated["job"]["skills"] == []
|
||||
assert updated["job"]["skill"] is None
|
||||
|
||||
def test_create_normalizes_list_form_deliver(self):
|
||||
"""deliver=['telegram'] (list) is stored as the string 'telegram'.
|
||||
|
||||
Regression for #17139: MCP clients / scripts sometimes pass ``deliver``
|
||||
as an array. Prior to the fix, ``['telegram']`` was written verbatim
|
||||
to ``jobs.json`` and the scheduler then tried to resolve the literal
|
||||
string ``"['telegram']"`` as a platform, failing with
|
||||
"no delivery target resolved".
|
||||
"""
|
||||
from cron.jobs import get_job
|
||||
|
||||
created = json.loads(
|
||||
cronjob(
|
||||
action="create",
|
||||
prompt="Daily briefing",
|
||||
schedule="every 1h",
|
||||
deliver=["telegram"],
|
||||
)
|
||||
)
|
||||
assert created["success"] is True
|
||||
stored = get_job(created["job_id"])
|
||||
assert stored["deliver"] == "telegram"
|
||||
|
||||
def test_create_normalizes_multi_element_list_deliver(self):
|
||||
"""deliver=['telegram', 'discord'] is stored as 'telegram,discord'."""
|
||||
from cron.jobs import get_job
|
||||
|
||||
created = json.loads(
|
||||
cronjob(
|
||||
action="create",
|
||||
prompt="Daily briefing",
|
||||
schedule="every 1h",
|
||||
deliver=["telegram", "discord"],
|
||||
)
|
||||
)
|
||||
assert created["success"] is True
|
||||
stored = get_job(created["job_id"])
|
||||
assert stored["deliver"] == "telegram,discord"
|
||||
|
||||
def test_update_normalizes_list_form_deliver(self):
|
||||
"""update with deliver=['telegram'] stores the canonical string."""
|
||||
from cron.jobs import get_job
|
||||
|
||||
created = json.loads(
|
||||
cronjob(action="create", prompt="x", schedule="every 1h")
|
||||
)
|
||||
updated = json.loads(
|
||||
cronjob(
|
||||
action="update",
|
||||
job_id=created["job_id"],
|
||||
deliver=["telegram"],
|
||||
)
|
||||
)
|
||||
assert updated["success"] is True
|
||||
stored = get_job(created["job_id"])
|
||||
assert stored["deliver"] == "telegram"
|
||||
|
||||
@@ -45,6 +45,7 @@ def _make_dummy_env(**kwargs):
|
||||
host_cwd=kwargs.get("host_cwd"),
|
||||
auto_mount_cwd=kwargs.get("auto_mount_cwd", False),
|
||||
env=kwargs.get("env"),
|
||||
run_as_host_user=kwargs.get("run_as_host_user", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -384,9 +385,10 @@ def test_normalize_env_dict_rejects_complex_values():
|
||||
assert result == {"GOOD": "string"}
|
||||
|
||||
|
||||
def test_security_args_include_setuid_setgid_for_gosu_drop():
|
||||
"""_SECURITY_ARGS must include SETUID and SETGID so the image entrypoint
|
||||
can drop from root to the non-root `hermes` user via gosu.
|
||||
def test_security_args_include_setuid_setgid_for_gosu_drop(monkeypatch):
|
||||
"""The default (run_as_host_user=False) invocation must include SETUID and
|
||||
SETGID caps so the image entrypoint can drop from root to the non-root
|
||||
`hermes` user via gosu.
|
||||
|
||||
Without these caps gosu exits with
|
||||
``error: failed switching to 'hermes': operation not permitted``
|
||||
@@ -396,17 +398,117 @@ def test_security_args_include_setuid_setgid_for_gosu_drop():
|
||||
after the drop — the drop is a one-way transition performed before the
|
||||
`no_new_privs` bit is enforced on the exec boundary.
|
||||
"""
|
||||
args = docker_env._SECURITY_ARGS
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
calls = _mock_subprocess_run(monkeypatch)
|
||||
|
||||
_make_dummy_env()
|
||||
|
||||
run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"]
|
||||
assert run_calls, "docker run should have been called"
|
||||
run_args = run_calls[0][0]
|
||||
|
||||
# Flatten to set of added caps for clarity.
|
||||
added = {
|
||||
args[i + 1]
|
||||
for i, flag in enumerate(args[:-1])
|
||||
run_args[i + 1]
|
||||
for i, flag in enumerate(run_args[:-1])
|
||||
if flag == "--cap-add"
|
||||
}
|
||||
assert "SETUID" in added, "SETUID cap missing — gosu drop in entrypoint will fail"
|
||||
assert "SETGID" in added, "SETGID cap missing — gosu drop in entrypoint will fail"
|
||||
|
||||
# Sanity: the hardening posture is still in place.
|
||||
assert "--cap-drop" in args and "ALL" in args
|
||||
assert "--security-opt" in args and "no-new-privileges" in args
|
||||
|
||||
# ── run_as_host_user tests ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_run_as_host_user_passes_uid_gid(monkeypatch):
|
||||
"""With run_as_host_user=True, --user <uid>:<gid> is added to docker run."""
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.os, "getuid", lambda: 1234, raising=False)
|
||||
monkeypatch.setattr(docker_env.os, "getgid", lambda: 5678, raising=False)
|
||||
calls = _mock_subprocess_run(monkeypatch)
|
||||
|
||||
_make_dummy_env(run_as_host_user=True)
|
||||
|
||||
run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"]
|
||||
assert run_calls, "docker run should have been called"
|
||||
run_args = run_calls[0][0]
|
||||
|
||||
# --user must be present and must be paired with "1234:5678"
|
||||
assert "--user" in run_args, f"--user flag missing from docker run args: {run_args}"
|
||||
idx = run_args.index("--user")
|
||||
assert run_args[idx + 1] == "1234:5678", (
|
||||
f"expected --user 1234:5678, got --user {run_args[idx + 1]}"
|
||||
)
|
||||
|
||||
|
||||
def test_run_as_host_user_drops_setuid_setgid_caps(monkeypatch):
|
||||
"""When --user is passed, the container never needs gosu, so SETUID/SETGID
|
||||
caps are omitted for a tighter security posture."""
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.os, "getuid", lambda: 1000, raising=False)
|
||||
monkeypatch.setattr(docker_env.os, "getgid", lambda: 1000, raising=False)
|
||||
calls = _mock_subprocess_run(monkeypatch)
|
||||
|
||||
_make_dummy_env(run_as_host_user=True)
|
||||
|
||||
run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"]
|
||||
run_args = run_calls[0][0]
|
||||
|
||||
added = {
|
||||
run_args[i + 1]
|
||||
for i, flag in enumerate(run_args[:-1])
|
||||
if flag == "--cap-add"
|
||||
}
|
||||
assert "SETUID" not in added, (
|
||||
"SETUID cap should be dropped when running as host user — no gosu drop is needed"
|
||||
)
|
||||
assert "SETGID" not in added, (
|
||||
"SETGID cap should be dropped when running as host user — no gosu drop is needed"
|
||||
)
|
||||
# Core non-privilege-drop caps must still be there (pip/npm/apt need them).
|
||||
assert "DAC_OVERRIDE" in added
|
||||
assert "CHOWN" in added
|
||||
assert "FOWNER" in added
|
||||
|
||||
|
||||
def test_run_as_host_user_default_off(monkeypatch):
|
||||
"""Without the opt-in, no --user flag is emitted — preserving existing behavior."""
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
calls = _mock_subprocess_run(monkeypatch)
|
||||
|
||||
_make_dummy_env() # run_as_host_user defaults to False
|
||||
|
||||
run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"]
|
||||
run_args = run_calls[0][0]
|
||||
assert "--user" not in run_args, (
|
||||
f"--user should not be in docker run args when opt-in is off: {run_args}"
|
||||
)
|
||||
|
||||
|
||||
def test_run_as_host_user_warns_and_skips_when_no_posix_ids(monkeypatch, caplog):
|
||||
"""On platforms without POSIX getuid/getgid, log a warning and leave the
|
||||
container at its image default user (no --user flag, full cap set)."""
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
# Simulate a platform where os.getuid is absent (e.g. Windows host).
|
||||
monkeypatch.delattr(docker_env.os, "getuid", raising=False)
|
||||
monkeypatch.delattr(docker_env.os, "getgid", raising=False)
|
||||
calls = _mock_subprocess_run(monkeypatch)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_make_dummy_env(run_as_host_user=True)
|
||||
|
||||
run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"]
|
||||
run_args = run_calls[0][0]
|
||||
|
||||
assert "--user" not in run_args
|
||||
# Fall back to the full cap set since the container still starts as root.
|
||||
added = {
|
||||
run_args[i + 1]
|
||||
for i, flag in enumerate(run_args[:-1])
|
||||
if flag == "--cap-add"
|
||||
}
|
||||
assert "SETUID" in added
|
||||
assert "SETGID" in added
|
||||
assert any(
|
||||
"does not expose POSIX uid/gid" in rec.getMessage()
|
||||
for rec in caplog.records
|
||||
), "expected a warning when POSIX ids are unavailable"
|
||||
|
||||
@@ -241,7 +241,7 @@ def test_container_backends_still_bypass(clean_session):
|
||||
|
||||
Hardline only protects environments with real host impact (local, ssh).
|
||||
"""
|
||||
for env in ("docker", "singularity", "modal", "daytona"):
|
||||
for env in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"):
|
||||
r1 = check_dangerous_command("rm -rf /", env)
|
||||
assert r1["approved"] is True, f"container {env} should still bypass"
|
||||
r2 = check_all_command_guards("rm -rf /", env)
|
||||
|
||||
@@ -132,6 +132,10 @@ class TestProviderEnvBlocklist:
|
||||
"MODAL_TOKEN_ID": "modal-id",
|
||||
"MODAL_TOKEN_SECRET": "modal-secret",
|
||||
"DAYTONA_API_KEY": "daytona-key",
|
||||
"VERCEL_OIDC_TOKEN": "vercel-oidc-token",
|
||||
"VERCEL_TOKEN": "vercel-token",
|
||||
"VERCEL_PROJECT_ID": "vercel-project",
|
||||
"VERCEL_TEAM_ID": "vercel-team",
|
||||
}
|
||||
result_env = _run_with_env(extra_os_env=leaked_vars)
|
||||
|
||||
@@ -287,6 +291,10 @@ class TestBlocklistCoverage:
|
||||
"MODAL_TOKEN_ID",
|
||||
"MODAL_TOKEN_SECRET",
|
||||
"DAYTONA_API_KEY",
|
||||
"VERCEL_OIDC_TOKEN",
|
||||
"VERCEL_TOKEN",
|
||||
"VERCEL_PROJECT_ID",
|
||||
"VERCEL_TEAM_ID",
|
||||
}
|
||||
assert extras.issubset(_HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -37,6 +38,58 @@ def _pgid_still_alive(pgid: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _process_group_snapshot(pgid: int) -> str:
|
||||
"""Return a process-table snapshot for diagnostics."""
|
||||
return subprocess.run(
|
||||
["ps", "-o", "pid,ppid,pgid,stat,cmd", "-g", str(pgid)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
).stdout.strip()
|
||||
|
||||
|
||||
def _wait_for_pgid_exit(pgid: int, timeout: float = 10.0) -> bool:
|
||||
"""Wait for a process group to disappear under loaded xdist hosts."""
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
if not _pgid_still_alive(pgid):
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
return not _pgid_still_alive(pgid)
|
||||
|
||||
|
||||
def test_kill_process_uses_cached_pgid_if_wrapper_already_exited(monkeypatch):
|
||||
"""If the shell wrapper exits before cleanup, still kill its process group.
|
||||
|
||||
Without the cached pgid fallback, ``os.getpgid(proc.pid)`` raises for the
|
||||
dead wrapper and cleanup falls back to ``proc.kill()``, which cannot reach
|
||||
orphaned grandchildren still running in the original process group.
|
||||
"""
|
||||
env = object.__new__(LocalEnvironment)
|
||||
proc = SimpleNamespace(
|
||||
pid=12345,
|
||||
_hermes_pgid=67890,
|
||||
poll=lambda: 0,
|
||||
kill=lambda: None,
|
||||
)
|
||||
killpg_calls = []
|
||||
|
||||
def fake_getpgid(_pid):
|
||||
raise ProcessLookupError
|
||||
|
||||
def fake_killpg(pgid, sig):
|
||||
killpg_calls.append((pgid, sig))
|
||||
if sig == 0:
|
||||
raise ProcessLookupError
|
||||
|
||||
monkeypatch.setattr(os, "getpgid", fake_getpgid)
|
||||
monkeypatch.setattr(os, "killpg", fake_killpg)
|
||||
|
||||
env._kill_process(proc)
|
||||
|
||||
assert killpg_calls == [(67890, signal.SIGTERM), (67890, 0)]
|
||||
|
||||
|
||||
def test_wait_for_process_kills_subprocess_on_keyboardinterrupt():
|
||||
"""When KeyboardInterrupt arrives mid-poll, the subprocess group must be
|
||||
killed before the exception is re-raised."""
|
||||
@@ -118,19 +171,15 @@ def test_wait_for_process_kills_subprocess_on_keyboardinterrupt():
|
||||
assert not t.is_alive(), "worker didn't exit within 5 s of the interrupt"
|
||||
|
||||
# The critical assertion: the subprocess GROUP must be dead. Not
|
||||
# just the bash wrapper — the 'sleep 30' child too.
|
||||
# Give the SIGTERM+1s wait+SIGKILL escalation a moment to complete.
|
||||
deadline = time.monotonic() + 3.0
|
||||
while time.monotonic() < deadline:
|
||||
if not _pgid_still_alive(pgid):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert not _pgid_still_alive(pgid), (
|
||||
# just the bash wrapper — the 'sleep 30' child too. Under xdist load,
|
||||
# process-group disappearance can lag briefly after the worker exits,
|
||||
# especially if the process is already dying or waiting to be reaped.
|
||||
assert _wait_for_pgid_exit(pgid), (
|
||||
f"subprocess group {pgid} is STILL ALIVE after worker received "
|
||||
f"KeyboardInterrupt — orphan bug regressed. This is the "
|
||||
f"sleep-300-survives-SIGTERM scenario from Physikal's Apr 2026 "
|
||||
f"report. See tools/environments/base.py _wait_for_process "
|
||||
f"except-block."
|
||||
f"except-block.\n{_process_group_snapshot(pgid)}"
|
||||
)
|
||||
# And the worker should have observed the KeyboardInterrupt (i.e.
|
||||
# it re-raised cleanly, not silently swallowed).
|
||||
|
||||
@@ -88,24 +88,29 @@ class TestMessageHandler:
|
||||
from mcp.types import ServerNotification, ToolListChangedNotification
|
||||
|
||||
server = MCPServerTask("notif_srv")
|
||||
with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh:
|
||||
# Product now schedules the refresh as a background task (see
|
||||
# _schedule_tools_refresh in mcp_tool.py ~L918) rather than awaiting
|
||||
# it directly, to avoid wedging the stdio JSON-RPC stream. Patch at
|
||||
# the scheduler seam so we can still assert dispatch happened without
|
||||
# reaching into asyncio.create_task internals.
|
||||
with patch.object(MCPServerTask, "_schedule_tools_refresh") as mock_schedule:
|
||||
handler = server._make_message_handler()
|
||||
notification = ServerNotification(
|
||||
root=ToolListChangedNotification(method="notifications/tools/list_changed")
|
||||
)
|
||||
await handler(notification)
|
||||
mock_refresh.assert_awaited_once()
|
||||
mock_schedule.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_exceptions_and_other_messages(self):
|
||||
server = MCPServerTask("notif_srv")
|
||||
with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh:
|
||||
with patch.object(MCPServerTask, "_schedule_tools_refresh") as mock_schedule:
|
||||
handler = server._make_message_handler()
|
||||
# Exceptions should not trigger refresh
|
||||
await handler(RuntimeError("connection dead"))
|
||||
# Unknown message types should not trigger refresh
|
||||
await handler({"jsonrpc": "2.0", "result": "ok"})
|
||||
mock_refresh.assert_not_awaited()
|
||||
mock_schedule.assert_not_called()
|
||||
|
||||
|
||||
class TestDeregister:
|
||||
|
||||
@@ -35,7 +35,15 @@ def _fake_run_on_mcp_loop(coro, timeout=30):
|
||||
"""Run an MCP coroutine directly in a fresh event loop."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
# `_rpc_lock` must be created inside the loop that awaits it, or asyncio
|
||||
# raises "attached to a different loop". Build it here and attach it to
|
||||
# whatever fake server is currently registered under _servers.
|
||||
async def _install_lock_and_run():
|
||||
for srv in list(mcp_tool._servers.values()):
|
||||
if getattr(srv, "_rpc_lock", None) is None:
|
||||
srv._rpc_lock = asyncio.Lock()
|
||||
return await coro
|
||||
return loop.run_until_complete(_install_lock_and_run())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@@ -44,7 +52,10 @@ def _fake_run_on_mcp_loop(coro, timeout=30):
|
||||
def _patch_mcp_server():
|
||||
"""Patch _servers and the MCP event loop so _make_tool_handler can run."""
|
||||
fake_session = MagicMock()
|
||||
fake_server = SimpleNamespace(session=fake_session)
|
||||
# `_rpc_lock` is acquired by _make_tool_handler's call path (mcp_tool.py
|
||||
# ~L2008) to serialize JSON-RPC against the server — build it inside the
|
||||
# fresh loop that _fake_run_on_mcp_loop spins up, not at fixture import.
|
||||
fake_server = SimpleNamespace(session=fake_session, _rpc_lock=None)
|
||||
with patch.dict(mcp_tool._servers, {"test-server": fake_server}), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop", side_effect=_fake_run_on_mcp_loop):
|
||||
yield fake_session
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user