Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ We leverage [LiteLLM](https://github.com/BerriAI/litellm) to support chat comple

Note that regardless of the model pair used, an `OPENAI_API_KEY` will currently still be required to generate embeddings for the `mf` and `sw_ranking` routers.

B.AI's LLM service is supported through its OpenAI-compatible API. Set `BAI_API_KEY`, optionally set `BAI_API_BASE` if you need a custom endpoint, and prefix B.AI model names with `bai/`:

```
export BAI_API_KEY=sk-XXXXXX
python -m routellm.openai_server \
--routers mf \
--strong-model bai/gpt-5.2 \
--weak-model bai/claude-sonnet-4-6
```

Instructions for setting up your API keys for popular providers:
- Local models with Ollama: see [this guide](examples/routing_to_local_models.md)
- [Anthropic](https://litellm.vercel.app/docs/providers/anthropic#api-keys)
Expand Down
34 changes: 30 additions & 4 deletions routellm/controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from types import SimpleNamespace
Expand Down Expand Up @@ -32,6 +33,10 @@ class RoutingError(Exception):
pass


BAI_API_BASE = "https://api.b.ai/v1"
BAI_MODEL_PREFIX = "bai/"


@dataclass
class ModelPair:
strong: str
Expand Down Expand Up @@ -114,6 +119,25 @@ def _get_routed_model_for_completion(

return routed_model

def _completion_kwargs_for_model(self, model: str):
completion_kwargs = {
"api_base": self.api_base,
"api_key": self.api_key,
"model": model,
}

if model.startswith(BAI_MODEL_PREFIX):
completion_kwargs.update(
{
"api_base": self.api_base
or os.environ.get("BAI_API_BASE", BAI_API_BASE),
"api_key": self.api_key or os.environ.get("BAI_API_KEY"),
"model": f"openai/{model.removeprefix(BAI_MODEL_PREFIX)}",
}
)

return completion_kwargs

# Mainly used for evaluations
def batch_calculate_win_rate(
self,
Expand Down Expand Up @@ -147,10 +171,11 @@ def completion(
router, threshold = self._parse_model_name(kwargs["model"])

self._validate_router_threshold(router, threshold)
kwargs["model"] = self._get_routed_model_for_completion(
routed_model = self._get_routed_model_for_completion(
kwargs["messages"], router, threshold
)
return completion(api_base=self.api_base, api_key=self.api_key, **kwargs)
kwargs.update(self._completion_kwargs_for_model(routed_model))
return completion(**kwargs)

# Matches OpenAI's Async Chat Completions interface, but also supports optional router and threshold args
async def acompletion(
Expand All @@ -164,7 +189,8 @@ async def acompletion(
router, threshold = self._parse_model_name(kwargs["model"])

self._validate_router_threshold(router, threshold)
kwargs["model"] = self._get_routed_model_for_completion(
routed_model = self._get_routed_model_for_completion(
kwargs["messages"], router, threshold
)
return await acompletion(api_base=self.api_base, api_key=self.api_key, **kwargs)
kwargs.update(self._completion_kwargs_for_model(routed_model))
return await acompletion(**kwargs)
100 changes: 100 additions & 0 deletions routellm/tests/test_bai_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import asyncio
import sys
from types import ModuleType
from unittest.mock import AsyncMock, Mock, patch

routers_module = ModuleType("routellm.routers.routers")
routers_module.ROUTER_CLS = {}
sys.modules["routellm.routers.routers"] = routers_module

from routellm.controller import BAI_API_BASE, Controller

MESSAGES = [{"role": "user", "content": "hello"}]


class StubRouter:
def __init__(self, model):
self.model = model

def route(self, prompt, threshold, model_pair):
return self.model


def controller_for_model(model, **kwargs):
controller = Controller(
routers=[],
strong_model="strong",
weak_model="weak",
**kwargs,
)
controller.routers["random"] = StubRouter(model)
return controller


def test_completion_uses_bai_defaults(monkeypatch):
monkeypatch.setenv("BAI_API_KEY", "bai-key")
monkeypatch.delenv("BAI_API_BASE", raising=False)
controller = controller_for_model("bai/gpt-5.2")

with patch("routellm.controller.completion", Mock()) as completion:
controller.completion(router="random", threshold=0.5, messages=MESSAGES)

completion.assert_called_once_with(
api_base=BAI_API_BASE,
api_key="bai-key",
model="openai/gpt-5.2",
messages=MESSAGES,
)


def test_completion_uses_explicit_bai_api_overrides(monkeypatch):
monkeypatch.setenv("BAI_API_KEY", "env-key")
monkeypatch.setenv("BAI_API_BASE", "https://env.example/v1")
controller = controller_for_model(
"bai/gpt-5.2",
api_base="https://custom.example/v1",
api_key="custom-key",
)

with patch("routellm.controller.completion", Mock()) as completion:
controller.completion(router="random", threshold=0.5, messages=MESSAGES)

completion.assert_called_once_with(
api_base="https://custom.example/v1",
api_key="custom-key",
model="openai/gpt-5.2",
messages=MESSAGES,
)


def test_acompletion_uses_bai_defaults(monkeypatch):
monkeypatch.setenv("BAI_API_KEY", "bai-key")
monkeypatch.delenv("BAI_API_BASE", raising=False)
controller = controller_for_model("bai/gpt-5.2")

with patch("routellm.controller.acompletion", AsyncMock()) as acompletion:
asyncio.run(
controller.acompletion(router="random", threshold=0.5, messages=MESSAGES)
)

acompletion.assert_awaited_once_with(
api_base=BAI_API_BASE,
api_key="bai-key",
model="openai/gpt-5.2",
messages=MESSAGES,
)


def test_completion_leaves_non_bai_model_unchanged(monkeypatch):
monkeypatch.setenv("BAI_API_KEY", "bai-key")
controller = controller_for_model("anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1")

with patch("routellm.controller.completion", Mock()) as completion:
controller.completion(router="random", threshold=0.5, messages=MESSAGES)

completion.assert_called_once_with(
api_base=None,
api_key=None,
model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
messages=MESSAGES,
)