Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions rdagent/app/utils/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import docker
import fire
import litellm
from litellm import completion, embedding
from litellm.utils import ModelResponse

from rdagent.log import rdagent_logger as logger
from rdagent.utils.env import cleanup_container
Expand Down Expand Up @@ -51,24 +48,11 @@ def check_and_list_free_ports(start_port=19899, max_ports=10) -> None:
def test_chat(chat_model, chat_api_key, chat_api_base):
logger.info(f"🧪 Testing chat model: {chat_model}")
try:
if chat_api_base is None:
response: ModelResponse = completion(
model=chat_model,
api_key=chat_api_key,
messages=[
{"role": "user", "content": "Hello!"},
],
)
else:
response: ModelResponse = completion(
model=chat_model,
api_key=chat_api_key,
api_base=chat_api_base,
messages=[
{"role": "user", "content": "Hello!"},
],
)
logger.info(f"✅ Chat test passed.")
from rdagent.oai.backend.litellm import LiteLLMAPIBackend

backend = LiteLLMAPIBackend()
backend.build_messages_and_create_chat_completion(user_prompt="Hello!")
logger.info("✅ Chat test passed.")
return True
except Exception as e:
logger.error(f"❌ Chat test failed: {e}")
Expand All @@ -78,12 +62,10 @@ def test_chat(chat_model, chat_api_key, chat_api_base):
def test_embedding(embedding_model, embedding_api_key, embedding_api_base):
logger.info(f"🧪 Testing embedding model: {embedding_model}")
try:
response = embedding(
model=embedding_model,
api_key=embedding_api_key,
api_base=embedding_api_base,
input="Hello world!",
)
from rdagent.oai.backend.litellm import LiteLLMAPIBackend

backend = LiteLLMAPIBackend()
backend.create_embedding(input_content="Hello world!")
logger.info("✅ Embedding test passed.")
return True
except Exception as e:
Expand Down Expand Up @@ -117,6 +99,13 @@ def env_check():
embedding_model = os.getenv("EMBEDDING_MODEL")
embedding_api_key = chat_api_key
embedding_api_base = chat_api_base
elif "CHAT_OPENAI_COMPATIBLE_API_KEY" in os.environ or "EMBEDDING_OPENAI_COMPATIBLE_API_KEY" in os.environ:
chat_api_key = os.getenv("CHAT_OPENAI_COMPATIBLE_API_KEY")
chat_api_base = os.getenv("CHAT_OPENAI_COMPATIBLE_API_BASE")
chat_model = os.getenv("CHAT_MODEL")
embedding_model = os.getenv("EMBEDDING_MODEL")
embedding_api_key = os.getenv("EMBEDDING_OPENAI_COMPATIBLE_API_KEY")
embedding_api_base = os.getenv("EMBEDDING_OPENAI_COMPATIBLE_API_BASE")
else:
logger.error("No valid configuration was found, please check your .env file.")

Expand Down
28 changes: 24 additions & 4 deletions rdagent/oai/backend/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,24 @@ def _create_embedding_inner_function(self, input_content_list: list[str]) -> lis
f"{LogColors.MAGENTA}Creating embedding{LogColors.END} for: {input_content_list}",
tag="debug_litellm_emb",
)
response = embedding(
model=model_name,
input=input_content_list,
)
call_kwargs = {
"model": model_name,
"input": input_content_list,
}
if LITELLM_SETTINGS.embedding_extra_params:
if isinstance(LITELLM_SETTINGS.embedding_extra_params, dict):
call_kwargs.update(LITELLM_SETTINGS.embedding_extra_params)
else:
logger.error(
f"{LogColors.RED}embedding_extra_params must be a dict, got {type(LITELLM_SETTINGS.embedding_extra_params).__name__}. Ignoring extra params.{LogColors.END}",
tag="debug_litellm_emb",
)
# Use embedding OpenAI-Compatible config
if LITELLM_SETTINGS.embedding_openai_compatible_api_key:
call_kwargs["api_key"] = LITELLM_SETTINGS.embedding_openai_compatible_api_key
if LITELLM_SETTINGS.embedding_openai_compatible_api_base:
call_kwargs["api_base"] = LITELLM_SETTINGS.embedding_openai_compatible_api_base
response = embedding(**call_kwargs)
response_list = [data["embedding"] for data in response.data]
return response_list

Expand Down Expand Up @@ -152,6 +166,12 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
complete_kwargs = self.get_complete_kwargs()
model = complete_kwargs["model"]

# Use OpenAI-Compatible API config for chat completion
if LITELLM_SETTINGS.chat_openai_compatible_api_key:
complete_kwargs["api_key"] = LITELLM_SETTINGS.chat_openai_compatible_api_key
if LITELLM_SETTINGS.chat_openai_compatible_api_base:
complete_kwargs["api_base"] = LITELLM_SETTINGS.chat_openai_compatible_api_base

response = completion(
messages=messages,
stream=LITELLM_SETTINGS.chat_stream,
Expand Down
7 changes: 7 additions & 0 deletions rdagent/oai/llm_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class LLMSettings(ExtendedBaseSettings):
chat_openai_base_url: str | None = None #
chat_azure_api_base: str = ""
chat_azure_api_version: str = ""
# OpenAI-Compatible API config (for chat completion)
chat_openai_compatible_api_key: str = ""
chat_openai_compatible_api_base: str = ""
chat_max_tokens: int | None = None
chat_temperature: float = 0.5
chat_stream: bool = True
Expand All @@ -87,6 +90,10 @@ class LLMSettings(ExtendedBaseSettings):
embedding_azure_api_version: str = ""
embedding_max_str_num: int = 50
embedding_max_length: int = 8192
embedding_extra_params: dict = {}
# OpenAI-Compatible API config (for embedding)
embedding_openai_compatible_api_key: str = ""
embedding_openai_compatible_api_base: str = ""

# offline llama2 related config
use_llama2: bool = False
Expand Down