diff --git a/README.md b/README.md index 45cbe52..6669280 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,16 @@ We leverage [LiteLLM](https://github.com/BerriAI/litellm) to support chat comple Note that regardless of the model pair used, an `OPENAI_API_KEY` will currently still be required to generate embeddings for the `mf` and `sw_ranking` routers. +B.AI's LLM service is supported through its OpenAI-compatible API. Set `BAI_API_KEY`, optionally set `BAI_API_BASE` if you need a custom endpoint, and prefix B.AI model names with `bai/`: + +``` +export BAI_API_KEY=sk-XXXXXX +python -m routellm.openai_server \ + --routers mf \ + --strong-model bai/gpt-5.2 \ + --weak-model bai/claude-sonnet-4-6 +``` + Instructions for setting up your API keys for popular providers: - Local models with Ollama: see [this guide](examples/routing_to_local_models.md) - [Anthropic](https://litellm.vercel.app/docs/providers/anthropic#api-keys) diff --git a/routellm/controller.py b/routellm/controller.py index 8a02a05..f9c10d4 100644 --- a/routellm/controller.py +++ b/routellm/controller.py @@ -1,3 +1,4 @@ +import os from collections import defaultdict from dataclasses import dataclass from types import SimpleNamespace @@ -32,6 +33,10 @@ class RoutingError(Exception): pass +BAI_API_BASE = "https://api.b.ai/v1" +BAI_MODEL_PREFIX = "bai/" + + @dataclass class ModelPair: strong: str @@ -114,6 +119,25 @@ def _get_routed_model_for_completion( return routed_model + def _completion_kwargs_for_model(self, model: str): + completion_kwargs = { + "api_base": self.api_base, + "api_key": self.api_key, + "model": model, + } + + if model.startswith(BAI_MODEL_PREFIX): + completion_kwargs.update( + { + "api_base": self.api_base + or os.environ.get("BAI_API_BASE", BAI_API_BASE), + "api_key": self.api_key or os.environ.get("BAI_API_KEY"), + "model": f"openai/{model.removeprefix(BAI_MODEL_PREFIX)}", + } + ) + + return completion_kwargs + # Mainly used for evaluations def batch_calculate_win_rate( self, @@ -147,10 +171,11 @@ def completion( router, threshold = self._parse_model_name(kwargs["model"]) self._validate_router_threshold(router, threshold) - kwargs["model"] = self._get_routed_model_for_completion( + routed_model = self._get_routed_model_for_completion( kwargs["messages"], router, threshold ) - return completion(api_base=self.api_base, api_key=self.api_key, **kwargs) + kwargs.update(self._completion_kwargs_for_model(routed_model)) + return completion(**kwargs) # Matches OpenAI's Async Chat Completions interface, but also supports optional router and threshold args async def acompletion( @@ -164,7 +189,8 @@ async def acompletion( router, threshold = self._parse_model_name(kwargs["model"]) self._validate_router_threshold(router, threshold) - kwargs["model"] = self._get_routed_model_for_completion( + routed_model = self._get_routed_model_for_completion( kwargs["messages"], router, threshold ) - return await acompletion(api_base=self.api_base, api_key=self.api_key, **kwargs) + kwargs.update(self._completion_kwargs_for_model(routed_model)) + return await acompletion(**kwargs) diff --git a/routellm/tests/test_bai_support.py b/routellm/tests/test_bai_support.py new file mode 100644 index 0000000..1e3cfdd --- /dev/null +++ b/routellm/tests/test_bai_support.py @@ -0,0 +1,100 @@ +import asyncio +import sys +from types import ModuleType +from unittest.mock import AsyncMock, Mock, patch + +routers_module = ModuleType("routellm.routers.routers") +routers_module.ROUTER_CLS = {} +sys.modules["routellm.routers.routers"] = routers_module + +from routellm.controller import BAI_API_BASE, Controller + +MESSAGES = [{"role": "user", "content": "hello"}] + + +class StubRouter: + def __init__(self, model): + self.model = model + + def route(self, prompt, threshold, model_pair): + return self.model + + +def controller_for_model(model, **kwargs): + controller = Controller( + routers=[], + strong_model="strong", + weak_model="weak", + **kwargs, + ) + controller.routers["random"] = StubRouter(model) + return controller + + +def test_completion_uses_bai_defaults(monkeypatch): + monkeypatch.setenv("BAI_API_KEY", "bai-key") + monkeypatch.delenv("BAI_API_BASE", raising=False) + controller = controller_for_model("bai/gpt-5.2") + + with patch("routellm.controller.completion", Mock()) as completion: + controller.completion(router="random", threshold=0.5, messages=MESSAGES) + + completion.assert_called_once_with( + api_base=BAI_API_BASE, + api_key="bai-key", + model="openai/gpt-5.2", + messages=MESSAGES, + ) + + +def test_completion_uses_explicit_bai_api_overrides(monkeypatch): + monkeypatch.setenv("BAI_API_KEY", "env-key") + monkeypatch.setenv("BAI_API_BASE", "https://env.example/v1") + controller = controller_for_model( + "bai/gpt-5.2", + api_base="https://custom.example/v1", + api_key="custom-key", + ) + + with patch("routellm.controller.completion", Mock()) as completion: + controller.completion(router="random", threshold=0.5, messages=MESSAGES) + + completion.assert_called_once_with( + api_base="https://custom.example/v1", + api_key="custom-key", + model="openai/gpt-5.2", + messages=MESSAGES, + ) + + +def test_acompletion_uses_bai_defaults(monkeypatch): + monkeypatch.setenv("BAI_API_KEY", "bai-key") + monkeypatch.delenv("BAI_API_BASE", raising=False) + controller = controller_for_model("bai/gpt-5.2") + + with patch("routellm.controller.acompletion", AsyncMock()) as acompletion: + asyncio.run( + controller.acompletion(router="random", threshold=0.5, messages=MESSAGES) + ) + + acompletion.assert_awaited_once_with( + api_base=BAI_API_BASE, + api_key="bai-key", + model="openai/gpt-5.2", + messages=MESSAGES, + ) + + +def test_completion_leaves_non_bai_model_unchanged(monkeypatch): + monkeypatch.setenv("BAI_API_KEY", "bai-key") + controller = controller_for_model("anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1") + + with patch("routellm.controller.completion", Mock()) as completion: + controller.completion(router="random", threshold=0.5, messages=MESSAGES) + + completion.assert_called_once_with( + api_base=None, + api_key=None, + model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", + messages=MESSAGES, + )