Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
elif isinstance(part, _messages.CompactionPart):
if part.content:
compaction_text += part.content
elif isinstance(part, _messages.ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history produce no streamed event here.
pass
else:
assert_never(part)

Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ToolCallPart,
ToolCallPartDelta,
ToolPartKind,
ToolReturnPart,
)

from ._utils import generate_tool_call_id as _generate_tool_call_id
Expand Down Expand Up @@ -526,6 +527,10 @@ def _resolve_provider_name(
self, existing_part: ModelResponsePart | ToolCallPartDelta, provider_name: str | None
) -> str | None:
"""Return the provider name if it has not been set on previous parts."""
# `ToolReturnPart` is a valid `ModelResponsePart` member but is never tracked by the parts
# manager, and unlike the other members it carries no `provider_name`.
if isinstance(existing_part, ToolReturnPart):
return provider_name
Comment thread
dsfaccini marked this conversation as resolved.
if existing_part.provider_name is None or provider_name != existing_part.provider_name:
return provider_name
return None
29 changes: 26 additions & 3 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,7 @@ def _model_response_part_discriminator(v: Any) -> str | None:
| Annotated[NativeToolCallPart, pydantic.Tag('builtin-tool-call')]
| Annotated[NativeToolSearchReturnPart, pydantic.Tag('builtin-tool-search-return')]
| Annotated[NativeToolReturnPart, pydantic.Tag('builtin-tool-return')]
| Annotated[ToolReturnPart, pydantic.Tag('tool-return')]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM This is the key design decision in the PR and I think it warrants explicit maintainer sign-off before proceeding.

Context: Issue #5721 was auto-generated by a roundtrip-sweep bot. No maintainer has commented on the issue or endorsed a specific approach. The PR author acknowledges the design choice in the PR description ("If you'd prefer a narrower serialization-only approach… happy to refactor").

The question: Should ToolReturnPart be a valid member of ModelResponsePart? The framework itself never places a ToolReturnPart into a ModelResponse_agent_graph.py always stores tool returns in ModelRequest.parts, and model adapters use NativeToolReturnPart for builtin returns. The only way a ToolReturnPart ends up in a ModelResponse is via user-constructed message history, which would currently fail at deserialization.

Trade-offs:

  • This approach (widening the union): fixes the deserialization crash and is defensive, but forces every consumer of ModelResponse.parts across ~20 files to handle a case the framework never produces. All new handler branches are # pragma: no cover dead code.
  • Alternative: serialization-only fix (e.g. a lenient custom discriminator that gracefully handles unknown/unexpected tags without widening the static type union): narrower change, fewer files touched, but less type-safe.
  • Alternative: reject at construction: validate that ModelResponse.parts doesn't contain request-side parts, making the user error explicit rather than silently accepting it.

The right call depends on whether "user-constructed ModelResponse containing ToolReturnPart" is something the framework should support or discourage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

David's AICA here: Decision on the design question (Douwe is away, so I'm making the call): we'll keep the union-widening approach — adding the base ToolReturnPart to ModelResponsePart. It's the approach with direct precedent and explicit version-policy cover; the two alternatives both reintroduce problems we've deliberately avoided.

Reasoning:

One correction to the framing above: the framework never produces a base ToolReturnPart on a ModelResponse — it routes tool returns into ModelRequest.parts. The shape is reachable only via user-constructed or deserialized message history, which is exactly the round-trip case #5721 hits, so the fix belongs at the type/deserialization layer.

| Annotated[ThinkingPart, pydantic.Tag('thinking')]
| Annotated[CompactionPart, pydantic.Tag('compaction')]
| Annotated[FilePart, pydantic.Tag('file')],
Expand Down Expand Up @@ -2285,7 +2286,7 @@ def new_event_body():

return result

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: # noqa: C901
Comment thread
dsfaccini marked this conversation as resolved.
parts: list[_otel_messages.MessagePart] = []
for part in self.parts:
if isinstance(part, TextPart):
Expand Down Expand Up @@ -2333,6 +2334,10 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
return_part['result'] = serialize_any(part.content)

parts.append(return_part)
elif isinstance(part, ToolReturnPart):
# A user-defined tool return can appear here via user-constructed message history; map it
# like its request-side counterpart (no `builtin` flag) so it isn't dropped from telemetry.
parts.extend(part.otel_message_parts(settings))
elif isinstance(part, CompactionPart):
# Compaction parts don't map to standard OTel message part types
pass
Expand Down Expand Up @@ -2710,7 +2715,16 @@ class PartStartEvent:
"""The newly started `ModelResponsePart`."""

