fix: forward auth when probing local model metadata
Pass the user's configured api_key through local-server detection and context-length probes (detect_local_server_type, _query_local_context_length, query_ollama_num_ctx) and use LM Studio's native /api/v1/models endpoint in fetch_endpoint_model_metadata when a loaded instance is present — so the probed context length is the actual runtime value the user loaded the model at, not just the model's theoretical max. Helps local-LLM users whose auto-detected context length was wrong, causing compression failures and context-overrun crashes.
This commit is contained in:
@@ -210,6 +210,13 @@ def _normalize_base_url(base_url: str) -> str:
|
||||
return (base_url or "").strip().rstrip("/")
|
||||
|
||||
|
||||
def _auth_headers(api_key: str = "") -> Dict[str, str]:
|
||||
token = str(api_key or "").strip()
|
||||
if not token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _is_openrouter_base_url(base_url: str) -> bool:
|
||||
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
||||
|
||||
@@ -309,7 +316,7 @@ def is_local_endpoint(base_url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||
def detect_local_server_type(base_url: str, api_key: str = "") -> Optional[str]:
|
||||
"""Detect which local server is running at base_url by probing known endpoints.
|
||||
|
||||
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
||||
@@ -321,8 +328,10 @@ def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
headers = _auth_headers(api_key)
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=2.0) as client:
|
||||
with httpx.Client(timeout=2.0, headers=headers) as client:
|
||||
# LM Studio exposes /api/v1/models — check first (most specific)
|
||||
try:
|
||||
r = client.get(f"{server_url}/api/v1/models")
|
||||
@@ -509,6 +518,59 @@ def fetch_endpoint_model_metadata(
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
if is_local_endpoint(normalized):
|
||||
try:
|
||||
if detect_local_server_type(normalized, api_key=api_key) == "lm-studio":
|
||||
server_url = normalized[:-3].rstrip("/") if normalized.endswith("/v1") else normalized
|
||||
response = requests.get(
|
||||
server_url.rstrip("/") + "/api/v1/models",
|
||||
headers=headers,
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
cache: Dict[str, Dict[str, Any]] = {}
|
||||
for model in payload.get("models", []):
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
model_id = model.get("key") or model.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
||||
|
||||
context_length = None
|
||||
for inst in model.get("loaded_instances", []) or []:
|
||||
if not isinstance(inst, dict):
|
||||
continue
|
||||
cfg = inst.get("config", {})
|
||||
ctx = cfg.get("context_length") if isinstance(cfg, dict) else None
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
context_length = ctx
|
||||
break
|
||||
if context_length is None:
|
||||
context_length = _extract_context_length(model)
|
||||
if context_length is not None:
|
||||
entry["context_length"] = context_length
|
||||
|
||||
max_completion_tokens = _extract_max_completion_tokens(model)
|
||||
if max_completion_tokens is not None:
|
||||
entry["max_completion_tokens"] = max_completion_tokens
|
||||
|
||||
pricing = _extract_pricing(model)
|
||||
if pricing:
|
||||
entry["pricing"] = pricing
|
||||
|
||||
_add_model_aliases(cache, model_id, entry)
|
||||
alt_id = model.get("id")
|
||||
if isinstance(alt_id, str) and alt_id and alt_id != model_id:
|
||||
_add_model_aliases(cache, alt_id, entry)
|
||||
|
||||
_endpoint_model_metadata_cache[normalized] = cache
|
||||
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||
return cache
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
for candidate in candidates:
|
||||
url = candidate.rstrip("/") + "/models"
|
||||
try:
|
||||
@@ -715,7 +777,7 @@ def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
|
||||
def query_ollama_num_ctx(model: str, base_url: str, api_key: str = "") -> Optional[int]:
|
||||
"""Query an Ollama server for the model's context length.
|
||||
|
||||
Returns the model's maximum context from GGUF metadata via ``/api/show``,
|
||||
@@ -733,14 +795,16 @@ def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
server_type = detect_local_server_type(base_url)
|
||||
server_type = detect_local_server_type(base_url, api_key=api_key)
|
||||
except Exception:
|
||||
return None
|
||||
if server_type != "ollama":
|
||||
return None
|
||||
|
||||
headers = _auth_headers(api_key)
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
with httpx.Client(timeout=3.0, headers=headers) as client:
|
||||
resp = client.post(f"{server_url}/api/show", json={"name": bare_model})
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
@@ -768,7 +832,7 @@ def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
def _query_local_context_length(model: str, base_url: str, api_key: str = "") -> Optional[int]:
|
||||
"""Query a local server for the model's context length."""
|
||||
import httpx
|
||||
|
||||
@@ -781,13 +845,15 @@ def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
headers = _auth_headers(api_key)
|
||||
|
||||
try:
|
||||
server_type = detect_local_server_type(base_url)
|
||||
server_type = detect_local_server_type(base_url, api_key=api_key)
|
||||
except Exception:
|
||||
server_type = None
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
with httpx.Client(timeout=3.0, headers=headers) as client:
|
||||
# Ollama: /api/show returns model details with context info
|
||||
if server_type == "ollama":
|
||||
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
||||
@@ -998,7 +1064,7 @@ def get_model_context_length(
|
||||
if not _is_known_provider_base_url(base_url):
|
||||
# 3. Try querying local server directly
|
||||
if is_local_endpoint(base_url):
|
||||
local_ctx = _query_local_context_length(model, base_url)
|
||||
local_ctx = _query_local_context_length(model, base_url, api_key=api_key)
|
||||
if local_ctx and local_ctx > 0:
|
||||
save_context_length(model, base_url, local_ctx)
|
||||
return local_ctx
|
||||
@@ -1068,7 +1134,7 @@ def get_model_context_length(
|
||||
|
||||
# 9. Query local server as last resort
|
||||
if base_url and is_local_endpoint(base_url):
|
||||
local_ctx = _query_local_context_length(model, base_url)
|
||||
local_ctx = _query_local_context_length(model, base_url, api_key=api_key)
|
||||
if local_ctx and local_ctx > 0:
|
||||
save_context_length(model, base_url, local_ctx)
|
||||
return local_ctx
|
||||
|
||||
@@ -3876,9 +3876,11 @@ class GatewayRunner:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
_msg_cwd = os.environ.get("TERMINAL_CWD", os.path.expanduser("~"))
|
||||
_msg_runtime = _resolve_runtime_agent_kwargs()
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model,
|
||||
base_url=self._base_url or "",
|
||||
base_url=self._base_url or _msg_runtime.get("base_url") or "",
|
||||
api_key=_msg_runtime.get("api_key") or "",
|
||||
)
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text,
|
||||
|
||||
@@ -424,6 +424,68 @@ class TestQueryLocalContextLengthLmStudio:
|
||||
)
|
||||
|
||||
|
||||
class TestDetectLocalServerTypeAuth:
|
||||
def test_passes_bearer_token_to_probe_requests(self):
|
||||
from agent.model_metadata import detect_local_server_type
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.get.return_value = resp
|
||||
|
||||
with patch("httpx.Client", return_value=client_mock) as mock_client:
|
||||
result = detect_local_server_type("http://localhost:1234/v1", api_key="lm-token")
|
||||
|
||||
assert result == "lm-studio"
|
||||
assert mock_client.call_args.kwargs["headers"] == {
|
||||
"Authorization": "Bearer lm-token"
|
||||
}
|
||||
|
||||
|
||||
class TestFetchEndpointModelMetadataLmStudio:
|
||||
"""fetch_endpoint_model_metadata should use LM Studio's native models endpoint."""
|
||||
|
||||
def _make_resp(self, body):
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.json.return_value = body
|
||||
return resp
|
||||
|
||||
def test_uses_native_models_endpoint_only(self):
|
||||
from agent.model_metadata import fetch_endpoint_model_metadata
|
||||
|
||||
native_resp = self._make_resp(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"key": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
|
||||
"id": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
|
||||
"max_context_length": 131072,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("agent.model_metadata.requests.get", return_value=native_resp) as mock_get:
|
||||
result = fetch_endpoint_model_metadata(
|
||||
"http://localhost:1234/v1",
|
||||
api_key="lm-token",
|
||||
force_refresh=True,
|
||||
)
|
||||
|
||||
assert mock_get.call_count == 1
|
||||
assert mock_get.call_args[0][0] == "http://localhost:1234/api/v1/models"
|
||||
assert mock_get.call_args.kwargs["headers"] == {
|
||||
"Authorization": "Bearer lm-token"
|
||||
}
|
||||
assert result["lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
|
||||
assert result["Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
|
||||
|
||||
|
||||
class TestQueryLocalContextLengthNetworkError:
|
||||
"""_query_local_context_length handles network failures gracefully."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user