From 2b30ddda11008333b70f925b2a2ce8c2089e10b7 Mon Sep 17 00:00:00 2001 From: Cole McIntosh Date: Mon, 1 Jun 2026 15:49:59 -0500 Subject: [PATCH] Add `api_host` and `timeout` to `XaiProvider` Forward the xAI SDK client's `api_host` (gateway/regional/proxy setups) and `timeout` (client-level default) options through `XaiProvider`, so users don't have to drop down to a prebuilt `xai_client` for common deployment ergonomics. Unset options are omitted so the SDK's own defaults apply. --- pydantic_ai_slim/pydantic_ai/providers/xai.py | 21 ++++++++++---- tests/providers/test_xai.py | 28 +++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/providers/xai.py b/pydantic_ai_slim/pydantic_ai/providers/xai.py index 93ecaef7d2..dba4715c44 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/xai.py +++ b/pydantic_ai_slim/pydantic_ai/providers/xai.py @@ -69,10 +69,9 @@ def model_profile(model_name: str) -> ModelProfile | None: return grok_model_profile(model_name) @overload - def __init__(self) -> None: ... - - @overload - def __init__(self, *, api_key: str) -> None: ... + def __init__( + self, *, api_key: str | None = None, api_host: str | None = None, timeout: float | None = None + ) -> None: ... @overload def __init__(self, *, xai_client: AsyncClient) -> None: ... @@ -81,6 +80,8 @@ def __init__( self, *, api_key: str | None = None, + api_host: str | None = None, + timeout: float | None = None, xai_client: AsyncClient | None = None, ) -> None: """Create a new xAI provider. @@ -88,7 +89,10 @@ def __init__( Args: api_key: The API key to use for authentication, if not provided, the `XAI_API_KEY` environment variable will be used if available. - xai_client: An existing `xai_sdk.AsyncClient` to use. This takes precedence over `api_key`. + api_host: The API host to use for the xAI SDK client. + timeout: The default timeout for the xAI SDK client, in seconds. + xai_client: An existing `xai_sdk.AsyncClient` to use. This takes precedence over `api_key`, `api_host`, + and `timeout`. """ self._lazy_client: _LazyAsyncClient | None = None if xai_client is not None: @@ -100,5 +104,10 @@ def __init__( 'Set the `XAI_API_KEY` environment variable or pass it via `XaiProvider(api_key=...)`' ' to use the xAI provider.' ) - self._lazy_client = _LazyAsyncClient(api_key=api_key) + client_kwargs: dict[str, Any] = {'api_key': api_key} + if api_host is not None: + client_kwargs['api_host'] = api_host + if timeout is not None: + client_kwargs['timeout'] = timeout + self._lazy_client = _LazyAsyncClient(**client_kwargs) self._client = None # type: ignore[assignment] diff --git a/tests/providers/test_xai.py b/tests/providers/test_xai.py index 9d30da42ce..8c1de21742 100644 --- a/tests/providers/test_xai.py +++ b/tests/providers/test_xai.py @@ -10,6 +10,7 @@ with try_import() as imports_successful: from xai_sdk import AsyncClient + from pydantic_ai.providers import xai as xai_provider_module from pydantic_ai.providers.xai import XaiProvider pytestmark = pytest.mark.skipif(not imports_successful(), reason='xai_sdk not installed') @@ -40,6 +41,33 @@ def test_xai_pass_xai_client() -> None: assert provider.client == xai_client +@pytest.fixture +def captured_client_kwargs(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, object]]: + """Capture the kwargs passed to the xAI `AsyncClient` constructor.""" + captured: list[dict[str, object]] = [] + + class FakeAsyncClient: + def __init__(self, **kwargs: object) -> None: + captured.append(kwargs) + + monkeypatch.setattr(xai_provider_module, 'AsyncClient', FakeAsyncClient) + return captured + + +def test_xai_provider_forwards_api_host_and_timeout(captured_client_kwargs: list[dict[str, object]]) -> None: + provider = XaiProvider(api_key='api-key', api_host='gateway.x.ai', timeout=30) + + assert provider.client is not None # triggers lazy client creation + assert captured_client_kwargs == [{'api_key': 'api-key', 'api_host': 'gateway.x.ai', 'timeout': 30}] + + +def test_xai_provider_omits_unset_client_kwargs(captured_client_kwargs: list[dict[str, object]]) -> None: + provider = XaiProvider(api_key='api-key') + + assert provider.client is not None # triggers lazy client creation + assert captured_client_kwargs == [{'api_key': 'api-key'}] + + def test_xai_model_profile(): from pydantic_ai.profiles.grok import GrokModelProfile