previous_part_kind: (
Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'compaction', 'file']
Literal[
'text',
'thinking',
'tool-call',
'builtin-tool-call',
'builtin-tool-return',
'tool-return',
'compaction',
'file',
]
Comment thread
dsfaccini marked this conversation as resolved.
| None
) = None
"""The kind of the previous part, if any.
Expand Down Expand Up @@ -2751,7 +2765,16 @@ class PartEndEvent:
"""The complete `ModelResponsePart`."""

next_part_kind: (
Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'compaction', 'file']
Literal[
'text',
'thinking',
'tool-call',
'builtin-tool-call',
'builtin-tool-return',
'tool-return',
'compaction',
'file',
]
| None
) = None
"""The kind of the next part, if any.
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,9 @@ async def _map_message( # noqa: C901
elif isinstance(response_part, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
elif isinstance(response_part, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(response_part)
if len(assistant_content_params) > 0:
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,9 +1066,13 @@ async def _map_messages( # noqa: C901
elif isinstance(item, CompactionPart | FilePart):
# Compaction and file parts are not sent back to models that don't support them.
pass # pragma: no cover
else:
assert isinstance(item, ToolCallPart)
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
elif isinstance(item, ToolCallPart):
content.append(self._map_tool_call(item))
else:
assert_never(item)
Comment thread
dsfaccini marked this conversation as resolved.
if content:
bedrock_messages.append({'role': 'assistant', 'content': content})
else:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _map_messages(
elif isinstance(item, ToolCallPart):
tool_calls.append(self._map_tool_call(item))
elif isinstance(
item, NativeToolCallPart | NativeToolReturnPart | FilePart | CompactionPart
item, NativeToolCallPart | NativeToolReturnPart | FilePart | CompactionPart | ToolReturnPart
): # pragma: no cover
pass
else:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
response_tokens += _estimate_string_tokens(part.content)
elif isinstance(part, ToolCallPart | NativeToolCallPart):
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
elif isinstance(part, NativeToolReturnPart):
elif isinstance(part, NativeToolReturnPart | ToolReturnPart):
response_tokens += _estimate_string_tokens(part.model_response_str())
elif isinstance(part, FilePart):
response_tokens += _estimate_string_tokens([part.content])
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,9 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(item)
return _GeminiContent(role='model', parts=parts)
Expand Down
8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,7 @@ def _map_code_execution_result(self, code_execution_result: CodeExecutionResult)
assert self._code_execution_tool_call_id is not None
return _map_code_execution_result(code_execution_result, self.provider_name, self._code_execution_tool_call_id)

def _handle_executable_code_streaming(self, executable_code: ExecutableCode) -> ModelResponsePart:
def _handle_executable_code_streaming(self, executable_code: ExecutableCode) -> NativeToolCallPart:
Comment thread
dsfaccini marked this conversation as resolved.
"""Handle executable code for streaming responses.

Returns a NativeToolCallPart for file search or code execution.
Expand Down Expand Up @@ -1592,6 +1592,9 @@ def _content_model_response(
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
part = None
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
part = None
else:
assert_never(item)

Expand All @@ -1611,6 +1614,9 @@ def _decode_inline_thought_signature(
Returns the raw signature bytes ready to embed in a `PartDict`, or `None` if no signature
applies (either missing, or the response originated from a different provider).
"""
if isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns carry no provider signature.
return None
if not item.provider_details:
return None
if m.provider_name not in accepted_provider_names and item.provider_name not in accepted_provider_names:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ async def _map_messages(
elif isinstance(item, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
elif isinstance(item, CompactionPart | ToolReturnPart): # pragma: no cover
# Compaction parts and user-constructed tool returns are not sent back to the provider.
pass
else:
assert_never(item)
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ async def _map_messages(
elif isinstance(item, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
elif isinstance(item, CompactionPart | ToolReturnPart): # pragma: no cover
# Compaction parts and user-constructed tool returns are not sent back to the provider.
pass
else:
assert_never(item)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ async def _map_messages( # noqa: C901
elif isinstance(part, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(part, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(part)
if thinking_chunks:
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,9 @@ def map_assistant_message(self, message: ModelResponse) -> chat.ChatCompletionAs
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to the Chat Completions API.
pass
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(item)
return self._into_message_param()
Expand Down Expand Up @@ -2769,6 +2772,9 @@ async def _map_messages( # noqa: C901
file_search_item: responses.ResponseFileSearchToolCallParam | None = None
code_interpreter_item: responses.ResponseCodeInterpreterToolCallParam | None = None
for item in message.parts:
if isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
continue
should_send_item_id = send_item_ids and (
item.provider_name == self.system
or (item.provider_name is None and message.provider_name == self.system)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ async def _format_prompt( # noqa: C901
elif isinstance(part, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(part, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(part)
if len(text_parts) == 1 and len(image_parts) == 0:
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(part, CompactionPart): # pragma: no cover
# NOTE: There's no way to reach this part of the code, since we don't generate CompactionPart on TestModel.
assert False, "This should be unreachable — we don't generate CompactionPart on TestModel."
elif isinstance(part, ToolReturnPart): # pragma: no cover
# NOTE: There's no way to reach this part of the code, since we don't generate ToolReturnPart on TestModel.
assert False, "This should be unreachable — we don't generate ToolReturnPart on TestModel."
else:
assert_never(part)

Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def _map_response_parts(self, parts: Sequence[ModelResponsePart]) -> list[chat_t
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(item)

Expand Down
7 changes: 5 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: # no
case _:
pass

async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]:
async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]: # noqa: C901
"""Handle a `PartStartEvent`.

This method dispatches to specific `handle_*` methods based on part type:
Expand Down Expand Up @@ -381,6 +381,9 @@ async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT
case CompactionPart(): # pragma: no cover
async for e in self.handle_compaction(part):
yield e
case ToolReturnPart(): # pragma: no cover
# User-defined tool returns in user-constructed message history have no UI start event.
pass

async def handle_part_delta(self, event: PartDeltaEvent) -> AsyncIterator[EventT]:
"""Handle a PartDeltaEvent.
Expand Down Expand Up @@ -440,7 +443,7 @@ async def handle_part_end(self, event: PartEndEvent) -> AsyncIterator[EventT]:
case NativeToolCallPart():
async for e in self.handle_builtin_tool_call_end(part):
yield e
case NativeToolReturnPart() | FilePart() | CompactionPart(): # pragma: no cover
case NativeToolReturnPart() | FilePart() | CompactionPart() | ToolReturnPart(): # pragma: no cover
# These don't have deltas, so they don't need to be ended.
pass

Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ def flush() -> None:
)
elif isinstance(part, CompactionPart): # pragma: no cover
pass # Compaction parts are not rendered in AG-UI
elif isinstance(part, ToolReturnPart): # pragma: no cover
pass # User-defined tool returns in user-constructed message history are not rendered in AG-UI
else:
assert_never(part)

Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def _dump_response_message(
ui_parts.extend(cls._dump_tool_call_part(part, tool_results, sdk_version))
elif isinstance(part, CompactionPart): # pragma: no cover
pass # Compaction parts are not rendered in the UI
elif isinstance(part, ToolReturnPart): # pragma: no cover
pass # User-defined tool returns in user-constructed message history are not rendered in the UI
else:
assert_never(part)

Expand Down
36 changes: 35 additions & 1 deletion tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.function import (
AgentInfo,
Comment thread
dsfaccini marked this conversation as resolved.
DeltaToolCall,
DeltaToolCalls,
FunctionModel,
)
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import RunUsage
from pydantic_ai.usage import RequestUsage
Expand Down Expand Up @@ -567,3 +572,32 @@ async def test_return_empty():
with pytest.raises(ValueError, match='Stream function must return at least one item'):
async with agent.run_stream(''):
pass


async def test_estimate_usage_handles_tool_return_part_in_response():
"""Regression for #5721: a `ToolReturnPart` on a `ModelResponse` must be handled.

A base `ToolReturnPart` can appear on a `ModelResponse` only via user-constructed or
deserialized message history (the framework itself always routes it into `ModelRequest.parts`).
Adding it to the `ModelResponsePart` union means every consumer iterating `ModelResponse.parts`
must handle it rather than fall through to `assert_never`. `FunctionModel` reaches that path
when it estimates usage of the request history, so running an agent over such a history
exercises the estimator through the public API.
"""

def return_text(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart('done')])

agent = Agent(FunctionModel(return_text))
result = await agent.run(
'hello',
message_history=[
ModelResponse(
parts=[
TextPart(content='hello'),
ToolReturnPart(tool_name='my_tool', content='tool result here', tool_call_id='call-1'),
]
)
],
)
assert result.usage == snapshot(RunUsage(input_tokens=51, output_tokens=5, requests=1))
4 changes: 3 additions & 1 deletion tests/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11525,7 +11525,9 @@ async def test_openai_responses_compact_stateful_mode_stream(allow_model_request
e for e in all_events if isinstance(e, PartStartEvent) and isinstance(e.part, CompactionPart)
]
assert compaction_start_events, 'expected PartStartEvent for CompactionPart during streaming'
assert compaction_start_events[0].part.provider_name == 'openai'
first_compaction_part = compaction_start_events[0].part
assert isinstance(first_compaction_part, CompactionPart)
assert first_compaction_part.provider_name == 'openai'

# Verify final messages contain the CompactionPart with encrypted_content
compaction_parts = [
Expand Down
Loading
Loading