fix(tui): share static model detection

This commit is contained in:
Brooklyn Nicholson
2026-04-25 13:56:16 -05:00
parent 2dfcc8087a
commit e48a497d16
3 changed files with 66 additions and 102 deletions

View File

@@ -1379,27 +1379,35 @@ def curated_models_for_provider(
return [(m, "") for m in models]
def detect_provider_for_model(
def _provider_keys(provider: str) -> set[str]:
key = (provider or "").strip().lower()
normalized = normalize_provider(provider)
return {k for k in (key, normalized) if k}
def _model_in_provider_catalog(name_lower: str, providers: set[str]) -> bool:
return any(
name_lower == model.lower()
for provider in providers
for model in _PROVIDER_MODELS.get(provider, [])
)
def detect_static_provider_for_model(
model_name: str,
current_provider: str,
) -> Optional[tuple[str, str]]:
"""Auto-detect the best provider for a model name.
"""Auto-detect a provider from static catalogs only.
Returns ``(provider_id, model_name)`` — the model name may be remapped
(e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter).
Returns ``None`` when no confident match is found.
Priority:
0. Bare provider name → switch to that provider's default model
1. Direct provider with credentials (highest)
2. Direct provider without credentials → remap to OpenRouter slug
3. OpenRouter catalog match
"""
name = (model_name or "").strip()
if not name:
return None
name_lower = name.lower()
current_keys = _provider_keys(current_provider)
# --- Step 0: bare provider name typed as model ---
# If someone types `/model nous` or `/model anthropic`, treat it as a
@@ -1412,7 +1420,7 @@ def detect_provider_for_model(
if (
resolved_provider in _PROVIDER_LABELS
and default_models
and resolved_provider != normalize_provider(current_provider)
and resolved_provider not in current_keys
):
return (resolved_provider, default_models[0])
@@ -1420,56 +1428,43 @@ def detect_provider_for_model(
_AGGREGATORS = {"nous", "openrouter", "ai-gateway", "copilot", "kilocode"}
# If the model belongs to the current provider's catalog, don't suggest switching
current_models = _PROVIDER_MODELS.get(current_provider, [])
if any(name_lower == m.lower() for m in current_models):
if _model_in_provider_catalog(name_lower, current_keys):
return None
# --- Step 1: check static provider catalogs for a direct match ---
direct_match: Optional[str] = None
for pid, models in _PROVIDER_MODELS.items():
if pid == current_provider or pid in _AGGREGATORS:
if pid in current_keys or pid in _AGGREGATORS:
continue
if any(name_lower == m.lower() for m in models):
direct_match = pid
break
return (pid, name)
if direct_match:
# Check if we have credentials for this provider — env vars,
# credential pool, or auth store entries.
has_creds = False
try:
from hermes_cli.auth import PROVIDER_REGISTRY
pconfig = PROVIDER_REGISTRY.get(direct_match)
if pconfig:
for env_var in pconfig.api_key_env_vars:
if os.getenv(env_var, "").strip():
has_creds = True
break
except Exception:
pass
# Also check credential pool and auth store — covers OAuth,
# Claude Code tokens, and other non-env-var credentials (#10300).
if not has_creds:
try:
from agent.credential_pool import load_pool
pool = load_pool(direct_match)
if pool.has_credentials():
has_creds = True
except Exception:
pass
if not has_creds:
try:
from hermes_cli.auth import _load_auth_store
store = _load_auth_store()
if direct_match in store.get("providers", {}) or direct_match in store.get("credential_pool", {}):
has_creds = True
except Exception:
pass
return None
# Always return the direct provider match. If credentials are
# missing, the client init will give a clear error rather than
# silently routing through the wrong provider (#10300).
return (direct_match, name)
def detect_provider_for_model(
model_name: str,
current_provider: str,
) -> Optional[tuple[str, str]]:
"""Auto-detect the best provider for a model name.
Returns ``(provider_id, model_name)`` — the model name may be remapped
(e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter).
Returns ``None`` when no confident match is found.
Priority:
0. Bare provider name → switch to that provider's default model
1. Direct provider static catalog match
2. OpenRouter catalog match
"""
name = (model_name or "").strip()
if not name:
return None
static_match = detect_static_provider_for_model(name, current_provider)
if static_match:
return static_match
if _model_in_provider_catalog(name.lower(), _provider_keys(current_provider)):
return None
# --- Step 2: check OpenRouter catalog ---
# First try exact match (handles provider/model format)