fix(context): honor custom_providers context_length on /model switch + bump probe tier to 256K (#15844)

Fixes #15779. Custom-provider per-model context_length (`custom_providers[].models.<id>.context_length`) is now honored across every resolution path, not just agent startup. Also adds 256K as the top probe tier and default fallback.

## What changed

New helper `hermes_cli.config.get_custom_provider_context_length()` — single source of truth for the per-model override lookup, with trailing-slash-insensitive base-url matching.

`agent.model_metadata.get_model_context_length()` gains an optional `custom_providers=` kwarg (step 0b — runs after explicit `config_context_length` but before every other probe).

Wired through five call sites that previously either duplicated the lookup or ignored it entirely:
- `run_agent.py` startup — refactored to use the new helper (dedups legacy inline loop, keeps invalid-value warning)
- `AIAgent.switch_model()` — re-reads custom_providers from live config on every /model switch
- `hermes_cli.model_switch.resolve_display_context_length()` — new `custom_providers=` kwarg
- `gateway/run.py` /model confirmation (picker callback + text path)
- `gateway/run.py` `_format_session_info` (/info)

## Context probe tiers

`CONTEXT_PROBE_TIERS = [256_000, 128_000, 64_000, 32_000, 16_000, 8_000]` — was `[128_000, ...]`. `DEFAULT_FALLBACK_CONTEXT` follows tier[0], so unknown models now default to 256K. The stale `128000` literal in the OpenRouter metadata-miss path is replaced with `DEFAULT_FALLBACK_CONTEXT` for consistency.

## Repro (from #15779)

```yaml
custom_providers:
  - name: my-custom-endpoint
    base_url: https://example.invalid/v1
    model: gpt-5.5
    models:
      gpt-5.5:
        context_length: 1050000
```

`/model gpt-5.5 --provider custom:my-custom-endpoint` → previously "Context: 128,000", now "Context: 1,050,000".

## Tests

- `tests/hermes_cli/test_custom_provider_context_length.py` — new file, 19 tests covering the helper, step-0b integration, and the 256K tier invariants
- `tests/hermes_cli/test_model_switch_context_display.py` — added regression tests for #15779 through the display resolver
- `tests/gateway/test_session_info.py` — updated default-fallback assertion (128K → 256K)
- `tests/agent/test_model_metadata.py` — updated tier assertions for the new top tier
This commit is contained in:
Teknium
2026-04-25 18:47:53 -07:00
committed by GitHub
parent 4c591c2819
commit 125de02056
9 changed files with 480 additions and 44 deletions

View File

@@ -106,9 +106,11 @@ _endpoint_model_metadata_cache_time: Dict[str, float] = {}
_ENDPOINT_MODEL_CACHE_TTL = 300
# Descending tiers for context length probing when the model is unknown.
# We start at 128K (a safe default for most modern models) and step down
# on context-length errors until one works.
# We start at 256K (covers GPT-5.x, many current large-context models) and
# step down on context-length errors until one works. Tier[0] is also the
# default fallback when no detection method succeeds.
CONTEXT_PROBE_TIERS = [
256_000,
128_000,
64_000,
32_000,
@@ -1193,6 +1195,7 @@ def get_model_context_length(
api_key: str = "",
config_context_length: int | None = None,
provider: str = "",
custom_providers: list | None = None,
) -> int:
"""Get the context length for a model.
@@ -1213,6 +1216,23 @@ def get_model_context_length(
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
return config_context_length
# 0b. custom_providers per-model override — check before any probe.
# This closes the gap where /model switch and display paths used to fall
# back to 128K despite the user having a per-model context_length set.
# See #15779.
if custom_providers and base_url and model:
try:
from hermes_cli.config import get_custom_provider_context_length
cp_ctx = get_custom_provider_context_length(
model=model,
base_url=base_url,
custom_providers=custom_providers,
)
if cp_ctx:
return cp_ctx
except Exception:
pass # fall through to probing
# Normalise provider-prefixed model names (e.g. "local:model-name" →
# "model-name") so cache lookups and server queries use the bare ID that
# local servers actually know about. Ollama "model:tag" colons are preserved.
@@ -1352,7 +1372,7 @@ def get_model_context_length(
# 6. OpenRouter live API metadata (provider-unaware fallback)
metadata = fetch_model_metadata()
if model in metadata:
return metadata[model].get("context_length", 128000)
return metadata[model].get("context_length", DEFAULT_FALLBACK_CONTEXT)
# 8. Hardcoded defaults (fuzzy match — longest key first for specificity)
# Only check `default_model in model` (is the key a substring of the input).

View File

@@ -4891,6 +4891,7 @@ class GatewayRunner:
provider = None
base_url = None
api_key = None
custom_provs = None
try:
cfg_path = _hermes_home / "config.yaml"
@@ -4908,6 +4909,11 @@ class GatewayRunner:
pass
provider = model_cfg.get("provider") or None
base_url = model_cfg.get("base_url") or None
try:
from hermes_cli.config import get_compatible_custom_providers
custom_provs = get_compatible_custom_providers(data)
except Exception:
custom_provs = data.get("custom_providers")
except Exception:
pass
@@ -4926,6 +4932,7 @@ class GatewayRunner:
api_key=api_key or "",
config_context_length=config_context_length,
provider=provider or "",
custom_providers=custom_provs,
)
# Format context source hint
@@ -5601,6 +5608,7 @@ class GatewayRunner:
base_url=result.base_url or current_base_url or "",
api_key=result.api_key or current_api_key or "",
model_info=mi,
custom_providers=custom_provs,
)
if ctx:
lines.append(f"Context: {ctx:,} tokens")
@@ -5748,6 +5756,7 @@ class GatewayRunner:
base_url=result.base_url or current_base_url or "",
api_key=result.api_key or current_api_key or "",
model_info=mi,
custom_providers=custom_provs,
)
if ctx:
lines.append(f"Context: {ctx:,} tokens")

View File

@@ -2206,6 +2206,71 @@ def get_compatible_custom_providers(
return compatible
def get_custom_provider_context_length(
model: str,
base_url: str,
custom_providers: Optional[List[Dict[str, Any]]] = None,
config: Optional[Dict[str, Any]] = None,
) -> Optional[int]:
"""Look up a per-model ``context_length`` override from ``custom_providers``.
Matches any entry whose ``base_url`` equals ``base_url`` (trailing-slash
insensitive) and returns ``custom_providers[i].models.<model>.context_length``
if present and valid. Returns ``None`` when no override applies.
This is the single source of truth for custom-provider context overrides,
used by:
* ``AIAgent.__init__`` (startup resolution)
* ``AIAgent.switch_model`` (mid-session ``/model`` switch)
* ``hermes_cli.model_switch.resolve_display_context_length`` (``/model`` confirmation display)
* ``gateway.run._format_session_info`` (``/info`` display)
* ``agent.model_metadata.get_model_context_length`` (when custom_providers is threaded through)
Before this helper existed, the lookup was duplicated in ``run_agent.py``'s
startup path only; every other path (notably ``/model`` switch) fell back
to the 128K default. See #15779.
"""
if not model or not base_url:
return None
if custom_providers is None:
try:
custom_providers = get_compatible_custom_providers(config)
except Exception:
if config is None:
return None
raw = config.get("custom_providers")
custom_providers = raw if isinstance(raw, list) else []
if not isinstance(custom_providers, list):
return None
target_url = (base_url or "").rstrip("/")
if not target_url:
return None
for entry in custom_providers:
if not isinstance(entry, dict):
continue
entry_url = (entry.get("base_url") or "").rstrip("/")
if not entry_url or entry_url != target_url:
continue
models = entry.get("models")
if not isinstance(models, dict):
continue
model_cfg = models.get(model)
if not isinstance(model_cfg, dict):
continue
raw_ctx = model_cfg.get("context_length")
if raw_ctx is None:
continue
try:
ctx = int(raw_ctx)
except (TypeError, ValueError):
continue
if ctx > 0:
return ctx
return None
def check_config_version() -> Tuple[int, int]:
"""
Check config version.

View File

@@ -533,6 +533,7 @@ def resolve_display_context_length(
base_url: str = "",
api_key: str = "",
model_info: Optional[ModelInfo] = None,
custom_providers: list | None = None,
) -> Optional[int]:
"""Resolve the context length to show in /model output.
@@ -543,6 +544,11 @@ def resolve_display_context_length(
about Codex OAuth, Copilot, Nous, and falls back to models.dev for the
rest.
When ``custom_providers`` is provided, per-model ``context_length``
overrides from ``custom_providers[].models.<id>.context_length`` are
honored — this closes #15779 where ``/model`` switch ignored user-set
overrides.
Prefer the provider-aware value; fall back to ``model_info.context_window``
only if the resolver returns nothing.
"""
@@ -553,6 +559,7 @@ def resolve_display_context_length(
base_url=base_url or "",
api_key=api_key or "",
provider=provider or None,
custom_providers=custom_providers,
)
if ctx:
return int(ctx)

View File

@@ -1765,43 +1765,64 @@ class AIAgent:
# Store for reuse in switch_model (so config override persists across model switches)
self._config_context_length = _config_context_length
# Resolve custom_providers list once for reuse below (startup
# context-length override and plugin context-engine init).
try:
from hermes_cli.config import get_compatible_custom_providers
_custom_providers = get_compatible_custom_providers(_agent_cfg)
except Exception:
_custom_providers = _agent_cfg.get("custom_providers")
if not isinstance(_custom_providers, list):
_custom_providers = []
# Check custom_providers per-model context_length
if _config_context_length is None:
if _config_context_length is None and _custom_providers:
try:
from hermes_cli.config import get_compatible_custom_providers
_custom_providers = get_compatible_custom_providers(_agent_cfg)
from hermes_cli.config import get_custom_provider_context_length
_cp_ctx_resolved = get_custom_provider_context_length(
model=self.model,
base_url=self.base_url,
custom_providers=_custom_providers,
)
if _cp_ctx_resolved:
_config_context_length = int(_cp_ctx_resolved)
except Exception:
_custom_providers = _agent_cfg.get("custom_providers")
if not isinstance(_custom_providers, list):
_custom_providers = []
for _cp_entry in _custom_providers:
if not isinstance(_cp_entry, dict):
continue
_cp_url = (_cp_entry.get("base_url") or "").rstrip("/")
if _cp_url and _cp_url == self.base_url.rstrip("/"):
_cp_models = _cp_entry.get("models", {})
if isinstance(_cp_models, dict):
_cp_model_cfg = _cp_models.get(self.model, {})
if isinstance(_cp_model_cfg, dict):
_cp_ctx = _cp_model_cfg.get("context_length")
if _cp_ctx is not None:
try:
_config_context_length = int(_cp_ctx)
except (TypeError, ValueError):
logger.warning(
"Invalid context_length for model %r in "
"custom_providers: %r — must be a plain "
"integer (e.g. 256000, not '256K'). "
"Falling back to auto-detection.",
self.model, _cp_ctx,
)
print(
f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n"
f" Must be a plain integer (e.g. 256000, not '256K').\n"
f" Falling back to auto-detected context window.\n",
file=sys.stderr,
)
break
_cp_ctx_resolved = None
# Surface a clear warning if the user set a context_length but it
# wasn't a valid positive int — the helper silently skips those.
if _config_context_length is None:
_target = self.base_url.rstrip("/") if self.base_url else ""
for _cp_entry in _custom_providers:
if not isinstance(_cp_entry, dict):
continue
_cp_url = (_cp_entry.get("base_url") or "").rstrip("/")
if _target and _cp_url == _target:
_cp_models = _cp_entry.get("models", {})
if isinstance(_cp_models, dict):
_cp_model_cfg = _cp_models.get(self.model, {})
if isinstance(_cp_model_cfg, dict):
_cp_ctx = _cp_model_cfg.get("context_length")
if _cp_ctx is not None:
try:
_parsed = int(_cp_ctx)
if _parsed <= 0:
raise ValueError
except (TypeError, ValueError):
logger.warning(
"Invalid context_length for model %r in "
"custom_providers: %r — must be a positive "
"integer (e.g. 256000, not '256K'). "
"Falling back to auto-detection.",
self.model, _cp_ctx,
)
print(
f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n"
f" Must be a positive integer (e.g. 256000, not '256K').\n"
f" Falling back to auto-detected context window.\n",
file=sys.stderr,
)
break
# Select context engine: config-driven (like memory providers).
# 1. Check config.yaml context.engine setting
@@ -1851,6 +1872,7 @@ class AIAgent:
api_key=getattr(self, "api_key", ""),
config_context_length=_config_context_length,
provider=self.provider,
custom_providers=_custom_providers,
)
self.context_compressor.update_model(
model=self.model,
@@ -2141,12 +2163,23 @@ class AIAgent:
# ── Update context compressor ──
if hasattr(self, "context_compressor") and self.context_compressor:
from agent.model_metadata import get_model_context_length
# Re-read custom_providers from live config so per-model
# context_length overrides are honored when switching to a
# custom provider mid-session (closes #15779).
_sm_custom_providers = None
try:
from hermes_cli.config import load_config, get_compatible_custom_providers
_sm_cfg = load_config()
_sm_custom_providers = get_compatible_custom_providers(_sm_cfg)
except Exception:
_sm_custom_providers = None
new_context_length = get_model_context_length(
self.model,
base_url=self.base_url,
api_key=self.api_key,
provider=self.provider,
config_context_length=getattr(self, "_config_context_length", None),
custom_providers=_sm_custom_providers,
)
self.context_compressor.update_model(
model=self.model,

View File

@@ -459,9 +459,10 @@ class TestGetModelContextLength:
@patch("agent.model_metadata.fetch_model_metadata")
def test_api_missing_context_length_key(self, mock_fetch):
"""Model in API but without context_length → defaults to 128000."""
"""Model in API but without context_length → defaults to the top
probe tier (currently 256K)."""
mock_fetch.return_value = {"test/model": {"name": "Test"}}
assert get_model_context_length("test/model") == 128000
assert get_model_context_length("test/model") == CONTEXT_PROBE_TIERS[0]
@patch("agent.model_metadata.fetch_model_metadata")
def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path):
@@ -814,14 +815,17 @@ class TestContextProbeTiers:
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
def test_first_tier_is_128k(self):
assert CONTEXT_PROBE_TIERS[0] == 128_000
def test_first_tier_is_256k(self):
assert CONTEXT_PROBE_TIERS[0] == 256_000
def test_last_tier_is_8k(self):
assert CONTEXT_PROBE_TIERS[-1] == 8_000
class TestGetNextProbeTier:
def test_from_256k(self):
assert get_next_probe_tier(256_000) == 128_000
def test_from_128k(self):
assert get_next_probe_tier(128_000) == 64_000
@@ -841,8 +845,8 @@ class TestGetNextProbeTier:
assert get_next_probe_tier(100_000) == 64_000
def test_above_max_tier(self):
"""Value above 128K should return 128K."""
assert get_next_probe_tier(500_000) == 128_000
"""Value above 256K should return 256K."""
assert get_next_probe_tier(500_000) == 256_000
def test_zero_returns_none(self):
assert get_next_probe_tier(0) is None

View File

@@ -58,7 +58,7 @@ class TestFormatSessionInfo:
{"provider": "", "base_url": "", "api_key": ""})
with p1, p2, p3:
info = runner._format_session_info()
assert "128K" in info
assert "256K" in info
assert "model.context_length" in info
def test_local_endpoint_shown(self, runner, tmp_path):

View File

@@ -0,0 +1,240 @@
"""Regression tests for custom_providers per-model context_length resolution.
Covers the fix for #15779 — mid-session /model switch to a named custom
provider must honor ``custom_providers[].models.<id>.context_length`` the
same way startup already does.
"""
from __future__ import annotations
from unittest.mock import patch
from hermes_cli.config import get_custom_provider_context_length
class TestGetCustomProviderContextLength:
def test_returns_override_for_matching_entry(self):
custom = [
{
"name": "my-endpoint",
"base_url": "https://example.invalid/v1",
"models": {"gpt-5.5": {"context_length": 1_050_000}},
}
]
assert (
get_custom_provider_context_length(
"gpt-5.5", "https://example.invalid/v1", custom
)
== 1_050_000
)
def test_trailing_slash_insensitive(self):
custom = [
{
"base_url": "https://example.invalid/v1/",
"models": {"m": {"context_length": 500_000}},
}
]
# config has trailing slash, runtime doesn't — must match
assert (
get_custom_provider_context_length(
"m", "https://example.invalid/v1", custom
)
== 500_000
)
# and the reverse
custom2 = [
{
"base_url": "https://example.invalid/v1",
"models": {"m": {"context_length": 500_000}},
}
]
assert (
get_custom_provider_context_length(
"m", "https://example.invalid/v1/", custom2
)
== 500_000
)
def test_returns_none_when_url_does_not_match(self):
custom = [
{
"base_url": "https://example.invalid/v1",
"models": {"m": {"context_length": 400_000}},
}
]
assert (
get_custom_provider_context_length(
"m", "https://other.invalid/v1", custom
)
is None
)
def test_returns_none_when_model_does_not_match(self):
custom = [
{
"base_url": "https://example.invalid/v1",
"models": {"gpt-5.5": {"context_length": 400_000}},
}
]
assert (
get_custom_provider_context_length(
"different-model", "https://example.invalid/v1", custom
)
is None
)
def test_returns_none_for_string_value(self):
"""'256K' string is not a valid int — skip silently.
(The inline startup path still emits a user-visible warning; the
helper itself returns None so downstream fallbacks can run.)
"""
custom = [
{
"base_url": "https://example.invalid/v1",
"models": {"m": {"context_length": "256K"}},
}
]
assert (
get_custom_provider_context_length(
"m", "https://example.invalid/v1", custom
)
is None
)
def test_returns_none_for_zero_or_negative(self):
for bad in (0, -1, -100):
custom = [
{
"base_url": "https://example.invalid/v1",
"models": {"m": {"context_length": bad}},
}
]
assert (
get_custom_provider_context_length(
"m", "https://example.invalid/v1", custom
)
is None
), f"value {bad!r} should be rejected"
def test_empty_inputs_return_none(self):
assert get_custom_provider_context_length("", "http://x", [{"base_url": "http://x", "models": {"": {"context_length": 1}}}]) is None
assert get_custom_provider_context_length("m", "", [{"base_url": "", "models": {"m": {"context_length": 1}}}]) is None
assert get_custom_provider_context_length("m", "http://x", None) is None
assert get_custom_provider_context_length("m", "http://x", []) is None
def test_ignores_non_dict_entries(self):
"""Malformed entries must not crash the lookup."""
custom = [
"not a dict",
None,
{"base_url": "https://example.invalid/v1", "models": "not a dict"},
{"base_url": "https://example.invalid/v1", "models": {"m": "not a dict"}},
{
"base_url": "https://example.invalid/v1",
"models": {"m": {"context_length": 400_000}},
},
]
assert (
get_custom_provider_context_length(
"m", "https://example.invalid/v1", custom
)
== 400_000
)
class TestGetModelContextLengthHonorsOverride:
"""agent.model_metadata.get_model_context_length must honor the
custom_providers override at step 0b — before any probe, cache hit,
or models.dev lookup can override it.
"""
def _mock_all_probes(self):
"""Context manager that disables every downstream resolution step."""
from agent import model_metadata as _mm
return [
patch.object(_mm, "get_cached_context_length", return_value=None),
patch.object(_mm, "fetch_endpoint_model_metadata", return_value={}),
patch.object(_mm, "fetch_model_metadata", return_value={}),
patch.object(_mm, "is_local_endpoint", return_value=False),
patch.object(_mm, "_is_known_provider_base_url", return_value=False),
]
def test_custom_providers_override_wins_over_default_fallback(self):
from agent.model_metadata import get_model_context_length
custom = [
{
"base_url": "https://example.invalid/v1",
"models": {"gpt-5.5": {"context_length": 1_050_000}},
}
]
patches = self._mock_all_probes()
for p in patches:
p.start()
try:
ctx = get_model_context_length(
"gpt-5.5",
base_url="https://example.invalid/v1",
provider="custom",
custom_providers=custom,
)
finally:
for p in patches:
p.stop()
assert ctx == 1_050_000
def test_explicit_config_context_length_still_wins(self):
"""Top-level model.context_length (step 0) outranks custom_providers (step 0b).
Users who set both should see the top-level value — that's the
documented precedence and matches the long-standing step-0 behavior.
"""
from agent.model_metadata import get_model_context_length
custom = [
{
"base_url": "https://example.invalid/v1",
"models": {"m": {"context_length": 1_050_000}},
}
]
ctx = get_model_context_length(
"m",
base_url="https://example.invalid/v1",
provider="custom",
config_context_length=500_000, # explicit top-level wins
custom_providers=custom,
)
assert ctx == 500_000
def test_no_override_falls_through_to_default(self):
"""With custom_providers=None and all probes disabled, resolver
returns DEFAULT_FALLBACK_CONTEXT (256K after the stepdown bump).
"""
from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT
patches = self._mock_all_probes()
for p in patches:
p.start()
try:
ctx = get_model_context_length(
"unknown-model",
base_url="https://example.invalid/v1",
provider="custom",
custom_providers=None,
)
finally:
for p in patches:
p.stop()
assert ctx == DEFAULT_FALLBACK_CONTEXT
class TestContextProbeTiers:
def test_256k_is_top_tier_and_default(self):
"""The stepdown probe starts at 256K and 256K is the new default."""
from agent.model_metadata import CONTEXT_PROBE_TIERS, DEFAULT_FALLBACK_CONTEXT
assert CONTEXT_PROBE_TIERS[0] == 256_000
assert DEFAULT_FALLBACK_CONTEXT == 256_000
# Tiers still descend monotonically
for a, b in zip(CONTEXT_PROBE_TIERS, CONTEXT_PROBE_TIERS[1:]):
assert a > b, f"tiers must strictly descend, got {a} then {b}"
# 128K is still a tier (users relying on it probe-down get there)
assert 128_000 in CONTEXT_PROBE_TIERS

View File

@@ -88,3 +88,61 @@ class TestResolveDisplayContextLength:
model_info=fake_mi,
)
assert ctx == 128_000
def test_custom_providers_override_honored(self):
"""Regression for #15779: /model switch onto a custom provider must
surface the configured per-model context_length, not the 128K/256K
fallback.
"""
custom_provs = [
{
"name": "my-custom-endpoint",
"base_url": "https://example.invalid/v1",
"models": {"gpt-5.5": {"context_length": 1_050_000}},
}
]
# Real resolver call — no mock — so the override path is exercised
# through agent.model_metadata.get_model_context_length.
from unittest.mock import patch as _p
from agent import model_metadata as _mm
with _p.object(_mm, "get_cached_context_length", return_value=None), \
_p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \
_p.object(_mm, "fetch_model_metadata", return_value={}), \
_p.object(_mm, "is_local_endpoint", return_value=False), \
_p.object(_mm, "_is_known_provider_base_url", return_value=False):
ctx = resolve_display_context_length(
"gpt-5.5",
"custom",
base_url="https://example.invalid/v1",
api_key="k",
custom_providers=custom_provs,
)
assert ctx == 1_050_000, (
"custom_providers[].models.gpt-5.5.context_length=1.05M must win "
"over probe-down fallback"
)
def test_custom_providers_trailing_slash_insensitive(self):
"""Base URL comparison must tolerate trailing-slash differences
between config.yaml and the runtime value.
"""
custom_provs = [
{
"base_url": "https://example.invalid/v1/",
"models": {"m": {"context_length": 400_000}},
}
]
from unittest.mock import patch as _p
from agent import model_metadata as _mm
with _p.object(_mm, "get_cached_context_length", return_value=None), \
_p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \
_p.object(_mm, "fetch_model_metadata", return_value={}), \
_p.object(_mm, "is_local_endpoint", return_value=False), \
_p.object(_mm, "_is_known_provider_base_url", return_value=False):
ctx = resolve_display_context_length(
"m",
"custom",
base_url="https://example.invalid/v1", # no trailing slash
custom_providers=custom_provs,
)
assert ctx == 400_000