Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions rdagent/oai/backend/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
token_counter,
)
from litellm.exceptions import BadRequestError, Timeout
from litellm.utils import get_llm_provider
from pydantic import BaseModel

from rdagent.log import LogColors
Expand Down Expand Up @@ -45,12 +46,44 @@ class Config:
ACC_COST = 0.0


def _validate_litellm_model_provider(model: str, setting_env_var: str) -> None:
"""Validate that ``model`` resolves to a LiteLLM provider before any
completion / embedding call is attempted.

LiteLLM uses a ``<provider>/<model>`` naming convention; for bare names it
cannot disambiguate it falls back to the OpenAI provider, and for names
it does not recognise (e.g. ``deepseek-chat``, ``qwen2:7b``) it raises
``BadRequestError: LLM Provider NOT provided``. That error surfaces from
deep inside the retry loop with no actionable hint, which is the bug
reported in #1016. By calling ``get_llm_provider`` at backend
construction, the misconfiguration is caught before any HTTP call with a
clear message that points at the env-var fix.
"""
try:
get_llm_provider(model)
except BadRequestError as e:
raise RuntimeError(
f"Could not determine the LiteLLM provider for "
f"{setting_env_var}={model!r}.\n"
f"LiteLLM uses a '<provider>/<model>' naming convention. "
f"Update the model name to include the provider prefix, for example:\n"
f" CHAT_MODEL=openai/gpt-4o\n"
f" CHAT_MODEL=deepseek/deepseek-chat\n"
f" CHAT_MODEL=ollama/qwen2:7b\n"
f" EMBEDDING_MODEL=ollama/nomic-embed-text\n"
f"See https://docs.litellm.ai/docs/providers for the full list "
f"of supported provider prefixes."
) from e


class LiteLLMAPIBackend(APIBackend):
"""LiteLLM implementation of APIBackend interface"""

_has_logged_settings: bool = False

def __init__(self, *args: Any, **kwargs: Any) -> None:
_validate_litellm_model_provider(LITELLM_SETTINGS.chat_model, "CHAT_MODEL")
_validate_litellm_model_provider(LITELLM_SETTINGS.embedding_model, "EMBEDDING_MODEL")
if not self.__class__._has_logged_settings:
logger.info(f"{LITELLM_SETTINGS}")
logger.log_object(LITELLM_SETTINGS.model_dump(), tag="LITELLM_SETTINGS")
Expand Down
65 changes: 65 additions & 0 deletions test/oai/test_litellm_model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Unit tests for the upfront ``<provider>/<model>`` validation added in
response to #1016.

These tests do not hit any live LLM endpoint; they only exercise the
``get_llm_provider`` resolution that LiteLLM exposes purely client-side.
"""

import unittest

import pytest

# get_llm_provider is the same helper the validation function uses; this
# avoids tripping over the lazy environment requirements of constructing a
# full LiteLLMAPIBackend instance.
try:
from litellm.exceptions import BadRequestError
from litellm.utils import get_llm_provider
except ImportError: # pragma: no cover - litellm is a hard dependency of RD-Agent
pytest.skip("litellm is not installed; skipping provider validation tests", allow_module_level=True)

from rdagent.oai.backend.litellm import _validate_litellm_model_provider


class TestValidateLiteLLMModelProvider(unittest.TestCase):
def test_default_chat_model_passes(self) -> None:
# gpt-4-turbo is the default; LiteLLM resolves bare OpenAI names.
_validate_litellm_model_provider("gpt-4-turbo", "CHAT_MODEL")

def test_default_embedding_model_passes(self) -> None:
_validate_litellm_model_provider("text-embedding-3-small", "EMBEDDING_MODEL")

def test_provider_prefixed_chat_models_pass(self) -> None:
for model in [
"openai/gpt-4o",
"deepseek/deepseek-chat",
"ollama/qwen2:7b",
"azure/my-deployment",
]:
with self.subTest(model=model):
_validate_litellm_model_provider(model, "CHAT_MODEL")

def test_bare_unknown_chat_model_raises_actionable_error(self) -> None:
# This is the exact failure mode from #1016: a bare model name that
# LiteLLM cannot map to a known provider.
with pytest.raises(RuntimeError) as exc_info:
_validate_litellm_model_provider("definitely-not-a-real-model-xyz", "CHAT_MODEL")
message = str(exc_info.value)
# Error must surface the offending env var name, the bad value, and
# at least one corrected example, so the user can fix it without
# having to grep the source.
self.assertIn("CHAT_MODEL", message)
self.assertIn("definitely-not-a-real-model-xyz", message)
self.assertIn("<provider>/<model>", message)
# The chained cause should still be the original LiteLLM exception so
# advanced users can inspect it.
self.assertIsInstance(exc_info.value.__cause__, BadRequestError)

def test_bare_unknown_embedding_model_raises_actionable_error(self) -> None:
with pytest.raises(RuntimeError) as exc_info:
_validate_litellm_model_provider("definitely-not-a-real-emb-xyz", "EMBEDDING_MODEL")
self.assertIn("EMBEDDING_MODEL", str(exc_info.value))


if __name__ == "__main__":
unittest.main()
Loading