fix(context): honor custom_providers context_length on /model switch + bump probe tier to 256K (#15844)
Fixes #15779. Custom-provider per-model context_length (`custom_providers[].models.<id>.context_length`) is now honored across every resolution path, not just agent startup. Also adds 256K as the top probe tier and default fallback. ## What changed New helper `hermes_cli.config.get_custom_provider_context_length()` — single source of truth for the per-model override lookup, with trailing-slash-insensitive base-url matching. `agent.model_metadata.get_model_context_length()` gains an optional `custom_providers=` kwarg (step 0b — runs after explicit `config_context_length` but before every other probe). Wired through five call sites that previously either duplicated the lookup or ignored it entirely: - `run_agent.py` startup — refactored to use the new helper (dedups legacy inline loop, keeps invalid-value warning) - `AIAgent.switch_model()` — re-reads custom_providers from live config on every /model switch - `hermes_cli.model_switch.resolve_display_context_length()` — new `custom_providers=` kwarg - `gateway/run.py` /model confirmation (picker callback + text path) - `gateway/run.py` `_format_session_info` (/info) ## Context probe tiers `CONTEXT_PROBE_TIERS = [256_000, 128_000, 64_000, 32_000, 16_000, 8_000]` — was `[128_000, ...]`. `DEFAULT_FALLBACK_CONTEXT` follows tier[0], so unknown models now default to 256K. The stale `128000` literal in the OpenRouter metadata-miss path is replaced with `DEFAULT_FALLBACK_CONTEXT` for consistency. ## Repro (from #15779) ```yaml custom_providers: - name: my-custom-endpoint base_url: https://example.invalid/v1 model: gpt-5.5 models: gpt-5.5: context_length: 1050000 ``` `/model gpt-5.5 --provider custom:my-custom-endpoint` → previously "Context: 128,000", now "Context: 1,050,000". ## Tests - `tests/hermes_cli/test_custom_provider_context_length.py` — new file, 19 tests covering the helper, step-0b integration, and the 256K tier invariants - `tests/hermes_cli/test_model_switch_context_display.py` — added regression tests for #15779 through the display resolver - `tests/gateway/test_session_info.py` — updated default-fallback assertion (128K → 256K) - `tests/agent/test_model_metadata.py` — updated tier assertions for the new top tier
This commit is contained in:
@@ -106,9 +106,11 @@ _endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
||||
_ENDPOINT_MODEL_CACHE_TTL = 300
|
||||
|
||||
# Descending tiers for context length probing when the model is unknown.
|
||||
# We start at 128K (a safe default for most modern models) and step down
|
||||
# on context-length errors until one works.
|
||||
# We start at 256K (covers GPT-5.x, many current large-context models) and
|
||||
# step down on context-length errors until one works. Tier[0] is also the
|
||||
# default fallback when no detection method succeeds.
|
||||
CONTEXT_PROBE_TIERS = [
|
||||
256_000,
|
||||
128_000,
|
||||
64_000,
|
||||
32_000,
|
||||
@@ -1193,6 +1195,7 @@ def get_model_context_length(
|
||||
api_key: str = "",
|
||||
config_context_length: int | None = None,
|
||||
provider: str = "",
|
||||
custom_providers: list | None = None,
|
||||
) -> int:
|
||||
"""Get the context length for a model.
|
||||
|
||||
@@ -1213,6 +1216,23 @@ def get_model_context_length(
|
||||
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
||||
return config_context_length
|
||||
|
||||
# 0b. custom_providers per-model override — check before any probe.
|
||||
# This closes the gap where /model switch and display paths used to fall
|
||||
# back to 128K despite the user having a per-model context_length set.
|
||||
# See #15779.
|
||||
if custom_providers and base_url and model:
|
||||
try:
|
||||
from hermes_cli.config import get_custom_provider_context_length
|
||||
cp_ctx = get_custom_provider_context_length(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
custom_providers=custom_providers,
|
||||
)
|
||||
if cp_ctx:
|
||||
return cp_ctx
|
||||
except Exception:
|
||||
pass # fall through to probing
|
||||
|
||||
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
||||
# "model-name") so cache lookups and server queries use the bare ID that
|
||||
# local servers actually know about. Ollama "model:tag" colons are preserved.
|
||||
@@ -1352,7 +1372,7 @@ def get_model_context_length(
|
||||
# 6. OpenRouter live API metadata (provider-unaware fallback)
|
||||
metadata = fetch_model_metadata()
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
return metadata[model].get("context_length", DEFAULT_FALLBACK_CONTEXT)
|
||||
|
||||
# 8. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
# Only check `default_model in model` (is the key a substring of the input).
|
||||
|
||||
@@ -4891,6 +4891,7 @@ class GatewayRunner:
|
||||
provider = None
|
||||
base_url = None
|
||||
api_key = None
|
||||
custom_provs = None
|
||||
|
||||
try:
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
@@ -4908,6 +4909,11 @@ class GatewayRunner:
|
||||
pass
|
||||
provider = model_cfg.get("provider") or None
|
||||
base_url = model_cfg.get("base_url") or None
|
||||
try:
|
||||
from hermes_cli.config import get_compatible_custom_providers
|
||||
custom_provs = get_compatible_custom_providers(data)
|
||||
except Exception:
|
||||
custom_provs = data.get("custom_providers")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -4926,6 +4932,7 @@ class GatewayRunner:
|
||||
api_key=api_key or "",
|
||||
config_context_length=config_context_length,
|
||||
provider=provider or "",
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
|
||||
# Format context source hint
|
||||
@@ -5601,6 +5608,7 @@ class GatewayRunner:
|
||||
base_url=result.base_url or current_base_url or "",
|
||||
api_key=result.api_key or current_api_key or "",
|
||||
model_info=mi,
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
if ctx:
|
||||
lines.append(f"Context: {ctx:,} tokens")
|
||||
@@ -5748,6 +5756,7 @@ class GatewayRunner:
|
||||
base_url=result.base_url or current_base_url or "",
|
||||
api_key=result.api_key or current_api_key or "",
|
||||
model_info=mi,
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
if ctx:
|
||||
lines.append(f"Context: {ctx:,} tokens")
|
||||
|
||||
@@ -2206,6 +2206,71 @@ def get_compatible_custom_providers(
|
||||
return compatible
|
||||
|
||||
|
||||
def get_custom_provider_context_length(
|
||||
model: str,
|
||||
base_url: str,
|
||||
custom_providers: Optional[List[Dict[str, Any]]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[int]:
|
||||
"""Look up a per-model ``context_length`` override from ``custom_providers``.
|
||||
|
||||
Matches any entry whose ``base_url`` equals ``base_url`` (trailing-slash
|
||||
insensitive) and returns ``custom_providers[i].models.<model>.context_length``
|
||||
if present and valid. Returns ``None`` when no override applies.
|
||||
|
||||
This is the single source of truth for custom-provider context overrides,
|
||||
used by:
|
||||
* ``AIAgent.__init__`` (startup resolution)
|
||||
* ``AIAgent.switch_model`` (mid-session ``/model`` switch)
|
||||
* ``hermes_cli.model_switch.resolve_display_context_length`` (``/model`` confirmation display)
|
||||
* ``gateway.run._format_session_info`` (``/info`` display)
|
||||
* ``agent.model_metadata.get_model_context_length`` (when custom_providers is threaded through)
|
||||
|
||||
Before this helper existed, the lookup was duplicated in ``run_agent.py``'s
|
||||
startup path only; every other path (notably ``/model`` switch) fell back
|
||||
to the 128K default. See #15779.
|
||||
"""
|
||||
if not model or not base_url:
|
||||
return None
|
||||
if custom_providers is None:
|
||||
try:
|
||||
custom_providers = get_compatible_custom_providers(config)
|
||||
except Exception:
|
||||
if config is None:
|
||||
return None
|
||||
raw = config.get("custom_providers")
|
||||
custom_providers = raw if isinstance(raw, list) else []
|
||||
if not isinstance(custom_providers, list):
|
||||
return None
|
||||
|
||||
target_url = (base_url or "").rstrip("/")
|
||||
if not target_url:
|
||||
return None
|
||||
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
entry_url = (entry.get("base_url") or "").rstrip("/")
|
||||
if not entry_url or entry_url != target_url:
|
||||
continue
|
||||
models = entry.get("models")
|
||||
if not isinstance(models, dict):
|
||||
continue
|
||||
model_cfg = models.get(model)
|
||||
if not isinstance(model_cfg, dict):
|
||||
continue
|
||||
raw_ctx = model_cfg.get("context_length")
|
||||
if raw_ctx is None:
|
||||
continue
|
||||
try:
|
||||
ctx = int(raw_ctx)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if ctx > 0:
|
||||
return ctx
|
||||
return None
|
||||
|
||||
|
||||
def check_config_version() -> Tuple[int, int]:
|
||||
"""
|
||||
Check config version.
|
||||
|
||||
@@ -533,6 +533,7 @@ def resolve_display_context_length(
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
custom_providers: list | None = None,
|
||||
) -> Optional[int]:
|
||||
"""Resolve the context length to show in /model output.
|
||||
|
||||
@@ -543,6 +544,11 @@ def resolve_display_context_length(
|
||||
about Codex OAuth, Copilot, Nous, and falls back to models.dev for the
|
||||
rest.
|
||||
|
||||
When ``custom_providers`` is provided, per-model ``context_length``
|
||||
overrides from ``custom_providers[].models.<id>.context_length`` are
|
||||
honored — this closes #15779 where ``/model`` switch ignored user-set
|
||||
overrides.
|
||||
|
||||
Prefer the provider-aware value; fall back to ``model_info.context_window``
|
||||
only if the resolver returns nothing.
|
||||
"""
|
||||
@@ -553,6 +559,7 @@ def resolve_display_context_length(
|
||||
base_url=base_url or "",
|
||||
api_key=api_key or "",
|
||||
provider=provider or None,
|
||||
custom_providers=custom_providers,
|
||||
)
|
||||
if ctx:
|
||||
return int(ctx)
|
||||
|
||||
101
run_agent.py
101
run_agent.py
@@ -1765,43 +1765,64 @@ class AIAgent:
|
||||
# Store for reuse in switch_model (so config override persists across model switches)
|
||||
self._config_context_length = _config_context_length
|
||||
|
||||
# Resolve custom_providers list once for reuse below (startup
|
||||
# context-length override and plugin context-engine init).
|
||||
try:
|
||||
from hermes_cli.config import get_compatible_custom_providers
|
||||
_custom_providers = get_compatible_custom_providers(_agent_cfg)
|
||||
except Exception:
|
||||
_custom_providers = _agent_cfg.get("custom_providers")
|
||||
if not isinstance(_custom_providers, list):
|
||||
_custom_providers = []
|
||||
|
||||
# Check custom_providers per-model context_length
|
||||
if _config_context_length is None:
|
||||
if _config_context_length is None and _custom_providers:
|
||||
try:
|
||||
from hermes_cli.config import get_compatible_custom_providers
|
||||
_custom_providers = get_compatible_custom_providers(_agent_cfg)
|
||||
from hermes_cli.config import get_custom_provider_context_length
|
||||
_cp_ctx_resolved = get_custom_provider_context_length(
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
custom_providers=_custom_providers,
|
||||
)
|
||||
if _cp_ctx_resolved:
|
||||
_config_context_length = int(_cp_ctx_resolved)
|
||||
except Exception:
|
||||
_custom_providers = _agent_cfg.get("custom_providers")
|
||||
if not isinstance(_custom_providers, list):
|
||||
_custom_providers = []
|
||||
for _cp_entry in _custom_providers:
|
||||
if not isinstance(_cp_entry, dict):
|
||||
continue
|
||||
_cp_url = (_cp_entry.get("base_url") or "").rstrip("/")
|
||||
if _cp_url and _cp_url == self.base_url.rstrip("/"):
|
||||
_cp_models = _cp_entry.get("models", {})
|
||||
if isinstance(_cp_models, dict):
|
||||
_cp_model_cfg = _cp_models.get(self.model, {})
|
||||
if isinstance(_cp_model_cfg, dict):
|
||||
_cp_ctx = _cp_model_cfg.get("context_length")
|
||||
if _cp_ctx is not None:
|
||||
try:
|
||||
_config_context_length = int(_cp_ctx)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid context_length for model %r in "
|
||||
"custom_providers: %r — must be a plain "
|
||||
"integer (e.g. 256000, not '256K'). "
|
||||
"Falling back to auto-detection.",
|
||||
self.model, _cp_ctx,
|
||||
)
|
||||
print(
|
||||
f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n"
|
||||
f" Must be a plain integer (e.g. 256000, not '256K').\n"
|
||||
f" Falling back to auto-detected context window.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
break
|
||||
_cp_ctx_resolved = None
|
||||
|
||||
# Surface a clear warning if the user set a context_length but it
|
||||
# wasn't a valid positive int — the helper silently skips those.
|
||||
if _config_context_length is None:
|
||||
_target = self.base_url.rstrip("/") if self.base_url else ""
|
||||
for _cp_entry in _custom_providers:
|
||||
if not isinstance(_cp_entry, dict):
|
||||
continue
|
||||
_cp_url = (_cp_entry.get("base_url") or "").rstrip("/")
|
||||
if _target and _cp_url == _target:
|
||||
_cp_models = _cp_entry.get("models", {})
|
||||
if isinstance(_cp_models, dict):
|
||||
_cp_model_cfg = _cp_models.get(self.model, {})
|
||||
if isinstance(_cp_model_cfg, dict):
|
||||
_cp_ctx = _cp_model_cfg.get("context_length")
|
||||
if _cp_ctx is not None:
|
||||
try:
|
||||
_parsed = int(_cp_ctx)
|
||||
if _parsed <= 0:
|
||||
raise ValueError
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid context_length for model %r in "
|
||||
"custom_providers: %r — must be a positive "
|
||||
"integer (e.g. 256000, not '256K'). "
|
||||
"Falling back to auto-detection.",
|
||||
self.model, _cp_ctx,
|
||||
)
|
||||
print(
|
||||
f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n"
|
||||
f" Must be a positive integer (e.g. 256000, not '256K').\n"
|
||||
f" Falling back to auto-detected context window.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
break
|
||||
|
||||
# Select context engine: config-driven (like memory providers).
|
||||
# 1. Check config.yaml context.engine setting
|
||||
@@ -1851,6 +1872,7 @@ class AIAgent:
|
||||
api_key=getattr(self, "api_key", ""),
|
||||
config_context_length=_config_context_length,
|
||||
provider=self.provider,
|
||||
custom_providers=_custom_providers,
|
||||
)
|
||||
self.context_compressor.update_model(
|
||||
model=self.model,
|
||||
@@ -2141,12 +2163,23 @@ class AIAgent:
|
||||
# ── Update context compressor ──
|
||||
if hasattr(self, "context_compressor") and self.context_compressor:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
# Re-read custom_providers from live config so per-model
|
||||
# context_length overrides are honored when switching to a
|
||||
# custom provider mid-session (closes #15779).
|
||||
_sm_custom_providers = None
|
||||
try:
|
||||
from hermes_cli.config import load_config, get_compatible_custom_providers
|
||||
_sm_cfg = load_config()
|
||||
_sm_custom_providers = get_compatible_custom_providers(_sm_cfg)
|
||||
except Exception:
|
||||
_sm_custom_providers = None
|
||||
new_context_length = get_model_context_length(
|
||||
self.model,
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
provider=self.provider,
|
||||
config_context_length=getattr(self, "_config_context_length", None),
|
||||
custom_providers=_sm_custom_providers,
|
||||
)
|
||||
self.context_compressor.update_model(
|
||||
model=self.model,
|
||||
|
||||
@@ -459,9 +459,10 @@ class TestGetModelContextLength:
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_api_missing_context_length_key(self, mock_fetch):
|
||||
"""Model in API but without context_length → defaults to 128000."""
|
||||
"""Model in API but without context_length → defaults to the top
|
||||
probe tier (currently 256K)."""
|
||||
mock_fetch.return_value = {"test/model": {"name": "Test"}}
|
||||
assert get_model_context_length("test/model") == 128000
|
||||
assert get_model_context_length("test/model") == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path):
|
||||
@@ -814,14 +815,17 @@ class TestContextProbeTiers:
|
||||
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
|
||||
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
|
||||
|
||||
def test_first_tier_is_128k(self):
|
||||
assert CONTEXT_PROBE_TIERS[0] == 128_000
|
||||
def test_first_tier_is_256k(self):
|
||||
assert CONTEXT_PROBE_TIERS[0] == 256_000
|
||||
|
||||
def test_last_tier_is_8k(self):
|
||||
assert CONTEXT_PROBE_TIERS[-1] == 8_000
|
||||
|
||||
|
||||
class TestGetNextProbeTier:
|
||||
def test_from_256k(self):
|
||||
assert get_next_probe_tier(256_000) == 128_000
|
||||
|
||||
def test_from_128k(self):
|
||||
assert get_next_probe_tier(128_000) == 64_000
|
||||
|
||||
@@ -841,8 +845,8 @@ class TestGetNextProbeTier:
|
||||
assert get_next_probe_tier(100_000) == 64_000
|
||||
|
||||
def test_above_max_tier(self):
|
||||
"""Value above 128K should return 128K."""
|
||||
assert get_next_probe_tier(500_000) == 128_000
|
||||
"""Value above 256K should return 256K."""
|
||||
assert get_next_probe_tier(500_000) == 256_000
|
||||
|
||||
def test_zero_returns_none(self):
|
||||
assert get_next_probe_tier(0) is None
|
||||
|
||||
@@ -58,7 +58,7 @@ class TestFormatSessionInfo:
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "128K" in info
|
||||
assert "256K" in info
|
||||
assert "model.context_length" in info
|
||||
|
||||
def test_local_endpoint_shown(self, runner, tmp_path):
|
||||
|
||||
240
tests/hermes_cli/test_custom_provider_context_length.py
Normal file
240
tests/hermes_cli/test_custom_provider_context_length.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Regression tests for custom_providers per-model context_length resolution.
|
||||
|
||||
Covers the fix for #15779 — mid-session /model switch to a named custom
|
||||
provider must honor ``custom_providers[].models.<id>.context_length`` the
|
||||
same way startup already does.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.config import get_custom_provider_context_length
|
||||
|
||||
|
||||
class TestGetCustomProviderContextLength:
|
||||
def test_returns_override_for_matching_entry(self):
|
||||
custom = [
|
||||
{
|
||||
"name": "my-endpoint",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"gpt-5.5": {"context_length": 1_050_000}},
|
||||
}
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"gpt-5.5", "https://example.invalid/v1", custom
|
||||
)
|
||||
== 1_050_000
|
||||
)
|
||||
|
||||
def test_trailing_slash_insensitive(self):
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1/",
|
||||
"models": {"m": {"context_length": 500_000}},
|
||||
}
|
||||
]
|
||||
# config has trailing slash, runtime doesn't — must match
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"m", "https://example.invalid/v1", custom
|
||||
)
|
||||
== 500_000
|
||||
)
|
||||
# and the reverse
|
||||
custom2 = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"m": {"context_length": 500_000}},
|
||||
}
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"m", "https://example.invalid/v1/", custom2
|
||||
)
|
||||
== 500_000
|
||||
)
|
||||
|
||||
def test_returns_none_when_url_does_not_match(self):
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"m": {"context_length": 400_000}},
|
||||
}
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"m", "https://other.invalid/v1", custom
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_model_does_not_match(self):
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"gpt-5.5": {"context_length": 400_000}},
|
||||
}
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"different-model", "https://example.invalid/v1", custom
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_string_value(self):
|
||||
"""'256K' string is not a valid int — skip silently.
|
||||
|
||||
(The inline startup path still emits a user-visible warning; the
|
||||
helper itself returns None so downstream fallbacks can run.)
|
||||
"""
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"m": {"context_length": "256K"}},
|
||||
}
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"m", "https://example.invalid/v1", custom
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_zero_or_negative(self):
|
||||
for bad in (0, -1, -100):
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"m": {"context_length": bad}},
|
||||
}
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"m", "https://example.invalid/v1", custom
|
||||
)
|
||||
is None
|
||||
), f"value {bad!r} should be rejected"
|
||||
|
||||
def test_empty_inputs_return_none(self):
|
||||
assert get_custom_provider_context_length("", "http://x", [{"base_url": "http://x", "models": {"": {"context_length": 1}}}]) is None
|
||||
assert get_custom_provider_context_length("m", "", [{"base_url": "", "models": {"m": {"context_length": 1}}}]) is None
|
||||
assert get_custom_provider_context_length("m", "http://x", None) is None
|
||||
assert get_custom_provider_context_length("m", "http://x", []) is None
|
||||
|
||||
def test_ignores_non_dict_entries(self):
|
||||
"""Malformed entries must not crash the lookup."""
|
||||
custom = [
|
||||
"not a dict",
|
||||
None,
|
||||
{"base_url": "https://example.invalid/v1", "models": "not a dict"},
|
||||
{"base_url": "https://example.invalid/v1", "models": {"m": "not a dict"}},
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"m": {"context_length": 400_000}},
|
||||
},
|
||||
]
|
||||
assert (
|
||||
get_custom_provider_context_length(
|
||||
"m", "https://example.invalid/v1", custom
|
||||
)
|
||||
== 400_000
|
||||
)
|
||||
|
||||
|
||||
class TestGetModelContextLengthHonorsOverride:
|
||||
"""agent.model_metadata.get_model_context_length must honor the
|
||||
custom_providers override at step 0b — before any probe, cache hit,
|
||||
or models.dev lookup can override it.
|
||||
"""
|
||||
|
||||
def _mock_all_probes(self):
|
||||
"""Context manager that disables every downstream resolution step."""
|
||||
from agent import model_metadata as _mm
|
||||
return [
|
||||
patch.object(_mm, "get_cached_context_length", return_value=None),
|
||||
patch.object(_mm, "fetch_endpoint_model_metadata", return_value={}),
|
||||
patch.object(_mm, "fetch_model_metadata", return_value={}),
|
||||
patch.object(_mm, "is_local_endpoint", return_value=False),
|
||||
patch.object(_mm, "_is_known_provider_base_url", return_value=False),
|
||||
]
|
||||
|
||||
def test_custom_providers_override_wins_over_default_fallback(self):
|
||||
from agent.model_metadata import get_model_context_length
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"gpt-5.5": {"context_length": 1_050_000}},
|
||||
}
|
||||
]
|
||||
patches = self._mock_all_probes()
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
ctx = get_model_context_length(
|
||||
"gpt-5.5",
|
||||
base_url="https://example.invalid/v1",
|
||||
provider="custom",
|
||||
custom_providers=custom,
|
||||
)
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
assert ctx == 1_050_000
|
||||
|
||||
def test_explicit_config_context_length_still_wins(self):
|
||||
"""Top-level model.context_length (step 0) outranks custom_providers (step 0b).
|
||||
|
||||
Users who set both should see the top-level value — that's the
|
||||
documented precedence and matches the long-standing step-0 behavior.
|
||||
"""
|
||||
from agent.model_metadata import get_model_context_length
|
||||
custom = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"m": {"context_length": 1_050_000}},
|
||||
}
|
||||
]
|
||||
ctx = get_model_context_length(
|
||||
"m",
|
||||
base_url="https://example.invalid/v1",
|
||||
provider="custom",
|
||||
config_context_length=500_000, # explicit top-level wins
|
||||
custom_providers=custom,
|
||||
)
|
||||
assert ctx == 500_000
|
||||
|
||||
def test_no_override_falls_through_to_default(self):
|
||||
"""With custom_providers=None and all probes disabled, resolver
|
||||
returns DEFAULT_FALLBACK_CONTEXT (256K after the stepdown bump).
|
||||
"""
|
||||
from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT
|
||||
patches = self._mock_all_probes()
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
ctx = get_model_context_length(
|
||||
"unknown-model",
|
||||
base_url="https://example.invalid/v1",
|
||||
provider="custom",
|
||||
custom_providers=None,
|
||||
)
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
assert ctx == DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
|
||||
class TestContextProbeTiers:
|
||||
def test_256k_is_top_tier_and_default(self):
|
||||
"""The stepdown probe starts at 256K and 256K is the new default."""
|
||||
from agent.model_metadata import CONTEXT_PROBE_TIERS, DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
assert CONTEXT_PROBE_TIERS[0] == 256_000
|
||||
assert DEFAULT_FALLBACK_CONTEXT == 256_000
|
||||
# Tiers still descend monotonically
|
||||
for a, b in zip(CONTEXT_PROBE_TIERS, CONTEXT_PROBE_TIERS[1:]):
|
||||
assert a > b, f"tiers must strictly descend, got {a} then {b}"
|
||||
# 128K is still a tier (users relying on it probe-down get there)
|
||||
assert 128_000 in CONTEXT_PROBE_TIERS
|
||||
@@ -88,3 +88,61 @@ class TestResolveDisplayContextLength:
|
||||
model_info=fake_mi,
|
||||
)
|
||||
assert ctx == 128_000
|
||||
|
||||
def test_custom_providers_override_honored(self):
|
||||
"""Regression for #15779: /model switch onto a custom provider must
|
||||
surface the configured per-model context_length, not the 128K/256K
|
||||
fallback.
|
||||
"""
|
||||
custom_provs = [
|
||||
{
|
||||
"name": "my-custom-endpoint",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"models": {"gpt-5.5": {"context_length": 1_050_000}},
|
||||
}
|
||||
]
|
||||
# Real resolver call — no mock — so the override path is exercised
|
||||
# through agent.model_metadata.get_model_context_length.
|
||||
from unittest.mock import patch as _p
|
||||
from agent import model_metadata as _mm
|
||||
with _p.object(_mm, "get_cached_context_length", return_value=None), \
|
||||
_p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \
|
||||
_p.object(_mm, "fetch_model_metadata", return_value={}), \
|
||||
_p.object(_mm, "is_local_endpoint", return_value=False), \
|
||||
_p.object(_mm, "_is_known_provider_base_url", return_value=False):
|
||||
ctx = resolve_display_context_length(
|
||||
"gpt-5.5",
|
||||
"custom",
|
||||
base_url="https://example.invalid/v1",
|
||||
api_key="k",
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
assert ctx == 1_050_000, (
|
||||
"custom_providers[].models.gpt-5.5.context_length=1.05M must win "
|
||||
"over probe-down fallback"
|
||||
)
|
||||
|
||||
def test_custom_providers_trailing_slash_insensitive(self):
|
||||
"""Base URL comparison must tolerate trailing-slash differences
|
||||
between config.yaml and the runtime value.
|
||||
"""
|
||||
custom_provs = [
|
||||
{
|
||||
"base_url": "https://example.invalid/v1/",
|
||||
"models": {"m": {"context_length": 400_000}},
|
||||
}
|
||||
]
|
||||
from unittest.mock import patch as _p
|
||||
from agent import model_metadata as _mm
|
||||
with _p.object(_mm, "get_cached_context_length", return_value=None), \
|
||||
_p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \
|
||||
_p.object(_mm, "fetch_model_metadata", return_value={}), \
|
||||
_p.object(_mm, "is_local_endpoint", return_value=False), \
|
||||
_p.object(_mm, "_is_known_provider_base_url", return_value=False):
|
||||
ctx = resolve_display_context_length(
|
||||
"m",
|
||||
"custom",
|
||||
base_url="https://example.invalid/v1", # no trailing slash
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
assert ctx == 400_000
|
||||
|
||||
Reference in New Issue
Block a user