Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/models/google.md
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,25 @@ avg_logprobs = result.response.provider_details.get('avg_logprobs')

See the [Google Dev Blog](https://developers.googleblog.com/unlock-gemini-reasoning-with-logprobs-on-vertex-ai/) for more information.

### Context caching (`google_cached_content`)

When you've created a Gemini [cached content resource](https://ai.google.dev/gemini-api/docs/caching), pass its resource name through [`google_cached_content`][pydantic_ai.models.google.GoogleModelSettings.google_cached_content] to reuse it across requests:

```python {test="skip"}
from pydantic_ai import Agent
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings

model_settings = GoogleModelSettings(
google_cached_content='projects/p/locations/global/cachedContents/your-cache-id',
)

agent = Agent(GoogleModel('gemini-2.5-pro'), model_settings=model_settings)
result = agent.run_sync('Summarise the cached document')
```

!!! warning "Cached fields are owned by the cache resource"
The cache owns `system_instruction`, `tools`, and `tool_config`; both the Gemini API and Vertex AI reject requests that supply them alongside `cached_content` (`400 INVALID_ARGUMENT`). Pydantic AI strips those fields from the outgoing request when `google_cached_content` is set, so agent instructions and registered tools are ignored on cached requests — a `UserWarning` is emitted whenever stripping drops a field so the mismatch is discoverable.

## Streaming cancellation

!!! warning "Cancellation limitations"
Expand Down
49 changes: 45 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ class GoogleModelSettings(ModelSettings, total=False):
google_cached_content: str
"""The name of the cached content to use for the model.

When set, `system_instruction`, `tools`, and `tool_config` are omitted from
the outgoing request — the cached content resource owns those fields, and
both the Gemini API and Vertex AI reject requests that supply them
alongside `cached_content` (`400 INVALID_ARGUMENT`: "Tool config, tools and
system instruction should not be set in the request when using cached
content."). Any tools registered on the agent and any system prompt are
therefore ignored on requests that go through the cache; a `UserWarning`
is emitted whenever stripping actually drops a field so the mismatch is
discoverable.

See <https://ai.google.dev/gemini-api/docs/caching> for more information.
"""

Expand Down Expand Up @@ -325,6 +335,30 @@ def _get_deprecated_google_service_tier(model_settings: GoogleModelSettings) ->
return None


def _warn_on_cached_content_strips(
cached_content: str | None,
system_instruction: ContentDict | None,
tools: list[ToolDict] | None,
) -> None:
"""Emit a `UserWarning` when `google_cached_content` would strip a field that the caller populated."""
if not cached_content:
return
dropped: list[str] = []
if system_instruction is not None:
dropped.append('system_instruction')
if tools is not None:
dropped.extend(('tools', 'tool_config'))
if dropped:
warnings.warn(
f'`google_cached_content` is set; the cached content resource owns '
f'{dropped}, so these fields are stripped from the outgoing request '
f'and any agent instructions or registered tools are ignored. '
f'See https://ai.google.dev/gemini-api/docs/caching.',
Comment thread
dsfaccini marked this conversation as resolved.
UserWarning,
stacklevel=3,
)


def _get_deprecated_google_vertex_service_tier(model_settings: GoogleModelSettings) -> GoogleCloudServiceTier | None:
"""Return `google_vertex_service_tier`, emitting a `PydanticAIDeprecationWarning` when it is set.

Expand Down Expand Up @@ -884,9 +918,16 @@ async def _build_content_and_config(
else:
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')

# Vertex AI and the Gemini API reject requests that combine
# `cached_content` with `system_instruction`, `tools`, or `tool_config`
# (`400 INVALID_ARGUMENT`). The cache resource itself owns those
# fields, so we pass `None` for them when caching is requested.
Comment thread
dsfaccini marked this conversation as resolved.
Outdated
cached_content = model_settings.get('google_cached_content')
_warn_on_cached_content_strips(cached_content, system_instruction, tools)

config = GenerateContentConfigDict(
http_options=http_options,
system_instruction=system_instruction,
system_instruction=None if cached_content else system_instruction,
temperature=model_settings.get('temperature'),
top_p=model_settings.get('top_p'),
top_k=model_settings.get('top_k'),
Expand All @@ -899,9 +940,9 @@ async def _build_content_and_config(
thinking_config=self._translate_thinking(model_settings, model_request_parameters),
labels=model_settings.get('google_labels'),
media_resolution=model_settings.get('google_video_resolution'),
cached_content=model_settings.get('google_cached_content'),
tools=cast(ToolListUnionDict, tools),
tool_config=tool_config,
cached_content=cached_content,
tools=None if cached_content else cast(ToolListUnionDict, tools),
tool_config=None if cached_content else tool_config,
Comment thread
dsfaccini marked this conversation as resolved.
response_mime_type=response_mime_type,
response_json_schema=response_schema,
response_modalities=modalities,
Expand Down
88 changes: 88 additions & 0 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import tempfile
from collections.abc import AsyncIterator
from contextlib import nullcontext
from datetime import date, timezone
from decimal import Decimal
from typing import Any, cast
Expand Down Expand Up @@ -6589,3 +6590,90 @@ async def test_google_top_k_propagation(
assert mock_generate.call_count == 1
_, kwargs = mock_generate.call_args
assert kwargs['config']['top_k'] == 40


@pytest.mark.parametrize(
('cached', 'instructions', 'register_tool', 'expect_warning'),
[
pytest.param(True, 'You are a helpful chatbot.', True, True, id='cache_set_with_instructions_and_tools'),
pytest.param(False, 'You are a helpful chatbot.', True, False, id='cache_unset_with_instructions_and_tools'),
pytest.param(True, None, False, False, id='cache_set_no_instructions_no_tools'),
],
)
@pytest.mark.parametrize('stream', [False, True])
async def test_google_model_cached_content_request_config(
allow_model_requests: None,
google_provider: GoogleProvider,
mocker: MockerFixture,
cached: bool,
instructions: str | None,
register_tool: bool,
expect_warning: bool,
stream: bool,
):
"""When `google_cached_content` is set, the outgoing request omits
`system_instruction`, `tools`, and `tool_config` — the cache resource
owns those fields. When unset, the request still carries them. A
`UserWarning` is emitted whenever caching strips agent instructions or
registered tools so the mismatch is discoverable.

Parametrized over `stream` because `_build_content_and_config` is shared
by `request()` (`generate_content`) and `request_stream()`
(`generate_content_stream`) — both paths must apply the same omission.
See issue #5671.
Comment thread
dsfaccini marked this conversation as resolved.
Outdated
"""
cache_name = 'projects/p/locations/global/cachedContents/test-cache'
model = GoogleModel('gemini-2.5-pro', provider=google_provider)

chunk = GenerateContentResponse(
candidates=[
Candidate(
content=Content(parts=[Part(text='Paris')], role='model'),
finish_reason=GoogleFinishReason.STOP,
)
],
response_id='cached',
model_version='gemini-2.5-pro',
)

if stream:

async def stream_iterator():
yield chunk

mock = mocker.patch.object(model.client.aio.models, 'generate_content_stream', return_value=stream_iterator())
else:
mock = mocker.patch.object(model.client.aio.models, 'generate_content', return_value=chunk)

settings = GoogleModelSettings(google_cached_content=cache_name) if cached else GoogleModelSettings()
agent = Agent(model=model, instructions=instructions, model_settings=settings)

if register_tool:

@agent.tool_plain
def echo(text: str) -> str:
return text # pragma: no cover

warning_ctx = pytest.warns(UserWarning, match='`google_cached_content` is set') if expect_warning else nullcontext()
with warning_ctx:
if stream:
async with agent.run_stream('say hi') as result:
await result.get_output()
else:
await agent.run('say hi')

assert mock.call_count == 1
_, kwargs = mock.call_args
config = kwargs['config']

if cached:
assert config['cached_content'] == cache_name
# The three cache-owned fields must be absent (or unset) on the request.
assert not config.get('system_instruction')
assert not config.get('tools')
assert not config.get('tool_config')
else:
assert not config.get('cached_content')
assert config['system_instruction']
assert config['tools']
assert config['tool_config'] is not None
Loading