diff --git a/rdagent/app/utils/health_check.py b/rdagent/app/utils/health_check.py index 95597b257..467d2d23d 100644 --- a/rdagent/app/utils/health_check.py +++ b/rdagent/app/utils/health_check.py @@ -3,9 +3,6 @@ import docker import fire -import litellm -from litellm import completion, embedding -from litellm.utils import ModelResponse from rdagent.log import rdagent_logger as logger from rdagent.utils.env import cleanup_container @@ -51,24 +48,11 @@ def check_and_list_free_ports(start_port=19899, max_ports=10) -> None: def test_chat(chat_model, chat_api_key, chat_api_base): logger.info(f"๐Ÿงช Testing chat model: {chat_model}") try: - if chat_api_base is None: - response: ModelResponse = completion( - model=chat_model, - api_key=chat_api_key, - messages=[ - {"role": "user", "content": "Hello!"}, - ], - ) - else: - response: ModelResponse = completion( - model=chat_model, - api_key=chat_api_key, - api_base=chat_api_base, - messages=[ - {"role": "user", "content": "Hello!"}, - ], - ) - logger.info(f"โœ… Chat test passed.") + from rdagent.oai.backend.litellm import LiteLLMAPIBackend + + backend = LiteLLMAPIBackend() + backend.build_messages_and_create_chat_completion(user_prompt="Hello!") + logger.info("โœ… Chat test passed.") return True except Exception as e: logger.error(f"โŒ Chat test failed: {e}") @@ -78,12 +62,10 @@ def test_chat(chat_model, chat_api_key, chat_api_base): def test_embedding(embedding_model, embedding_api_key, embedding_api_base): logger.info(f"๐Ÿงช Testing embedding model: {embedding_model}") try: - response = embedding( - model=embedding_model, - api_key=embedding_api_key, - api_base=embedding_api_base, - input="Hello world!", - ) + from rdagent.oai.backend.litellm import LiteLLMAPIBackend + + backend = LiteLLMAPIBackend() + backend.create_embedding(input_content="Hello world!") logger.info("โœ… Embedding test passed.") return True except Exception as e: @@ -117,6 +99,13 @@ def env_check(): embedding_model = os.getenv("EMBEDDING_MODEL") embedding_api_key = chat_api_key embedding_api_base = chat_api_base + elif "CHAT_OPENAI_COMPATIBLE_API_KEY" in os.environ or "EMBEDDING_OPENAI_COMPATIBLE_API_KEY" in os.environ: + chat_api_key = os.getenv("CHAT_OPENAI_COMPATIBLE_API_KEY") + chat_api_base = os.getenv("CHAT_OPENAI_COMPATIBLE_API_BASE") + chat_model = os.getenv("CHAT_MODEL") + embedding_model = os.getenv("EMBEDDING_MODEL") + embedding_api_key = os.getenv("EMBEDDING_OPENAI_COMPATIBLE_API_KEY") + embedding_api_base = os.getenv("EMBEDDING_OPENAI_COMPATIBLE_API_BASE") else: logger.error("No valid configuration was found, please check your .env file.") diff --git a/rdagent/oai/backend/litellm.py b/rdagent/oai/backend/litellm.py index 86cbe8461..9dd489000 100644 --- a/rdagent/oai/backend/litellm.py +++ b/rdagent/oai/backend/litellm.py @@ -79,10 +79,24 @@ def _create_embedding_inner_function(self, input_content_list: list[str]) -> lis f"{LogColors.MAGENTA}Creating embedding{LogColors.END} for: {input_content_list}", tag="debug_litellm_emb", ) - response = embedding( - model=model_name, - input=input_content_list, - ) + call_kwargs = { + "model": model_name, + "input": input_content_list, + } + if LITELLM_SETTINGS.embedding_extra_params: + if isinstance(LITELLM_SETTINGS.embedding_extra_params, dict): + call_kwargs.update(LITELLM_SETTINGS.embedding_extra_params) + else: + logger.error( + f"{LogColors.RED}embedding_extra_params must be a dict, got {type(LITELLM_SETTINGS.embedding_extra_params).__name__}. Ignoring extra params.{LogColors.END}", + tag="debug_litellm_emb", + ) + # Use embedding OpenAI-Compatible config + if LITELLM_SETTINGS.embedding_openai_compatible_api_key: + call_kwargs["api_key"] = LITELLM_SETTINGS.embedding_openai_compatible_api_key + if LITELLM_SETTINGS.embedding_openai_compatible_api_base: + call_kwargs["api_base"] = LITELLM_SETTINGS.embedding_openai_compatible_api_base + response = embedding(**call_kwargs) response_list = [data["embedding"] for data in response.data] return response_list @@ -152,6 +166,12 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no complete_kwargs = self.get_complete_kwargs() model = complete_kwargs["model"] + # Use OpenAI-Compatible API config for chat completion + if LITELLM_SETTINGS.chat_openai_compatible_api_key: + complete_kwargs["api_key"] = LITELLM_SETTINGS.chat_openai_compatible_api_key + if LITELLM_SETTINGS.chat_openai_compatible_api_base: + complete_kwargs["api_base"] = LITELLM_SETTINGS.chat_openai_compatible_api_base + response = completion( messages=messages, stream=LITELLM_SETTINGS.chat_stream, diff --git a/rdagent/oai/llm_conf.py b/rdagent/oai/llm_conf.py index a9a1130e7..708b494b0 100644 --- a/rdagent/oai/llm_conf.py +++ b/rdagent/oai/llm_conf.py @@ -66,6 +66,9 @@ class LLMSettings(ExtendedBaseSettings): chat_openai_base_url: str | None = None # chat_azure_api_base: str = "" chat_azure_api_version: str = "" + # OpenAI-Compatible API config (for chat completion) + chat_openai_compatible_api_key: str = "" + chat_openai_compatible_api_base: str = "" chat_max_tokens: int | None = None chat_temperature: float = 0.5 chat_stream: bool = True @@ -87,6 +90,10 @@ class LLMSettings(ExtendedBaseSettings): embedding_azure_api_version: str = "" embedding_max_str_num: int = 50 embedding_max_length: int = 8192 + embedding_extra_params: dict = {} + # OpenAI-Compatible API config (for embedding) + embedding_openai_compatible_api_key: str = "" + embedding_openai_compatible_api_base: str = "" # offline llama2 related config use_llama2: bool = False