From eeb8b4b00f8e6f549de6423075a31c145822d6da Mon Sep 17 00:00:00 2001 From: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com> Date: Sat, 11 Apr 2026 12:37:53 +0530 Subject: [PATCH] fix(auxiliary): harden fallback behavior for non-OpenRouter users Four fixes to auxiliary_client.py: 1. Respect explicit provider as hard constraint (#7559) When auxiliary.{task}.provider is explicitly set (not 'auto'), connection/payment errors no longer silently fallback to cloud providers. Local-only users (Ollama, vLLM) will no longer get unexpected OpenRouter billing from auxiliary tasks. 2. Eliminate model='default' sentinel (#7512) _resolve_api_key_provider() no longer sends literal 'default' as model name to APIs. Providers without a known aux model in _API_KEY_PROVIDER_AUX_MODELS are skipped instead of producing model_not_supported errors. 3. Add payment/connection fallback to async_call_llm (#7512) async_call_llm now mirrors sync call_llm's fallback logic for payment (402) and connection errors. Previously, async consumers (session_search, web_tools, vision) got hard failures with no recovery. Also fixes hardcoded 'openrouter' fallback to use the full auto-detection chain. 4. Use accurate error reason in fallback logs (#7512) _try_payment_fallback() now accepts a reason parameter and uses it in log messages. Connection timeouts are no longer misleadingly logged as 'payment error'. Closes #7559 Closes #7512 --- agent/auxiliary_client.py | 73 ++++++-- tests/agent/test_auxiliary_client.py | 259 ++++++++++++++++++++++++++- 2 files changed, 305 insertions(+), 27 deletions(-) diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index aa823006f..32188b2e8 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -707,7 +707,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: base_url = _to_openai_base_url( _pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url ) - model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default") + model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id) + if model is None: + continue # skip provider if we don't know a valid aux model logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model) extra = {} if "api.kimi.com" in base_url.lower(): @@ -726,7 +728,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: base_url = _to_openai_base_url( str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url ) - model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default") + model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id) + if model is None: + continue # skip provider if we don't know a valid aux model logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model) extra = {} if "api.kimi.com" in base_url.lower(): @@ -1075,11 +1079,12 @@ def _is_connection_error(exc: Exception) -> bool: def _try_payment_fallback( failed_provider: str, task: str = None, + reason: str = "payment error", ) -> Tuple[Optional[Any], Optional[str], str]: - """Try alternative providers after a payment/credit error. + """Try alternative providers after a payment/credit or connection error. Iterates the standard auto-detection chain, skipping the provider that - returned a payment error. + failed. Returns: (client, model, provider_label) or (None, None, "") if no fallback. @@ -1105,15 +1110,15 @@ def _try_payment_fallback( client, model = try_fn() if client is not None: logger.info( - "Auxiliary %s: payment error on %s — falling back to %s (%s)", - task or "call", failed_provider, label, model or "default", + "Auxiliary %s: %s on %s — falling back to %s (%s)", + task or "call", reason, failed_provider, label, model or "default", ) return client, model, label tried.append(label) logger.warning( - "Auxiliary %s: payment error on %s and no fallback available (tried: %s)", - task or "call", failed_provider, ", ".join(tried), + "Auxiliary %s: %s on %s and no fallback available (tried: %s)", + task or "call", reason, failed_provider, ", ".join(tried), ) return None, None, "" @@ -2178,9 +2183,9 @@ def call_llm( try: return client.chat.completions.create(**kwargs) except Exception as retry_err: - # If the max_tokens retry also hits a payment error, - # fall through to the payment fallback below. - if not _is_payment_error(retry_err): + # If the max_tokens retry also hits a payment or connection + # error, fall through to the fallback chain below. + if not (_is_payment_error(retry_err) or _is_connection_error(retry_err)): raise first_err = retry_err @@ -2197,12 +2202,16 @@ def call_llm( # and providers the user never configured that got picked up by # the auto-detection chain. should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) - if should_fallback: + # Only try alternative providers when the user didn't explicitly + # configure this task's provider. Explicit provider = hard constraint; + # auto (the default) = best-effort fallback chain. (#7559) + is_auto = resolved_provider in ("auto", "", None) + if should_fallback and is_auto: reason = "payment error" if _is_payment_error(first_err) else "connection error" logger.info("Auxiliary %s: %s on %s (%s), trying fallback", task or "call", reason, resolved_provider, first_err) fb_client, fb_model, fb_label = _try_payment_fallback( - resolved_provider, task) + resolved_provider, task, reason=reason) if fb_client is not None: fb_kwargs = _build_call_kwargs( fb_label, fb_model, messages, @@ -2332,11 +2341,9 @@ async def async_call_llm( f"variable, or switch to a different provider with `hermes model`." ) if not resolved_base_url: - logger.warning("Provider %s unavailable, falling back to openrouter", - resolved_provider) - client, final_model = _get_cached_client( - "openrouter", resolved_model or _OPENROUTER_MODEL, - async_mode=True) + logger.info("Auxiliary %s: provider %s unavailable, trying auto-detection chain", + task or "call", resolved_provider) + client, final_model = _get_cached_client("auto", async_mode=True) if client is None: raise RuntimeError( f"No LLM provider configured for task={task} provider={resolved_provider}. " @@ -2357,5 +2364,33 @@ async def async_call_llm( if "max_tokens" in err_str or "unsupported_parameter" in err_str: kwargs.pop("max_tokens", None) kwargs["max_completion_tokens"] = max_tokens - return await client.chat.completions.create(**kwargs) + try: + return await client.chat.completions.create(**kwargs) + except Exception as retry_err: + # If the max_tokens retry also hits a payment or connection + # error, fall through to the fallback chain below. + if not (_is_payment_error(retry_err) or _is_connection_error(retry_err)): + raise + first_err = retry_err + + # ── Payment / connection fallback (mirrors sync call_llm) ───── + should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) + is_auto = resolved_provider in ("auto", "", None) + if should_fallback and is_auto: + reason = "payment error" if _is_payment_error(first_err) else "connection error" + logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback", + task or "call", reason, resolved_provider, first_err) + fb_client, fb_model, fb_label = _try_payment_fallback( + resolved_provider, task, reason=reason) + if fb_client is not None: + fb_kwargs = _build_call_kwargs( + fb_label, fb_model, messages, + temperature=temperature, max_tokens=max_tokens, + tools=tools, timeout=effective_timeout, + extra_body=extra_body) + # Convert sync fallback client to async + async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "") + if async_fb_model and async_fb_model != fb_kwargs.get("model"): + fb_kwargs["model"] = async_fb_model + return await async_fb.chat.completions.create(**fb_kwargs) raise diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index 61020e195..2d6a3fc7f 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -3,7 +3,7 @@ import json import os from pathlib import Path -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock import pytest @@ -14,6 +14,7 @@ from agent.auxiliary_client import ( resolve_provider_client, auxiliary_max_tokens_param, call_llm, + async_call_llm, _read_codex_access_token, _get_auxiliary_provider, _get_provider_chain, @@ -1122,8 +1123,8 @@ class TestCallLlmPaymentFallback: exc.status_code = 402 return exc - def test_402_triggers_fallback(self, monkeypatch): - """When the primary provider returns 402, call_llm tries the next one.""" + def test_402_triggers_fallback_when_auto(self, monkeypatch): + """When provider is auto and returns 402, call_llm tries the next one.""" monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") primary_client = MagicMock() @@ -1136,7 +1137,7 @@ class TestCallLlmPaymentFallback: with patch("agent.auxiliary_client._get_cached_client", return_value=(primary_client, "google/gemini-3-flash-preview")), \ patch("agent.auxiliary_client._resolve_task_provider_model", - return_value=("openrouter", "google/gemini-3-flash-preview", None, None, None)), \ + return_value=("auto", "google/gemini-3-flash-preview", None, None, None)), \ patch("agent.auxiliary_client._try_payment_fallback", return_value=(fallback_client, "gpt-5.2-codex", "openai-codex")) as mock_fb: result = call_llm( @@ -1145,13 +1146,62 @@ class TestCallLlmPaymentFallback: ) assert result is fallback_response - mock_fb.assert_called_once_with("openrouter", "compression") + mock_fb.assert_called_once_with("auto", "compression", reason="payment error") # Fallback call should use the fallback model fb_kwargs = fallback_client.chat.completions.create.call_args.kwargs assert fb_kwargs["model"] == "gpt-5.2-codex" + def test_402_no_fallback_when_explicit_provider(self, monkeypatch): + """When provider is explicitly configured (not auto), 402 should NOT fallback (#7559).""" + monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") + + primary_client = MagicMock() + primary_client.chat.completions.create.side_effect = self._make_402_error() + + with patch("agent.auxiliary_client._get_cached_client", + return_value=(primary_client, "local-model")), \ + patch("agent.auxiliary_client._resolve_task_provider_model", + return_value=("custom", "local-model", None, None, None)), \ + patch("agent.auxiliary_client._try_payment_fallback") as mock_fb: + with pytest.raises(Exception, match="insufficient credits"): + call_llm( + task="compression", + messages=[{"role": "user", "content": "hello"}], + ) + + # Fallback should NOT be attempted when provider is explicit + mock_fb.assert_not_called() + + def test_connection_error_triggers_fallback_when_auto(self, monkeypatch): + """Connection errors also trigger fallback when provider is auto.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") + + primary_client = MagicMock() + conn_err = Exception("Connection refused") + conn_err.status_code = None + primary_client.chat.completions.create.side_effect = conn_err + + fallback_client = MagicMock() + fallback_response = MagicMock() + fallback_client.chat.completions.create.return_value = fallback_response + + with patch("agent.auxiliary_client._get_cached_client", + return_value=(primary_client, "model")), \ + patch("agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", "model", None, None, None)), \ + patch("agent.auxiliary_client._is_connection_error", return_value=True), \ + patch("agent.auxiliary_client._try_payment_fallback", + return_value=(fallback_client, "fb-model", "nous")) as mock_fb: + result = call_llm( + task="compression", + messages=[{"role": "user", "content": "hello"}], + ) + + assert result is fallback_response + mock_fb.assert_called_once_with("auto", "compression", reason="connection error") + def test_non_payment_error_not_caught(self, monkeypatch): - """Non-payment errors (500, connection, etc.) should NOT trigger fallback.""" + """Non-payment/non-connection errors (500) should NOT trigger fallback.""" monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") primary_client = MagicMock() @@ -1162,7 +1212,7 @@ class TestCallLlmPaymentFallback: with patch("agent.auxiliary_client._get_cached_client", return_value=(primary_client, "google/gemini-3-flash-preview")), \ patch("agent.auxiliary_client._resolve_task_provider_model", - return_value=("openrouter", "google/gemini-3-flash-preview", None, None, None)): + return_value=("auto", "google/gemini-3-flash-preview", None, None, None)): with pytest.raises(Exception, match="Internal Server Error"): call_llm( task="compression", @@ -1179,7 +1229,7 @@ class TestCallLlmPaymentFallback: with patch("agent.auxiliary_client._get_cached_client", return_value=(primary_client, "google/gemini-3-flash-preview")), \ patch("agent.auxiliary_client._resolve_task_provider_model", - return_value=("openrouter", "google/gemini-3-flash-preview", None, None, None)), \ + return_value=("auto", "google/gemini-3-flash-preview", None, None, None)), \ patch("agent.auxiliary_client._try_payment_fallback", return_value=(None, None, "")): with pytest.raises(Exception, match="insufficient credits"): @@ -1229,3 +1279,196 @@ def test_resolve_api_key_provider_skips_unconfigured_anthropic(monkeypatch): assert "anthropic" not in called, \ "_try_anthropic() should not be called when anthropic is not explicitly configured" + + +# --------------------------------------------------------------------------- +# model="default" elimination (#7512) +# --------------------------------------------------------------------------- + + +class TestModelDefaultElimination: + """_resolve_api_key_provider must skip providers without known aux models.""" + + def test_unknown_provider_skipped(self, monkeypatch): + """Providers not in _API_KEY_PROVIDER_AUX_MODELS are skipped, not sent model='default'.""" + from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS + + # Verify our known providers have entries + assert "gemini" in _API_KEY_PROVIDER_AUX_MODELS + assert "kimi-coding" in _API_KEY_PROVIDER_AUX_MODELS + + # A random provider_id not in the dict should return None + assert _API_KEY_PROVIDER_AUX_MODELS.get("totally-unknown-provider") is None + + def test_known_provider_gets_real_model(self): + """Known providers get a real model name, not 'default'.""" + from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS + + for provider_id, model in _API_KEY_PROVIDER_AUX_MODELS.items(): + assert model != "default", f"{provider_id} should not map to 'default'" + assert isinstance(model, str) and model.strip(), \ + f"{provider_id} should have a non-empty model string" + + +# --------------------------------------------------------------------------- +# _try_payment_fallback reason parameter (#7512 bug 3) +# --------------------------------------------------------------------------- + + +class TestTryPaymentFallbackReason: + """_try_payment_fallback uses the reason parameter in log messages.""" + + def test_reason_parameter_passed_through(self, monkeypatch): + """The reason= parameter is accepted without error.""" + from agent.auxiliary_client import _try_payment_fallback + + # Mock the provider chain to return nothing + monkeypatch.setattr( + "agent.auxiliary_client._get_provider_chain", + lambda: [], + ) + monkeypatch.setattr( + "agent.auxiliary_client._read_main_provider", + lambda: "", + ) + + client, model, label = _try_payment_fallback( + "openrouter", task="compression", reason="connection error" + ) + assert client is None + assert label == "" + + +# --------------------------------------------------------------------------- +# _is_connection_error coverage +# --------------------------------------------------------------------------- + + +class TestIsConnectionError: + """Tests for _is_connection_error detection.""" + + def test_connection_refused(self): + from agent.auxiliary_client import _is_connection_error + err = Exception("Connection refused") + assert _is_connection_error(err) is True + + def test_timeout(self): + from agent.auxiliary_client import _is_connection_error + err = Exception("Request timed out.") + assert _is_connection_error(err) is True + + def test_dns_failure(self): + from agent.auxiliary_client import _is_connection_error + err = Exception("Name or service not known") + assert _is_connection_error(err) is True + + def test_normal_api_error_not_connection(self): + from agent.auxiliary_client import _is_connection_error + err = Exception("Bad Request: invalid model") + err.status_code = 400 + assert _is_connection_error(err) is False + + def test_500_not_connection(self): + from agent.auxiliary_client import _is_connection_error + err = Exception("Internal Server Error") + err.status_code = 500 + assert _is_connection_error(err) is False + + +# --------------------------------------------------------------------------- +# async_call_llm payment / connection fallback (#7512 bug 2) +# --------------------------------------------------------------------------- + + +class TestAsyncCallLlmFallback: + """async_call_llm mirrors call_llm fallback behavior.""" + + def _make_402_error(self, msg="Payment Required: insufficient credits"): + exc = Exception(msg) + exc.status_code = 402 + return exc + + @pytest.mark.asyncio + async def test_402_triggers_async_fallback_when_auto(self, monkeypatch): + """When provider is auto and returns 402, async_call_llm tries fallback.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") + + primary_client = MagicMock() + primary_client.chat.completions.create = AsyncMock( + side_effect=self._make_402_error()) + + # Fallback client (sync) returned by _try_payment_fallback + fb_sync_client = MagicMock() + fb_async_client = MagicMock() + fb_response = MagicMock() + fb_async_client.chat.completions.create = AsyncMock(return_value=fb_response) + + with patch("agent.auxiliary_client._get_cached_client", + return_value=(primary_client, "google/gemini-3-flash-preview")), \ + patch("agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", "google/gemini-3-flash-preview", None, None)), \ + patch("agent.auxiliary_client._try_payment_fallback", + return_value=(fb_sync_client, "gpt-5.2-codex", "openai-codex")) as mock_fb, \ + patch("agent.auxiliary_client._to_async_client", + return_value=(fb_async_client, "gpt-5.2-codex")): + result = await async_call_llm( + task="compression", + messages=[{"role": "user", "content": "hello"}], + ) + + assert result is fb_response + mock_fb.assert_called_once_with("auto", "compression", reason="payment error") + + @pytest.mark.asyncio + async def test_402_no_async_fallback_when_explicit(self, monkeypatch): + """When provider is explicit, 402 should NOT trigger async fallback.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") + + primary_client = MagicMock() + primary_client.chat.completions.create = AsyncMock( + side_effect=self._make_402_error()) + + with patch("agent.auxiliary_client._get_cached_client", + return_value=(primary_client, "local-model")), \ + patch("agent.auxiliary_client._resolve_task_provider_model", + return_value=("custom", "local-model", None, None, None)), \ + patch("agent.auxiliary_client._try_payment_fallback") as mock_fb: + with pytest.raises(Exception, match="insufficient credits"): + await async_call_llm( + task="compression", + messages=[{"role": "user", "content": "hello"}], + ) + + mock_fb.assert_not_called() + + @pytest.mark.asyncio + async def test_connection_error_triggers_async_fallback(self, monkeypatch): + """Connection errors trigger async fallback when provider is auto.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") + + primary_client = MagicMock() + conn_err = Exception("Connection refused") + conn_err.status_code = None + primary_client.chat.completions.create = AsyncMock(side_effect=conn_err) + + fb_sync_client = MagicMock() + fb_async_client = MagicMock() + fb_response = MagicMock() + fb_async_client.chat.completions.create = AsyncMock(return_value=fb_response) + + with patch("agent.auxiliary_client._get_cached_client", + return_value=(primary_client, "model")), \ + patch("agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", "model", None, None, None)), \ + patch("agent.auxiliary_client._is_connection_error", return_value=True), \ + patch("agent.auxiliary_client._try_payment_fallback", + return_value=(fb_sync_client, "fb-model", "nous")) as mock_fb, \ + patch("agent.auxiliary_client._to_async_client", + return_value=(fb_async_client, "fb-model")): + result = await async_call_llm( + task="compression", + messages=[{"role": "user", "content": "hello"}], + ) + + assert result is fb_response + mock_fb.assert_called_once_with("auto", "compression", reason="connection error")