diff --git a/rdagent/oai/backend/litellm.py b/rdagent/oai/backend/litellm.py index 86cbe8461..beb2d5f2e 100644 --- a/rdagent/oai/backend/litellm.py +++ b/rdagent/oai/backend/litellm.py @@ -12,6 +12,7 @@ token_counter, ) from litellm.exceptions import BadRequestError, Timeout +from litellm.utils import get_llm_provider from pydantic import BaseModel from rdagent.log import LogColors @@ -45,12 +46,44 @@ class Config: ACC_COST = 0.0 +def _validate_litellm_model_provider(model: str, setting_env_var: str) -> None: + """Validate that ``model`` resolves to a LiteLLM provider before any + completion / embedding call is attempted. + + LiteLLM uses a ``/`` naming convention; for bare names it + cannot disambiguate it falls back to the OpenAI provider, and for names + it does not recognise (e.g. ``deepseek-chat``, ``qwen2:7b``) it raises + ``BadRequestError: LLM Provider NOT provided``. That error surfaces from + deep inside the retry loop with no actionable hint, which is the bug + reported in #1016. By calling ``get_llm_provider`` at backend + construction, the misconfiguration is caught before any HTTP call with a + clear message that points at the env-var fix. + """ + try: + get_llm_provider(model) + except BadRequestError as e: + raise RuntimeError( + f"Could not determine the LiteLLM provider for " + f"{setting_env_var}={model!r}.\n" + f"LiteLLM uses a '/' naming convention. " + f"Update the model name to include the provider prefix, for example:\n" + f" CHAT_MODEL=openai/gpt-4o\n" + f" CHAT_MODEL=deepseek/deepseek-chat\n" + f" CHAT_MODEL=ollama/qwen2:7b\n" + f" EMBEDDING_MODEL=ollama/nomic-embed-text\n" + f"See https://docs.litellm.ai/docs/providers for the full list " + f"of supported provider prefixes." + ) from e + + class LiteLLMAPIBackend(APIBackend): """LiteLLM implementation of APIBackend interface""" _has_logged_settings: bool = False def __init__(self, *args: Any, **kwargs: Any) -> None: + _validate_litellm_model_provider(LITELLM_SETTINGS.chat_model, "CHAT_MODEL") + _validate_litellm_model_provider(LITELLM_SETTINGS.embedding_model, "EMBEDDING_MODEL") if not self.__class__._has_logged_settings: logger.info(f"{LITELLM_SETTINGS}") logger.log_object(LITELLM_SETTINGS.model_dump(), tag="LITELLM_SETTINGS") diff --git a/test/oai/test_litellm_model_validation.py b/test/oai/test_litellm_model_validation.py new file mode 100644 index 000000000..e24c99545 --- /dev/null +++ b/test/oai/test_litellm_model_validation.py @@ -0,0 +1,65 @@ +"""Unit tests for the upfront ``/`` validation added in +response to #1016. + +These tests do not hit any live LLM endpoint; they only exercise the +``get_llm_provider`` resolution that LiteLLM exposes purely client-side. +""" + +import unittest + +import pytest + +# get_llm_provider is the same helper the validation function uses; this +# avoids tripping over the lazy environment requirements of constructing a +# full LiteLLMAPIBackend instance. +try: + from litellm.exceptions import BadRequestError + from litellm.utils import get_llm_provider +except ImportError: # pragma: no cover - litellm is a hard dependency of RD-Agent + pytest.skip("litellm is not installed; skipping provider validation tests", allow_module_level=True) + +from rdagent.oai.backend.litellm import _validate_litellm_model_provider + + +class TestValidateLiteLLMModelProvider(unittest.TestCase): + def test_default_chat_model_passes(self) -> None: + # gpt-4-turbo is the default; LiteLLM resolves bare OpenAI names. + _validate_litellm_model_provider("gpt-4-turbo", "CHAT_MODEL") + + def test_default_embedding_model_passes(self) -> None: + _validate_litellm_model_provider("text-embedding-3-small", "EMBEDDING_MODEL") + + def test_provider_prefixed_chat_models_pass(self) -> None: + for model in [ + "openai/gpt-4o", + "deepseek/deepseek-chat", + "ollama/qwen2:7b", + "azure/my-deployment", + ]: + with self.subTest(model=model): + _validate_litellm_model_provider(model, "CHAT_MODEL") + + def test_bare_unknown_chat_model_raises_actionable_error(self) -> None: + # This is the exact failure mode from #1016: a bare model name that + # LiteLLM cannot map to a known provider. + with pytest.raises(RuntimeError) as exc_info: + _validate_litellm_model_provider("definitely-not-a-real-model-xyz", "CHAT_MODEL") + message = str(exc_info.value) + # Error must surface the offending env var name, the bad value, and + # at least one corrected example, so the user can fix it without + # having to grep the source. + self.assertIn("CHAT_MODEL", message) + self.assertIn("definitely-not-a-real-model-xyz", message) + self.assertIn("/", message) + # The chained cause should still be the original LiteLLM exception so + # advanced users can inspect it. + self.assertIsInstance(exc_info.value.__cause__, BadRequestError) + + def test_bare_unknown_embedding_model_raises_actionable_error(self) -> None: + with pytest.raises(RuntimeError) as exc_info: + _validate_litellm_model_provider("definitely-not-a-real-emb-xyz", "EMBEDDING_MODEL") + self.assertIn("EMBEDDING_MODEL", str(exc_info.value)) + + +if __name__ == "__main__": + unittest.main()