Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
@@ -134,11 +134,11 @@ FAL_MODELS: Dict[str, Dict[str, Any]] = {
|
||||
},
|
||||
"upscale": False,
|
||||
},
|
||||
"fal-ai/nano-banana": {
|
||||
"display": "Nano Banana (Gemini 2.5 Flash Image)",
|
||||
"speed": "~6s",
|
||||
"strengths": "Gemini 2.5, consistency",
|
||||
"price": "$0.08/image",
|
||||
"fal-ai/nano-banana-pro": {
|
||||
"display": "Nano Banana Pro (Gemini 3 Pro Image)",
|
||||
"speed": "~8s",
|
||||
"strengths": "Gemini 3 Pro, reasoning depth, text rendering",
|
||||
"price": "$0.15/image (1K)",
|
||||
"size_style": "aspect_ratio",
|
||||
"sizes": {
|
||||
"landscape": "16:9",
|
||||
@@ -149,10 +149,14 @@ FAL_MODELS: Dict[str, Dict[str, Any]] = {
|
||||
"num_images": 1,
|
||||
"output_format": "png",
|
||||
"safety_tolerance": "5",
|
||||
# "1K" is the cheapest tier; 4K doubles the per-image cost.
|
||||
# Users on Nous Subscription should stay at 1K for predictable billing.
|
||||
"resolution": "1K",
|
||||
},
|
||||
"supports": {
|
||||
"prompt", "aspect_ratio", "num_images", "output_format",
|
||||
"safety_tolerance", "seed", "sync_mode",
|
||||
"safety_tolerance", "seed", "sync_mode", "resolution",
|
||||
"enable_web_search", "limit_generations",
|
||||
},
|
||||
"upscale": False,
|
||||
},
|
||||
@@ -202,11 +206,11 @@ FAL_MODELS: Dict[str, Dict[str, Any]] = {
|
||||
},
|
||||
"upscale": False,
|
||||
},
|
||||
"fal-ai/recraft-v3": {
|
||||
"display": "Recraft V3",
|
||||
"fal-ai/recraft/v4/pro/text-to-image": {
|
||||
"display": "Recraft V4 Pro",
|
||||
"speed": "~8s",
|
||||
"strengths": "Vector, brand styles",
|
||||
"price": "$0.04/image",
|
||||
"strengths": "Design, brand systems, production-ready",
|
||||
"price": "$0.25/image",
|
||||
"size_style": "image_size_preset",
|
||||
"sizes": {
|
||||
"landscape": "landscape_16_9",
|
||||
@@ -214,10 +218,12 @@ FAL_MODELS: Dict[str, Dict[str, Any]] = {
|
||||
"portrait": "portrait_16_9",
|
||||
},
|
||||
"defaults": {
|
||||
"style": "realistic_image",
|
||||
# V4 Pro dropped V3's required `style` enum — defaults handle taste now.
|
||||
"enable_safety_checker": False,
|
||||
},
|
||||
"supports": {
|
||||
"prompt", "image_size", "style",
|
||||
"prompt", "image_size", "enable_safety_checker",
|
||||
"colors", "background_color",
|
||||
},
|
||||
"upscale": False,
|
||||
},
|
||||
|
||||
@@ -375,6 +375,103 @@ def remove_oauth_tokens(server_name: str) -> None:
|
||||
logger.info("OAuth tokens removed for '%s'", server_name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extracted helpers (Task 3 of MCP OAuth consolidation)
|
||||
#
|
||||
# These compose into ``build_oauth_auth`` below, and are also used by
|
||||
# ``tools.mcp_oauth_manager.MCPOAuthManager._build_provider`` so the two
|
||||
# construction paths share one implementation.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _configure_callback_port(cfg: dict) -> int:
|
||||
"""Pick or validate the OAuth callback port.
|
||||
|
||||
Stores the resolved port into ``cfg['_resolved_port']`` so sibling
|
||||
helpers (and the manager) can read it from the same dict. Returns the
|
||||
resolved port.
|
||||
|
||||
NOTE: also sets the legacy module-level ``_oauth_port`` so existing
|
||||
calls to ``_wait_for_callback`` keep working. The legacy global is
|
||||
the root cause of issue #5344 (port collision on concurrent OAuth
|
||||
flows); replacing it with a ContextVar is out of scope for this
|
||||
consolidation PR.
|
||||
"""
|
||||
global _oauth_port
|
||||
requested = int(cfg.get("redirect_port", 0))
|
||||
port = _find_free_port() if requested == 0 else requested
|
||||
cfg["_resolved_port"] = port
|
||||
_oauth_port = port # legacy consumer: _wait_for_callback reads this
|
||||
return port
|
||||
|
||||
|
||||
def _build_client_metadata(cfg: dict) -> "OAuthClientMetadata":
|
||||
"""Build OAuthClientMetadata from the oauth config dict.
|
||||
|
||||
Requires ``cfg['_resolved_port']`` to have been populated by
|
||||
:func:`_configure_callback_port` first.
|
||||
"""
|
||||
port = cfg.get("_resolved_port")
|
||||
if port is None:
|
||||
raise ValueError(
|
||||
"_configure_callback_port() must be called before _build_client_metadata()"
|
||||
)
|
||||
client_name = cfg.get("client_name", "Hermes Agent")
|
||||
scope = cfg.get("scope")
|
||||
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||
|
||||
metadata_kwargs: dict[str, Any] = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [AnyUrl(redirect_uri)],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
if scope:
|
||||
metadata_kwargs["scope"] = scope
|
||||
if cfg.get("client_secret"):
|
||||
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
|
||||
|
||||
return OAuthClientMetadata.model_validate(metadata_kwargs)
|
||||
|
||||
|
||||
def _maybe_preregister_client(
|
||||
storage: "HermesTokenStorage",
|
||||
cfg: dict,
|
||||
client_metadata: "OAuthClientMetadata",
|
||||
) -> None:
|
||||
"""If cfg has a pre-registered client_id, persist it to storage."""
|
||||
client_id = cfg.get("client_id")
|
||||
if not client_id:
|
||||
return
|
||||
port = cfg["_resolved_port"]
|
||||
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||
|
||||
info_dict: dict[str, Any] = {
|
||||
"client_id": client_id,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": client_metadata.grant_types,
|
||||
"response_types": client_metadata.response_types,
|
||||
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
|
||||
}
|
||||
if cfg.get("client_secret"):
|
||||
info_dict["client_secret"] = cfg["client_secret"]
|
||||
if cfg.get("client_name"):
|
||||
info_dict["client_name"] = cfg["client_name"]
|
||||
if cfg.get("scope"):
|
||||
info_dict["scope"] = cfg["scope"]
|
||||
|
||||
client_info = OAuthClientInformationFull.model_validate(info_dict)
|
||||
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
|
||||
logger.debug("Pre-registered client_id=%s for '%s'", client_id, storage._server_name)
|
||||
|
||||
|
||||
def _parse_base_url(server_url: str) -> str:
|
||||
"""Strip path component from server URL, returning the base origin."""
|
||||
parsed = urlparse(server_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
|
||||
def build_oauth_auth(
|
||||
server_name: str,
|
||||
server_url: str,
|
||||
@@ -382,7 +479,9 @@ def build_oauth_auth(
|
||||
) -> "OAuthClientProvider | None":
|
||||
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
|
||||
|
||||
Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config.
|
||||
Public API preserved for backwards compatibility. New code should use
|
||||
:func:`tools.mcp_oauth_manager.get_manager` so OAuth state is shared
|
||||
across config-time, runtime, and reconnect paths.
|
||||
|
||||
Args:
|
||||
server_name: Server key in mcp_servers config (used for storage).
|
||||
@@ -396,87 +495,32 @@ def build_oauth_auth(
|
||||
if not _OAUTH_AVAILABLE:
|
||||
logger.warning(
|
||||
"MCP OAuth requested for '%s' but SDK auth types are not available. "
|
||||
"Install with: pip install 'mcp>=1.10.0'",
|
||||
"Install with: pip install 'mcp>=1.26.0'",
|
||||
server_name,
|
||||
)
|
||||
return None
|
||||
|
||||
global _oauth_port
|
||||
|
||||
cfg = oauth_config or {}
|
||||
|
||||
# --- Storage ---
|
||||
cfg = dict(oauth_config or {}) # copy — we mutate _resolved_port
|
||||
storage = HermesTokenStorage(server_name)
|
||||
|
||||
# --- Non-interactive warning ---
|
||||
if not _is_interactive() and not storage.has_cached_tokens():
|
||||
logger.warning(
|
||||
"MCP OAuth for '%s': non-interactive environment and no cached tokens found. "
|
||||
"The OAuth flow requires browser authorization. Run interactively first "
|
||||
"to complete the initial authorization, then cached tokens will be reused.",
|
||||
"MCP OAuth for '%s': non-interactive environment and no cached tokens "
|
||||
"found. The OAuth flow requires browser authorization. Run "
|
||||
"interactively first to complete the initial authorization, then "
|
||||
"cached tokens will be reused.",
|
||||
server_name,
|
||||
)
|
||||
|
||||
# --- Pick callback port ---
|
||||
redirect_port = int(cfg.get("redirect_port", 0))
|
||||
if redirect_port == 0:
|
||||
redirect_port = _find_free_port()
|
||||
_oauth_port = redirect_port
|
||||
_configure_callback_port(cfg)
|
||||
client_metadata = _build_client_metadata(cfg)
|
||||
_maybe_preregister_client(storage, cfg, client_metadata)
|
||||
|
||||
# --- Client metadata ---
|
||||
client_name = cfg.get("client_name", "Hermes Agent")
|
||||
scope = cfg.get("scope")
|
||||
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
|
||||
|
||||
metadata_kwargs: dict[str, Any] = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [AnyUrl(redirect_uri)],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
if scope:
|
||||
metadata_kwargs["scope"] = scope
|
||||
|
||||
client_secret = cfg.get("client_secret")
|
||||
if client_secret:
|
||||
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
|
||||
|
||||
client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs)
|
||||
|
||||
# --- Pre-registered client ---
|
||||
client_id = cfg.get("client_id")
|
||||
if client_id:
|
||||
info_dict: dict[str, Any] = {
|
||||
"client_id": client_id,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": client_metadata.grant_types,
|
||||
"response_types": client_metadata.response_types,
|
||||
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
|
||||
}
|
||||
if client_secret:
|
||||
info_dict["client_secret"] = client_secret
|
||||
if client_name:
|
||||
info_dict["client_name"] = client_name
|
||||
if scope:
|
||||
info_dict["scope"] = scope
|
||||
|
||||
client_info = OAuthClientInformationFull.model_validate(info_dict)
|
||||
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
|
||||
logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name)
|
||||
|
||||
# --- Base URL for discovery ---
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# --- Build provider ---
|
||||
provider = OAuthClientProvider(
|
||||
server_url=base_url,
|
||||
return OAuthClientProvider(
|
||||
server_url=_parse_base_url(server_url),
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_handler,
|
||||
callback_handler=_wait_for_callback,
|
||||
timeout=float(cfg.get("timeout", 300)),
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
413
tools/mcp_oauth_manager.py
Normal file
413
tools/mcp_oauth_manager.py
Normal file
@@ -0,0 +1,413 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Central manager for per-server MCP OAuth state.
|
||||
|
||||
One instance shared across the process. Holds per-server OAuth provider
|
||||
instances and coordinates:
|
||||
|
||||
- **Cross-process token reload** via mtime-based disk watch. When an external
|
||||
process (e.g. a user cron job) refreshes tokens on disk, the next auth flow
|
||||
picks them up without requiring a process restart.
|
||||
- **401 deduplication** via in-flight futures. When N concurrent tool calls
|
||||
all hit 401 with the same access_token, only one recovery attempt fires;
|
||||
the rest await the same result.
|
||||
- **Reconnect signalling** for long-lived MCP sessions. The manager itself
|
||||
does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does —
|
||||
but the manager is the single source of truth that decides when reconnect
|
||||
is warranted.
|
||||
|
||||
Replaces what used to be scattered across eight call sites in `mcp_oauth.py`,
|
||||
`mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place
|
||||
that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths
|
||||
go through `get_manager()`.
|
||||
|
||||
Design reference:
|
||||
|
||||
- Claude Code's ``invalidateOAuthCacheIfDiskChanged``
|
||||
(``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical
|
||||
external-refresh staleness bug class.
|
||||
- Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed``
|
||||
(``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's
|
||||
lazy refresh rather than calling refresh before every op, because one
|
||||
``stat()`` per tool call is cheaper than an ``await`` + potential refresh
|
||||
round-trip, and the SDK's in-memory expiry path is already correct.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-server entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ProviderEntry:
|
||||
"""Per-server OAuth state tracked by the manager.
|
||||
|
||||
Fields:
|
||||
server_url: The MCP server URL used to build the provider. Tracked
|
||||
so we can discard a cached provider if the URL changes.
|
||||
oauth_config: Optional dict from ``mcp_servers.<name>.oauth``.
|
||||
provider: The ``httpx.Auth``-compatible provider wrapping the MCP
|
||||
SDK. None until first use.
|
||||
last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file.
|
||||
Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed`
|
||||
to detect external refreshes.
|
||||
lock: Serialises concurrent access to this entry's state. Bound to
|
||||
whichever asyncio loop first awaits it (the MCP event loop).
|
||||
pending_401: In-flight 401-handler futures keyed by the failed
|
||||
access_token, for deduplicating thundering-herd 401s. Mirrors
|
||||
Claude Code's ``pending401Handlers`` map.
|
||||
"""
|
||||
|
||||
server_url: str
|
||||
oauth_config: Optional[dict]
|
||||
provider: Optional[Any] = None
|
||||
last_mtime_ns: int = 0
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hermes_provider_class() -> Optional[type]:
|
||||
"""Lazy-import the SDK base class and return our subclass.
|
||||
|
||||
Wrapped in a function so this module imports cleanly even when the
|
||||
MCP SDK's OAuth module is unavailable (e.g. older mcp versions).
|
||||
"""
|
||||
try:
|
||||
from mcp.client.auth.oauth2 import OAuthClientProvider
|
||||
except ImportError: # pragma: no cover — SDK required in CI
|
||||
return None
|
||||
|
||||
class HermesMCPOAuthProvider(OAuthClientProvider):
|
||||
"""OAuthClientProvider with pre-flow disk-mtime reload.
|
||||
|
||||
Before every ``async_auth_flow`` invocation, asks the manager to
|
||||
check whether the tokens file on disk has been modified externally.
|
||||
If so, the manager resets ``_initialized`` so the next flow
|
||||
re-reads from storage.
|
||||
|
||||
This makes external-process refreshes (cron, another CLI instance)
|
||||
visible to the running MCP session without requiring a restart.
|
||||
|
||||
Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged``
|
||||
(``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hermes_server_name = server_name
|
||||
|
||||
async def async_auth_flow(self, request): # type: ignore[override]
|
||||
# Pre-flow hook: ask the manager to refresh from disk if needed.
|
||||
# Any failure here is non-fatal — we just log and proceed with
|
||||
# whatever state the SDK already has.
|
||||
try:
|
||||
await get_manager().invalidate_if_disk_changed(
|
||||
self._hermes_server_name
|
||||
)
|
||||
except Exception as exc: # pragma: no cover — defensive
|
||||
logger.debug(
|
||||
"MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s",
|
||||
self._hermes_server_name, exc,
|
||||
)
|
||||
|
||||
# Delegate to the SDK's auth flow
|
||||
async for item in super().async_auth_flow(request):
|
||||
yield item
|
||||
|
||||
return HermesMCPOAuthProvider
|
||||
|
||||
|
||||
# Cached at import time. Tested and used by :class:`MCPOAuthManager`.
|
||||
_HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MCPOAuthManager:
|
||||
"""Single source of truth for per-server MCP OAuth state.
|
||||
|
||||
Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for
|
||||
get-or-create semantics. Per-entry state is guarded by the entry's own
|
||||
``asyncio.Lock`` (used from the MCP event loop thread).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[str, _ProviderEntry] = {}
|
||||
self._entries_lock = threading.Lock()
|
||||
|
||||
# -- Provider construction / caching -------------------------------------
|
||||
|
||||
def get_or_build_provider(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str,
|
||||
oauth_config: Optional[dict],
|
||||
) -> Optional[Any]:
|
||||
"""Return a cached OAuth provider for ``server_name`` or build one.
|
||||
|
||||
Idempotent: repeat calls with the same name return the same instance.
|
||||
If ``server_url`` changes for a given name, the cached entry is
|
||||
discarded and a fresh provider is built.
|
||||
|
||||
Returns None if the MCP SDK's OAuth support is unavailable.
|
||||
"""
|
||||
with self._entries_lock:
|
||||
entry = self._entries.get(server_name)
|
||||
if entry is not None and entry.server_url != server_url:
|
||||
logger.info(
|
||||
"MCP OAuth '%s': URL changed from %s to %s, discarding cache",
|
||||
server_name, entry.server_url, server_url,
|
||||
)
|
||||
entry = None
|
||||
|
||||
if entry is None:
|
||||
entry = _ProviderEntry(
|
||||
server_url=server_url,
|
||||
oauth_config=oauth_config,
|
||||
)
|
||||
self._entries[server_name] = entry
|
||||
|
||||
if entry.provider is None:
|
||||
entry.provider = self._build_provider(server_name, entry)
|
||||
|
||||
return entry.provider
|
||||
|
||||
def _build_provider(
|
||||
self,
|
||||
server_name: str,
|
||||
entry: _ProviderEntry,
|
||||
) -> Optional[Any]:
|
||||
"""Build the underlying OAuth provider.
|
||||
|
||||
Constructs :class:`HermesMCPOAuthProvider` directly using the helpers
|
||||
extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow
|
||||
disk-watch hook so external token refreshes (cron, other CLI
|
||||
instances) are visible to running MCP sessions.
|
||||
|
||||
Returns None if the MCP SDK's OAuth support is unavailable.
|
||||
"""
|
||||
if _HERMES_PROVIDER_CLS is None:
|
||||
logger.warning(
|
||||
"MCP OAuth '%s': SDK auth module unavailable", server_name,
|
||||
)
|
||||
return None
|
||||
|
||||
# Local imports avoid circular deps at module import time.
|
||||
from tools.mcp_oauth import (
|
||||
HermesTokenStorage,
|
||||
_OAUTH_AVAILABLE,
|
||||
_build_client_metadata,
|
||||
_configure_callback_port,
|
||||
_is_interactive,
|
||||
_maybe_preregister_client,
|
||||
_parse_base_url,
|
||||
_redirect_handler,
|
||||
_wait_for_callback,
|
||||
)
|
||||
|
||||
if not _OAUTH_AVAILABLE:
|
||||
return None
|
||||
|
||||
cfg = dict(entry.oauth_config or {})
|
||||
storage = HermesTokenStorage(server_name)
|
||||
|
||||
if not _is_interactive() and not storage.has_cached_tokens():
|
||||
logger.warning(
|
||||
"MCP OAuth for '%s': non-interactive environment and no "
|
||||
"cached tokens found. Run interactively first to complete "
|
||||
"initial authorization.",
|
||||
server_name,
|
||||
)
|
||||
|
||||
_configure_callback_port(cfg)
|
||||
client_metadata = _build_client_metadata(cfg)
|
||||
_maybe_preregister_client(storage, cfg, client_metadata)
|
||||
|
||||
return _HERMES_PROVIDER_CLS(
|
||||
server_name=server_name,
|
||||
server_url=_parse_base_url(entry.server_url),
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_handler,
|
||||
callback_handler=_wait_for_callback,
|
||||
timeout=float(cfg.get("timeout", 300)),
|
||||
)
|
||||
|
||||
def remove(self, server_name: str) -> None:
|
||||
"""Evict the provider from cache AND delete tokens from disk.
|
||||
|
||||
Called by ``hermes mcp remove <name>`` and (indirectly) by
|
||||
``hermes mcp login <name>`` during forced re-auth.
|
||||
"""
|
||||
with self._entries_lock:
|
||||
self._entries.pop(server_name, None)
|
||||
|
||||
from tools.mcp_oauth import remove_oauth_tokens
|
||||
remove_oauth_tokens(server_name)
|
||||
logger.info(
|
||||
"MCP OAuth '%s': evicted from cache and removed from disk",
|
||||
server_name,
|
||||
)
|
||||
|
||||
# -- Disk watch ----------------------------------------------------------
|
||||
|
||||
async def invalidate_if_disk_changed(self, server_name: str) -> bool:
|
||||
"""If the tokens file on disk has a newer mtime than last-seen, force
|
||||
the MCP SDK provider to reload its in-memory state.
|
||||
|
||||
Returns True if the cache was invalidated (mtime differed). This is
|
||||
the core fix for the external-refresh workflow: a cron job writes
|
||||
fresh tokens to disk, and on the next tool call the running MCP
|
||||
session picks them up without a restart.
|
||||
"""
|
||||
from tools.mcp_oauth import _get_token_dir, _safe_filename
|
||||
|
||||
entry = self._entries.get(server_name)
|
||||
if entry is None or entry.provider is None:
|
||||
return False
|
||||
|
||||
async with entry.lock:
|
||||
tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json"
|
||||
try:
|
||||
mtime_ns = tokens_path.stat().st_mtime_ns
|
||||
except (FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
if mtime_ns != entry.last_mtime_ns:
|
||||
old = entry.last_mtime_ns
|
||||
entry.last_mtime_ns = mtime_ns
|
||||
# Force the SDK's OAuthClientProvider to reload from storage
|
||||
# on its next auth flow. `_initialized` is private API but
|
||||
# stable across the MCP SDK versions we pin (>=1.26.0).
|
||||
if hasattr(entry.provider, "_initialized"):
|
||||
entry.provider._initialized = False # noqa: SLF001
|
||||
logger.info(
|
||||
"MCP OAuth '%s': tokens file changed (mtime %d -> %d), "
|
||||
"forcing reload",
|
||||
server_name, old, mtime_ns,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
# -- 401 handler (dedup'd) -----------------------------------------------
|
||||
|
||||
async def handle_401(
|
||||
self,
|
||||
server_name: str,
|
||||
failed_access_token: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Handle a 401 from a tool call, deduplicated across concurrent callers.
|
||||
|
||||
Returns:
|
||||
True if a (possibly new) access token is now available — caller
|
||||
should trigger a reconnect and retry the operation.
|
||||
False if no recovery path exists — caller should surface a
|
||||
``needs_reauth`` error to the model so it stops hallucinating
|
||||
manual refresh attempts.
|
||||
|
||||
Thundering-herd protection: if N concurrent tool calls hit 401 with
|
||||
the same ``failed_access_token``, only one recovery attempt fires.
|
||||
Others await the same future.
|
||||
"""
|
||||
entry = self._entries.get(server_name)
|
||||
if entry is None or entry.provider is None:
|
||||
return False
|
||||
|
||||
key = failed_access_token or "<unknown>"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async with entry.lock:
|
||||
pending = entry.pending_401.get(key)
|
||||
if pending is None:
|
||||
pending = loop.create_future()
|
||||
entry.pending_401[key] = pending
|
||||
|
||||
async def _do_handle() -> None:
|
||||
try:
|
||||
# Step 1: Did disk change? Picks up external refresh.
|
||||
disk_changed = await self.invalidate_if_disk_changed(
|
||||
server_name
|
||||
)
|
||||
if disk_changed:
|
||||
if not pending.done():
|
||||
pending.set_result(True)
|
||||
return
|
||||
|
||||
# Step 2: No disk change — if the SDK can refresh
|
||||
# in-place, let the caller retry. The SDK's httpx.Auth
|
||||
# flow will issue the refresh on the next request.
|
||||
provider = entry.provider
|
||||
ctx = getattr(provider, "context", None)
|
||||
can_refresh = False
|
||||
if ctx is not None:
|
||||
can_refresh_fn = getattr(ctx, "can_refresh_token", None)
|
||||
if callable(can_refresh_fn):
|
||||
try:
|
||||
can_refresh = bool(can_refresh_fn())
|
||||
except Exception:
|
||||
can_refresh = False
|
||||
if not pending.done():
|
||||
pending.set_result(can_refresh)
|
||||
except Exception as exc: # pragma: no cover — defensive
|
||||
logger.warning(
|
||||
"MCP OAuth '%s': 401 handler failed: %s",
|
||||
server_name, exc,
|
||||
)
|
||||
if not pending.done():
|
||||
pending.set_result(False)
|
||||
finally:
|
||||
entry.pending_401.pop(key, None)
|
||||
|
||||
asyncio.create_task(_do_handle())
|
||||
|
||||
try:
|
||||
return await pending
|
||||
except Exception as exc: # pragma: no cover — defensive
|
||||
logger.warning(
|
||||
"MCP OAuth '%s': awaiting 401 handler failed: %s",
|
||||
server_name, exc,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_MANAGER: Optional[MCPOAuthManager] = None
|
||||
_MANAGER_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_manager() -> MCPOAuthManager:
|
||||
"""Return the process-wide :class:`MCPOAuthManager` singleton."""
|
||||
global _MANAGER
|
||||
with _MANAGER_LOCK:
|
||||
if _MANAGER is None:
|
||||
_MANAGER = MCPOAuthManager()
|
||||
return _MANAGER
|
||||
|
||||
|
||||
def reset_manager_for_tests() -> None:
|
||||
"""Test-only helper: drop the singleton so fixtures start clean."""
|
||||
global _MANAGER
|
||||
with _MANAGER_LOCK:
|
||||
_MANAGER = None
|
||||
@@ -783,7 +783,8 @@ class MCPServerTask:
|
||||
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
||||
"_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||
)
|
||||
|
||||
@@ -794,6 +795,12 @@ class MCPServerTask:
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
# Set by tool handlers on auth failure after manager.handle_401()
|
||||
# confirms recovery is viable. When set, _run_http / _run_stdio
|
||||
# exit their async-with blocks cleanly (no exception), and the
|
||||
# outer run() loop re-enters the transport so the MCP session is
|
||||
# rebuilt with fresh credentials.
|
||||
self._reconnect_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
@@ -887,6 +894,40 @@ class MCPServerTask:
|
||||
self.name, len(self._registered_tool_names),
|
||||
)
|
||||
|
||||
async def _wait_for_lifecycle_event(self) -> str:
|
||||
"""Block until either _shutdown_event or _reconnect_event fires.
|
||||
|
||||
Returns:
|
||||
"shutdown" if the server should exit the run loop entirely.
|
||||
"reconnect" if the server should tear down the current MCP
|
||||
session and re-enter the transport (fresh OAuth
|
||||
tokens, new session ID, etc.). The reconnect event
|
||||
is cleared before return so the next cycle starts
|
||||
with a fresh signal.
|
||||
|
||||
Shutdown takes precedence if both events are set simultaneously.
|
||||
"""
|
||||
shutdown_task = asyncio.create_task(self._shutdown_event.wait())
|
||||
reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{shutdown_task, reconnect_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
finally:
|
||||
for t in (shutdown_task, reconnect_task):
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
try:
|
||||
await t
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
if self._shutdown_event.is_set():
|
||||
return "shutdown"
|
||||
self._reconnect_event.clear()
|
||||
return "reconnect"
|
||||
|
||||
async def _run_stdio(self, config: dict):
|
||||
"""Run the server using stdio transport."""
|
||||
command = config.get("command")
|
||||
@@ -932,7 +973,10 @@ class MCPServerTask:
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
# stdio transport does not use OAuth, but we still honor
|
||||
# _reconnect_event (e.g. future manual /mcp refresh) for
|
||||
# consistency with _run_http.
|
||||
await self._wait_for_lifecycle_event()
|
||||
# Context exited cleanly — subprocess was terminated by the SDK.
|
||||
if new_pids:
|
||||
with _lock:
|
||||
@@ -951,16 +995,18 @@ class MCPServerTask:
|
||||
headers = dict(config.get("headers") or {})
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK.
|
||||
# If OAuth setup fails (e.g. non-interactive environment without
|
||||
# cached tokens), re-raise so this server is reported as failed
|
||||
# without blocking other MCP servers from connecting.
|
||||
# OAuth 2.1 PKCE: route through the central MCPOAuthManager so the
|
||||
# same provider instance is reused across reconnects, pre-flow
|
||||
# disk-watch is active, and config-time CLI code paths share state.
|
||||
# If OAuth setup fails (e.g. non-interactive env without cached
|
||||
# tokens), re-raise so this server is reported as failed without
|
||||
# blocking other MCP servers from connecting.
|
||||
_oauth_auth = None
|
||||
if self._auth_type == "oauth":
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
_oauth_auth = build_oauth_auth(
|
||||
self.name, url, config.get("oauth")
|
||||
from tools.mcp_oauth_manager import get_manager
|
||||
_oauth_auth = get_manager().get_or_build_provider(
|
||||
self.name, url, config.get("oauth"),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
@@ -995,7 +1041,12 @@ class MCPServerTask:
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
reason = await self._wait_for_lifecycle_event()
|
||||
if reason == "reconnect":
|
||||
logger.info(
|
||||
"MCP server '%s': reconnect requested — "
|
||||
"tearing down HTTP session", self.name,
|
||||
)
|
||||
else:
|
||||
# Deprecated API (mcp < 1.24.0): manages httpx client internally.
|
||||
_http_kwargs: dict = {
|
||||
@@ -1012,7 +1063,12 @@ class MCPServerTask:
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
reason = await self._wait_for_lifecycle_event()
|
||||
if reason == "reconnect":
|
||||
logger.info(
|
||||
"MCP server '%s': reconnect requested — "
|
||||
"tearing down legacy HTTP session", self.name,
|
||||
)
|
||||
|
||||
async def _discover_tools(self):
|
||||
"""Discover tools from the connected session."""
|
||||
@@ -1060,8 +1116,25 @@ class MCPServerTask:
|
||||
await self._run_http(config)
|
||||
else:
|
||||
await self._run_stdio(config)
|
||||
# Normal exit (shutdown requested) -- break out
|
||||
break
|
||||
# Transport returned cleanly. Two cases:
|
||||
# - _shutdown_event was set: exit the run loop entirely.
|
||||
# - _reconnect_event was set (auth recovery): loop back and
|
||||
# rebuild the MCP session with fresh credentials. Do NOT
|
||||
# touch the retry counters — this is not a failure.
|
||||
if self._shutdown_event.is_set():
|
||||
break
|
||||
logger.info(
|
||||
"MCP server '%s': reconnecting (OAuth recovery or "
|
||||
"manual refresh)",
|
||||
self.name,
|
||||
)
|
||||
# Reset the session reference; _run_http/_run_stdio will
|
||||
# repopulate it on successful re-entry.
|
||||
self.session = None
|
||||
# Keep _ready set across reconnects so tool handlers can
|
||||
# still detect a transient in-flight state — it'll be
|
||||
# re-set after the fresh session initializes.
|
||||
continue
|
||||
except Exception as exc:
|
||||
self.session = None
|
||||
|
||||
@@ -1141,6 +1214,12 @@ class MCPServerTask:
|
||||
from tools.registry import registry
|
||||
|
||||
self._shutdown_event.set()
|
||||
# Defensive: if _wait_for_lifecycle_event is blocking, we need ANY
|
||||
# event to unblock it. _shutdown_event alone is sufficient (the
|
||||
# helper checks shutdown first), but setting reconnect too ensures
|
||||
# there's no race where the helper misses the shutdown flag after
|
||||
# returning "reconnect".
|
||||
self._reconnect_event.set()
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=10)
|
||||
@@ -1174,6 +1253,175 @@ _servers: Dict[str, MCPServerTask] = {}
|
||||
_server_error_counts: Dict[str, int] = {}
|
||||
_CIRCUIT_BREAKER_THRESHOLD = 3
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Cached tuple of auth-related exception types. Lazy so this module
|
||||
# imports cleanly when the MCP SDK OAuth module is missing.
|
||||
_AUTH_ERROR_TYPES: tuple = ()
|
||||
|
||||
|
||||
def _get_auth_error_types() -> tuple:
|
||||
"""Return a tuple of exception types that indicate MCP OAuth failure.
|
||||
|
||||
Cached after first call. Includes:
|
||||
- ``mcp.client.auth.OAuthFlowError`` / ``OAuthTokenError`` — raised by
|
||||
the SDK's auth flow when discovery, refresh, or full re-auth fails.
|
||||
- ``mcp.client.auth.UnauthorizedError`` (older MCP SDKs) — kept as an
|
||||
optional import for forward/backward compatibility.
|
||||
- ``tools.mcp_oauth.OAuthNonInteractiveError`` — raised by our callback
|
||||
handler when no user is present to complete a browser flow.
|
||||
- ``httpx.HTTPStatusError`` — caller must additionally check
|
||||
``status_code == 401`` via :func:`_is_auth_error`.
|
||||
"""
|
||||
global _AUTH_ERROR_TYPES
|
||||
if _AUTH_ERROR_TYPES:
|
||||
return _AUTH_ERROR_TYPES
|
||||
types: list = []
|
||||
try:
|
||||
from mcp.client.auth import OAuthFlowError, OAuthTokenError
|
||||
types.extend([OAuthFlowError, OAuthTokenError])
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
# Older MCP SDK variants exported this
|
||||
from mcp.client.auth import UnauthorizedError # type: ignore
|
||||
types.append(UnauthorizedError)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from tools.mcp_oauth import OAuthNonInteractiveError
|
||||
types.append(OAuthNonInteractiveError)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import httpx
|
||||
types.append(httpx.HTTPStatusError)
|
||||
except ImportError:
|
||||
pass
|
||||
_AUTH_ERROR_TYPES = tuple(types)
|
||||
return _AUTH_ERROR_TYPES
|
||||
|
||||
|
||||
def _is_auth_error(exc: BaseException) -> bool:
|
||||
"""Return True if ``exc`` indicates an MCP OAuth failure.
|
||||
|
||||
``httpx.HTTPStatusError`` is only treated as auth-related when the
|
||||
response status code is 401. Other HTTP errors fall through to the
|
||||
generic error path in the tool handlers.
|
||||
"""
|
||||
types = _get_auth_error_types()
|
||||
if not types or not isinstance(exc, types):
|
||||
return False
|
||||
try:
|
||||
import httpx
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return getattr(exc.response, "status_code", None) == 401
|
||||
except ImportError:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def _handle_auth_error_and_retry(
|
||||
server_name: str,
|
||||
exc: BaseException,
|
||||
retry_call,
|
||||
op_description: str,
|
||||
):
|
||||
"""Attempt auth recovery and one retry; return None to fall through.
|
||||
|
||||
Called by the 5 MCP tool handlers when ``session.<op>()`` raises an
|
||||
auth-related exception. Workflow:
|
||||
|
||||
1. Ask :class:`tools.mcp_oauth_manager.MCPOAuthManager.handle_401` if
|
||||
recovery is viable (i.e., disk has fresh tokens, or the SDK can
|
||||
refresh in-place).
|
||||
2. If yes, set the server's ``_reconnect_event`` so the server task
|
||||
tears down the current MCP session and rebuilds it with fresh
|
||||
credentials. Wait briefly for ``_ready`` to re-fire.
|
||||
3. Retry the operation once. Return the retry result if it produced
|
||||
a non-error JSON payload. Otherwise return the ``needs_reauth``
|
||||
error dict so the model stops hallucinating manual refresh.
|
||||
4. Return None if ``exc`` is not an auth error, signalling the
|
||||
caller to use the generic error path.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server that raised.
|
||||
exc: The exception from the failed tool call.
|
||||
retry_call: Zero-arg callable that re-runs the tool call, returning
|
||||
the same JSON string format as the handler.
|
||||
op_description: Human-readable name of the operation (for logs).
|
||||
|
||||
Returns:
|
||||
A JSON string if auth recovery was attempted, or None to fall
|
||||
through to the caller's generic error path.
|
||||
"""
|
||||
if not _is_auth_error(exc):
|
||||
return None
|
||||
|
||||
from tools.mcp_oauth_manager import get_manager
|
||||
manager = get_manager()
|
||||
|
||||
async def _recover():
|
||||
return await manager.handle_401(server_name, None)
|
||||
|
||||
try:
|
||||
recovered = _run_on_mcp_loop(_recover(), timeout=10)
|
||||
except Exception as rec_exc:
|
||||
logger.warning(
|
||||
"MCP OAuth '%s': recovery attempt failed: %s",
|
||||
server_name, rec_exc,
|
||||
)
|
||||
recovered = False
|
||||
|
||||
if recovered:
|
||||
with _lock:
|
||||
srv = _servers.get(server_name)
|
||||
if srv is not None and hasattr(srv, "_reconnect_event"):
|
||||
loop = _mcp_loop
|
||||
if loop is not None and loop.is_running():
|
||||
loop.call_soon_threadsafe(srv._reconnect_event.set)
|
||||
# Wait briefly for the session to come back ready. Bounded
|
||||
# so that a stuck reconnect falls through to the error
|
||||
# path rather than hanging the caller.
|
||||
deadline = time.monotonic() + 15
|
||||
while time.monotonic() < deadline:
|
||||
if srv.session is not None and srv._ready.is_set():
|
||||
break
|
||||
time.sleep(0.25)
|
||||
|
||||
try:
|
||||
result = retry_call()
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
if "error" not in parsed:
|
||||
_server_error_counts[server_name] = 0
|
||||
return result
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
_server_error_counts[server_name] = 0
|
||||
return result
|
||||
except Exception as retry_exc:
|
||||
logger.warning(
|
||||
"MCP %s/%s retry after auth recovery failed: %s",
|
||||
server_name, op_description, retry_exc,
|
||||
)
|
||||
|
||||
# No recovery available, or retry also failed: surface a structured
|
||||
# needs_reauth error. Bumps the circuit breaker so the model stops
|
||||
# retrying the tool.
|
||||
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"MCP server '{server_name}' requires re-authentication. "
|
||||
f"Run `hermes mcp login {server_name}` (or delete the tokens "
|
||||
f"file under ~/.hermes/mcp-tokens/ and restart). Do NOT retry "
|
||||
f"this tool — ask the user to re-authenticate."
|
||||
),
|
||||
"needs_reauth": True,
|
||||
"server": server_name,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
@@ -1420,8 +1668,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
return json.dumps({"result": structured}, ensure_ascii=False)
|
||||
return json.dumps({"result": text_result}, ensure_ascii=False)
|
||||
|
||||
def _call_once():
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
|
||||
try:
|
||||
result = _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
result = _call_once()
|
||||
# Check if the MCP tool itself returned an error
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
@@ -1435,6 +1686,16 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
except InterruptedError:
|
||||
return _interrupted_call_result()
|
||||
except Exception as exc:
|
||||
# Auth-specific recovery path: consult the manager, signal
|
||||
# reconnect if viable, retry once. Returns None to fall
|
||||
# through for non-auth exceptions.
|
||||
recovered = _handle_auth_error_and_retry(
|
||||
server_name, exc, _call_once,
|
||||
f"tools/call {tool_name}",
|
||||
)
|
||||
if recovered is not None:
|
||||
return recovered
|
||||
|
||||
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||
logger.error(
|
||||
"MCP tool %s/%s call failed: %s",
|
||||
@@ -1476,11 +1737,19 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
resources.append(entry)
|
||||
return json.dumps({"resources": resources}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
def _call_once():
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
|
||||
try:
|
||||
return _call_once()
|
||||
except InterruptedError:
|
||||
return _interrupted_call_result()
|
||||
except Exception as exc:
|
||||
recovered = _handle_auth_error_and_retry(
|
||||
server_name, exc, _call_once, "resources/list",
|
||||
)
|
||||
if recovered is not None:
|
||||
return recovered
|
||||
logger.error(
|
||||
"MCP %s/list_resources failed: %s", server_name, exc,
|
||||
)
|
||||
@@ -1522,11 +1791,19 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
parts.append(f"[binary data, {len(block.blob)} bytes]")
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
def _call_once():
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
|
||||
try:
|
||||
return _call_once()
|
||||
except InterruptedError:
|
||||
return _interrupted_call_result()
|
||||
except Exception as exc:
|
||||
recovered = _handle_auth_error_and_retry(
|
||||
server_name, exc, _call_once, "resources/read",
|
||||
)
|
||||
if recovered is not None:
|
||||
return recovered
|
||||
logger.error(
|
||||
"MCP %s/read_resource failed: %s", server_name, exc,
|
||||
)
|
||||
@@ -1571,11 +1848,19 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
prompts.append(entry)
|
||||
return json.dumps({"prompts": prompts}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
def _call_once():
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
|
||||
try:
|
||||
return _call_once()
|
||||
except InterruptedError:
|
||||
return _interrupted_call_result()
|
||||
except Exception as exc:
|
||||
recovered = _handle_auth_error_and_retry(
|
||||
server_name, exc, _call_once, "prompts/list",
|
||||
)
|
||||
if recovered is not None:
|
||||
return recovered
|
||||
logger.error(
|
||||
"MCP %s/list_prompts failed: %s", server_name, exc,
|
||||
)
|
||||
@@ -1628,11 +1913,19 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
resp["description"] = result.description
|
||||
return json.dumps(resp, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
def _call_once():
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
|
||||
try:
|
||||
return _call_once()
|
||||
except InterruptedError:
|
||||
return _interrupted_call_result()
|
||||
except Exception as exc:
|
||||
recovered = _handle_auth_error_and_retry(
|
||||
server_name, exc, _call_once, "prompts/get",
|
||||
)
|
||||
if recovered is not None:
|
||||
return recovered
|
||||
logger.error(
|
||||
"MCP %s/get_prompt failed: %s", server_name, exc,
|
||||
)
|
||||
|
||||
@@ -215,7 +215,27 @@ def _handle_send(args):
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
return tool_error(f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables.")
|
||||
# Weixin can be configured purely via .env; synthesize a pconfig so
|
||||
# send_message and cron delivery work without a gateway.yaml entry.
|
||||
if platform_name == "weixin":
|
||||
import os
|
||||
wx_token = os.getenv("WEIXIN_TOKEN", "").strip()
|
||||
wx_account = os.getenv("WEIXIN_ACCOUNT_ID", "").strip()
|
||||
if wx_token and wx_account:
|
||||
from gateway.config import PlatformConfig
|
||||
pconfig = PlatformConfig(
|
||||
enabled=True,
|
||||
token=wx_token,
|
||||
extra={
|
||||
"account_id": wx_account,
|
||||
"base_url": os.getenv("WEIXIN_BASE_URL", "").strip(),
|
||||
"cdn_base_url": os.getenv("WEIXIN_CDN_BASE_URL", "").strip(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
return tool_error(f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables.")
|
||||
else:
|
||||
return tool_error(f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables.")
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
||||
@@ -225,6 +245,12 @@ def _handle_send(args):
|
||||
used_home_channel = False
|
||||
if not chat_id:
|
||||
home = config.get_home_channel(platform)
|
||||
if not home and platform_name == "weixin":
|
||||
import os
|
||||
wx_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
|
||||
if wx_home:
|
||||
from gateway.config import HomeChannel
|
||||
home = HomeChannel(platform=platform, chat_id=wx_home, name="Weixin Home")
|
||||
if home:
|
||||
chat_id = home.chat_id
|
||||
used_home_channel = True
|
||||
@@ -1274,7 +1300,7 @@ async def _send_qqbot(pconfig, chat_id, message):
|
||||
|
||||
# Step 2: Send message via REST
|
||||
headers = {
|
||||
"Authorization": f"QQBotAccessToken {access_token}",
|
||||
"Authorization": f"QQBot {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"https://api.sgroup.qq.com/channels/{chat_id}/messages"
|
||||
|
||||
@@ -301,6 +301,104 @@ def sync_skills(quiet: bool = False) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def reset_bundled_skill(name: str, restore: bool = False) -> dict:
|
||||
"""
|
||||
Reset a bundled skill's manifest tracking so future syncs work normally.
|
||||
|
||||
When a user edits a bundled skill, subsequent syncs mark it as
|
||||
``user_modified`` and skip it forever — even if the user later copies
|
||||
the bundled version back into place, because the manifest still holds
|
||||
the *old* origin hash. This function breaks that loop.
|
||||
|
||||
Args:
|
||||
name: The skill name (matches the manifest key / skill frontmatter name).
|
||||
restore: If True, also delete the user's copy in SKILLS_DIR and let
|
||||
the next sync re-copy the current bundled version. If False
|
||||
(default), only clear the manifest entry — the user's
|
||||
current copy is preserved but future updates work again.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- ok: bool, whether the reset succeeded
|
||||
- action: one of "manifest_cleared", "restored", "not_in_manifest",
|
||||
"bundled_missing"
|
||||
- message: human-readable description
|
||||
- synced: dict from sync_skills() if a sync was triggered, else None
|
||||
"""
|
||||
manifest = _read_manifest()
|
||||
bundled_dir = _get_bundled_dir()
|
||||
bundled_skills = _discover_bundled_skills(bundled_dir)
|
||||
bundled_by_name = {skill_name: skill_dir for skill_name, skill_dir in bundled_skills}
|
||||
|
||||
in_manifest = name in manifest
|
||||
is_bundled = name in bundled_by_name
|
||||
|
||||
if not in_manifest and not is_bundled:
|
||||
return {
|
||||
"ok": False,
|
||||
"action": "not_in_manifest",
|
||||
"message": (
|
||||
f"'{name}' is not a tracked bundled skill. Nothing to reset. "
|
||||
f"(Hub-installed skills use `hermes skills uninstall`.)"
|
||||
),
|
||||
"synced": None,
|
||||
}
|
||||
|
||||
# Step 1: drop the manifest entry so next sync treats it as new
|
||||
if in_manifest:
|
||||
del manifest[name]
|
||||
_write_manifest(manifest)
|
||||
|
||||
# Step 2 (optional): delete the user's copy so next sync re-copies bundled
|
||||
deleted_user_copy = False
|
||||
if restore:
|
||||
if not is_bundled:
|
||||
return {
|
||||
"ok": False,
|
||||
"action": "bundled_missing",
|
||||
"message": (
|
||||
f"'{name}' has no bundled source — manifest entry cleared "
|
||||
f"but cannot restore from bundled (skill was removed upstream)."
|
||||
),
|
||||
"synced": None,
|
||||
}
|
||||
# The destination mirrors the bundled path relative to bundled_dir.
|
||||
dest = _compute_relative_dest(bundled_by_name[name], bundled_dir)
|
||||
if dest.exists():
|
||||
try:
|
||||
shutil.rmtree(dest)
|
||||
deleted_user_copy = True
|
||||
except (OSError, IOError) as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"action": "manifest_cleared",
|
||||
"message": (
|
||||
f"Cleared manifest entry for '{name}' but could not "
|
||||
f"delete user copy at {dest}: {e}"
|
||||
),
|
||||
"synced": None,
|
||||
}
|
||||
|
||||
# Step 3: run sync to re-baseline (or re-copy if we deleted)
|
||||
synced = sync_skills(quiet=True)
|
||||
|
||||
if restore and deleted_user_copy:
|
||||
action = "restored"
|
||||
message = f"Restored '{name}' from bundled source."
|
||||
elif restore:
|
||||
# Nothing on disk to delete, but we re-synced — acts like a fresh install
|
||||
action = "restored"
|
||||
message = f"Restored '{name}' (no prior user copy, re-copied from bundled)."
|
||||
else:
|
||||
action = "manifest_cleared"
|
||||
message = (
|
||||
f"Cleared manifest entry for '{name}'. Future `hermes update` runs "
|
||||
f"will re-baseline against your current copy and accept upstream changes."
|
||||
)
|
||||
|
||||
return {"ok": True, "action": action, "message": message, "synced": synced}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Syncing bundled skills into ~/.hermes/skills/ ...")
|
||||
result = sync_skills(quiet=False)
|
||||
|
||||
@@ -29,6 +29,13 @@ _BLOCKED_HOSTNAMES = frozenset({
|
||||
"metadata.goog",
|
||||
})
|
||||
|
||||
# Exact HTTPS hostnames allowed to resolve to private/benchmark-space IPs.
|
||||
# This is intentionally narrow: QQ media downloads can legitimately resolve
|
||||
# to 198.18.0.0/15 behind local proxy/benchmark infrastructure.
|
||||
_TRUSTED_PRIVATE_IP_HOSTS = frozenset({
|
||||
"multimedia.nt.qq.com.cn",
|
||||
})
|
||||
|
||||
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
|
||||
# ipaddress.is_private — it returns False for both is_private and is_global.
|
||||
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
|
||||
@@ -48,6 +55,11 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _allows_private_ip_resolution(hostname: str, scheme: str) -> bool:
|
||||
"""Return True when a trusted HTTPS hostname may bypass IP-class blocking."""
|
||||
return scheme == "https" and hostname in _TRUSTED_PRIVATE_IP_HOSTS
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Return True if the URL target is not a private/internal address.
|
||||
|
||||
@@ -56,7 +68,8 @@ def is_safe_url(url: str) -> bool:
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = (parsed.hostname or "").strip().lower()
|
||||
hostname = (parsed.hostname or "").strip().lower().rstrip(".")
|
||||
scheme = (parsed.scheme or "").strip().lower()
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
@@ -65,6 +78,8 @@ def is_safe_url(url: str) -> bool:
|
||||
logger.warning("Blocked request to internal hostname: %s", hostname)
|
||||
return False
|
||||
|
||||
allow_private_ip = _allows_private_ip_resolution(hostname, scheme)
|
||||
|
||||
# Try to resolve and check IP
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
@@ -81,13 +96,19 @@ def is_safe_url(url: str) -> bool:
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if _is_blocked_ip(ip):
|
||||
if not allow_private_ip and _is_blocked_ip(ip):
|
||||
logger.warning(
|
||||
"Blocked request to private/internal address: %s -> %s",
|
||||
hostname, ip_str,
|
||||
)
|
||||
return False
|
||||
|
||||
if allow_private_ip:
|
||||
logger.debug(
|
||||
"Allowing trusted hostname despite private/internal resolution: %s",
|
||||
hostname,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
|
||||
Reference in New Issue
Block a user