diff --git a/docs/models/xai.md b/docs/models/xai.md index e0d0ebe08d..54c0444fac 100644 --- a/docs/models/xai.md +++ b/docs/models/xai.md @@ -57,6 +57,25 @@ agent = Agent(model) ... ``` +For gateway, regional, or proxy deployments you can also point the provider at a custom host and set a client-level default timeout, both of which are forwarded to the underlying `xai_sdk.AsyncClient`: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.xai import XaiModel +from pydantic_ai.providers.xai import XaiProvider + +provider = XaiProvider( + api_key='your-api-key', + api_host='gateway.example.com', + timeout=30, +) +model = XaiModel('grok-4.3', provider=provider) +agent = Agent(model) +... +``` + +`api_host` is the hostname of the xAI API server (the SDK connects over gRPC), and `timeout` is the default timeout in seconds applied to every request the client makes. The provider-level `timeout` is distinct from [`ModelSettings.timeout`][pydantic_ai.settings.ModelSettings.timeout], which overrides the timeout for an individual request. Both options are omitted when left unset, so the SDK's own defaults apply. + Or with a custom `xai_sdk.AsyncClient`: ```python diff --git a/pydantic_ai_slim/pydantic_ai/providers/xai.py b/pydantic_ai_slim/pydantic_ai/providers/xai.py index 93ecaef7d2..459577b4e2 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/xai.py +++ b/pydantic_ai_slim/pydantic_ai/providers/xai.py @@ -56,6 +56,8 @@ def name(self) -> str: @property def base_url(self) -> str: + # Canonical pricing/identity label, not the transport host: the xAI SDK is gRPC and the actual + # channel target is set via `api_host`. This URL is used for usage/price lookup and telemetry only. return 'https://api.x.ai/v1' @property @@ -69,10 +71,9 @@ def model_profile(model_name: str) -> ModelProfile | None: return grok_model_profile(model_name) @overload - def __init__(self) -> None: ... - - @overload - def __init__(self, *, api_key: str) -> None: ... + def __init__( + self, *, api_key: str | None = None, api_host: str | None = None, timeout: float | None = None + ) -> None: ... @overload def __init__(self, *, xai_client: AsyncClient) -> None: ... @@ -81,6 +82,8 @@ def __init__( self, *, api_key: str | None = None, + api_host: str | None = None, + timeout: float | None = None, xai_client: AsyncClient | None = None, ) -> None: """Create a new xAI provider. @@ -88,7 +91,12 @@ def __init__( Args: api_key: The API key to use for authentication, if not provided, the `XAI_API_KEY` environment variable will be used if available. - xai_client: An existing `xai_sdk.AsyncClient` to use. This takes precedence over `api_key`. + api_host: The API host to use for the xAI SDK client. + timeout: The client-level default timeout for the xAI SDK client, in seconds, applied to all requests + made through it. This is distinct from `ModelSettings.timeout`, which overrides the timeout for an + individual request. + xai_client: An existing `xai_sdk.AsyncClient` to use. This takes precedence over `api_key`, `api_host`, + and `timeout`. """ self._lazy_client: _LazyAsyncClient | None = None if xai_client is not None: @@ -100,5 +108,10 @@ def __init__( 'Set the `XAI_API_KEY` environment variable or pass it via `XaiProvider(api_key=...)`' ' to use the xAI provider.' ) - self._lazy_client = _LazyAsyncClient(api_key=api_key) + client_kwargs: dict[str, str | float] = {'api_key': api_key} + if api_host is not None: + client_kwargs['api_host'] = api_host + if timeout is not None: + client_kwargs['timeout'] = timeout + self._lazy_client = _LazyAsyncClient(**client_kwargs) self._client = None # type: ignore[assignment] diff --git a/tests/providers/test_xai.py b/tests/providers/test_xai.py index b52f179a22..aff03948c1 100644 --- a/tests/providers/test_xai.py +++ b/tests/providers/test_xai.py @@ -10,6 +10,7 @@ with try_import() as imports_successful: from xai_sdk import AsyncClient + from pydantic_ai.providers import xai as xai_provider_module from pydantic_ai.providers.xai import XaiProvider pytestmark = pytest.mark.skipif(not imports_successful(), reason='xai_sdk not installed') @@ -40,6 +41,33 @@ def test_xai_pass_xai_client() -> None: assert provider.client == xai_client +@pytest.fixture +def captured_client_kwargs(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, object]]: + """Capture the kwargs passed to the xAI `AsyncClient` constructor.""" + captured: list[dict[str, object]] = [] + + class FakeAsyncClient: + def __init__(self, **kwargs: object) -> None: + captured.append(kwargs) + + monkeypatch.setattr(xai_provider_module, 'AsyncClient', FakeAsyncClient) + return captured + + +def test_xai_provider_forwards_api_host_and_timeout(captured_client_kwargs: list[dict[str, object]]) -> None: + provider = XaiProvider(api_key='api-key', api_host='gateway.x.ai', timeout=30) + + assert provider.client is not None # triggers lazy client creation + assert captured_client_kwargs == [{'api_key': 'api-key', 'api_host': 'gateway.x.ai', 'timeout': 30}] + + +def test_xai_provider_omits_unset_client_kwargs(captured_client_kwargs: list[dict[str, object]]) -> None: + provider = XaiProvider(api_key='api-key') + + assert provider.client is not None # triggers lazy client creation + assert captured_client_kwargs == [{'api_key': 'api-key'}] + + def test_xai_model_profile(): from pydantic_ai.profiles.grok import GrokModelProfile