Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
elif isinstance(part, _messages.CompactionPart):
if part.content:
compaction_text += part.content
elif isinstance(part, _messages.ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history produce no streamed event here.
pass
else:
assert_never(part)

Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ToolCallPart,
ToolCallPartDelta,
ToolPartKind,
ToolReturnPart,
)

from ._utils import generate_tool_call_id as _generate_tool_call_id
Expand Down Expand Up @@ -526,6 +527,10 @@ def _resolve_provider_name(
self, existing_part: ModelResponsePart | ToolCallPartDelta, provider_name: str | None
) -> str | None:
"""Return the provider name if it has not been set on previous parts."""
# `ToolReturnPart` is a valid `ModelResponsePart` member but is never tracked by the parts
# manager, and unlike the other members it carries no `provider_name`.
if isinstance(existing_part, ToolReturnPart):
return provider_name
Comment thread
dsfaccini marked this conversation as resolved.
if existing_part.provider_name is None or provider_name != existing_part.provider_name:
return provider_name
return None
29 changes: 26 additions & 3 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,7 @@ def _model_response_part_discriminator(v: Any) -> str | None:
| Annotated[NativeToolCallPart, pydantic.Tag('builtin-tool-call')]
| Annotated[NativeToolSearchReturnPart, pydantic.Tag('builtin-tool-search-return')]
| Annotated[NativeToolReturnPart, pydantic.Tag('builtin-tool-return')]
| Annotated[ToolReturnPart, pydantic.Tag('tool-return')]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM This is the key design decision in the PR and I think it warrants explicit maintainer sign-off before proceeding.

Context: Issue #5721 was auto-generated by a roundtrip-sweep bot. No maintainer has commented on the issue or endorsed a specific approach. The PR author acknowledges the design choice in the PR description ("If you'd prefer a narrower serialization-only approach… happy to refactor").

The question: Should ToolReturnPart be a valid member of ModelResponsePart? The framework itself never places a ToolReturnPart into a ModelResponse_agent_graph.py always stores tool returns in ModelRequest.parts, and model adapters use NativeToolReturnPart for builtin returns. The only way a ToolReturnPart ends up in a ModelResponse is via user-constructed message history, which would currently fail at deserialization.

Trade-offs:

  • This approach (widening the union): fixes the deserialization crash and is defensive, but forces every consumer of ModelResponse.parts across ~20 files to handle a case the framework never produces. All new handler branches are # pragma: no cover dead code.
  • Alternative: serialization-only fix (e.g. a lenient custom discriminator that gracefully handles unknown/unexpected tags without widening the static type union): narrower change, fewer files touched, but less type-safe.
  • Alternative: reject at construction: validate that ModelResponse.parts doesn't contain request-side parts, making the user error explicit rather than silently accepting it.

The right call depends on whether "user-constructed ModelResponse containing ToolReturnPart" is something the framework should support or discourage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

David's AICA here: Decision on the design question (Douwe is away, so I'm making the call): we'll keep the union-widening approach — adding the base ToolReturnPart to ModelResponsePart. It's the approach with direct precedent and explicit version-policy cover; the two alternatives both reintroduce problems we've deliberately avoided.

Reasoning:

One correction to the framing above: the framework never produces a base ToolReturnPart on a ModelResponse — it routes tool returns into ModelRequest.parts. The shape is reachable only via user-constructed or deserialized message history, which is exactly the round-trip case #5721 hits, so the fix belongs at the type/deserialization layer.

| Annotated[ThinkingPart, pydantic.Tag('thinking')]
| Annotated[CompactionPart, pydantic.Tag('compaction')]
| Annotated[FilePart, pydantic.Tag('file')],
Expand Down Expand Up @@ -2285,7 +2286,7 @@ def new_event_body():

return result

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: # noqa: C901
Comment thread
dsfaccini marked this conversation as resolved.
parts: list[_otel_messages.MessagePart] = []
for part in self.parts:
if isinstance(part, TextPart):
Expand Down Expand Up @@ -2333,6 +2334,10 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
return_part['result'] = serialize_any(part.content)

parts.append(return_part)
elif isinstance(part, ToolReturnPart):
# A user-defined tool return can appear here via user-constructed message history; map it
# like its request-side counterpart (no `builtin` flag) so it isn't dropped from telemetry.
parts.extend(part.otel_message_parts(settings))
elif isinstance(part, CompactionPart):
# Compaction parts don't map to standard OTel message part types
pass
Expand Down Expand Up @@ -2710,7 +2715,16 @@ class PartStartEvent:
"""The newly started `ModelResponsePart`."""

previous_part_kind: (
Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'compaction', 'file']
Literal[
'text',
'thinking',
'tool-call',
'builtin-tool-call',
'builtin-tool-return',
'tool-return',
'compaction',
'file',
]
Comment thread
dsfaccini marked this conversation as resolved.
| None
) = None
"""The kind of the previous part, if any.
Expand Down Expand Up @@ -2751,7 +2765,16 @@ class PartEndEvent:
"""The complete `ModelResponsePart`."""

next_part_kind: (
Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'compaction', 'file']
Literal[
'text',
'thinking',
'tool-call',
'builtin-tool-call',
'builtin-tool-return',
'tool-return',
'compaction',
'file',
]
| None
) = None
"""The kind of the next part, if any.
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,9 @@ async def _map_message( # noqa: C901
elif isinstance(response_part, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
elif isinstance(response_part, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(response_part)
if len(assistant_content_params) > 0:
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,9 +1066,13 @@ async def _map_messages( # noqa: C901
elif isinstance(item, CompactionPart | FilePart):
# Compaction and file parts are not sent back to models that don't support them.
pass # pragma: no cover
else:
assert isinstance(item, ToolCallPart)
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
elif isinstance(item, ToolCallPart):
content.append(self._map_tool_call(item))
else:
assert_never(item)
Comment thread
dsfaccini marked this conversation as resolved.
if content:
bedrock_messages.append({'role': 'assistant', 'content': content})
else:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _map_messages(
elif isinstance(item, ToolCallPart):
tool_calls.append(self._map_tool_call(item))
elif isinstance(
item, NativeToolCallPart | NativeToolReturnPart | FilePart | CompactionPart
item, NativeToolCallPart | NativeToolReturnPart | FilePart | CompactionPart | ToolReturnPart
): # pragma: no cover
pass
else:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
response_tokens += _estimate_string_tokens(part.content)
elif isinstance(part, ToolCallPart | NativeToolCallPart):
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
elif isinstance(part, NativeToolReturnPart):
elif isinstance(part, NativeToolReturnPart | ToolReturnPart):
response_tokens += _estimate_string_tokens(part.model_response_str())
elif isinstance(part, FilePart):
response_tokens += _estimate_string_tokens([part.content])
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,9 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(item)
return _GeminiContent(role='model', parts=parts)
Expand Down
8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,7 @@ def _map_code_execution_result(self, code_execution_result: CodeExecutionResult)
assert self._code_execution_tool_call_id is not None
return _map_code_execution_result(code_execution_result, self.provider_name, self._code_execution_tool_call_id)

def _handle_executable_code_streaming(self, executable_code: ExecutableCode) -> ModelResponsePart:
def _handle_executable_code_streaming(self, executable_code: ExecutableCode) -> NativeToolCallPart:
Comment thread
dsfaccini marked this conversation as resolved.
"""Handle executable code for streaming responses.

Returns a NativeToolCallPart for file search or code execution.
Expand Down Expand Up @@ -1592,6 +1592,9 @@ def _content_model_response(
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
part = None
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
part = None
else:
assert_never(item)

Expand All @@ -1611,6 +1614,9 @@ def _decode_inline_thought_signature(
Returns the raw signature bytes ready to embed in a `PartDict`, or `None` if no signature
applies (either missing, or the response originated from a different provider).
"""
if isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns carry no provider signature.
return None
if not item.provider_details:
return None
if m.provider_name not in accepted_provider_names and item.provider_name not in accepted_provider_names:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ async def _map_messages(
elif isinstance(item, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
elif isinstance(item, CompactionPart | ToolReturnPart): # pragma: no cover
# Compaction parts and user-constructed tool returns are not sent back to the provider.
pass
else:
assert_never(item)
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ async def _map_messages(
elif isinstance(item, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
elif isinstance(item, CompactionPart | ToolReturnPart): # pragma: no cover
# Compaction parts and user-constructed tool returns are not sent back to the provider.
pass
else:
assert_never(item)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ async def _map_messages( # noqa: C901
elif isinstance(part, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(part, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(part)
if thinking_chunks:
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,9 @@ def map_assistant_message(self, message: ModelResponse) -> chat.ChatCompletionAs
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to the Chat Completions API.
pass
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(item)
return self._into_message_param()
Expand Down Expand Up @@ -2769,6 +2772,9 @@ async def _map_messages( # noqa: C901
file_search_item: responses.ResponseFileSearchToolCallParam | None = None
code_interpreter_item: responses.ResponseCodeInterpreterToolCallParam | None = None
for item in message.parts:
if isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
continue
should_send_item_id = send_item_ids and (
item.provider_name == self.system
or (item.provider_name is None and message.provider_name == self.system)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ async def _format_prompt( # noqa: C901
elif isinstance(part, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(part, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(part)
if len(text_parts) == 1 and len(image_parts) == 0:
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(part, CompactionPart): # pragma: no cover
# NOTE: There's no way to reach this part of the code, since we don't generate CompactionPart on TestModel.
assert False, "This should be unreachable — we don't generate CompactionPart on TestModel."
elif isinstance(part, ToolReturnPart): # pragma: no cover
# NOTE: There's no way to reach this part of the code, since we don't generate ToolReturnPart on TestModel.
assert False, "This should be unreachable — we don't generate ToolReturnPart on TestModel."
else:
assert_never(part)

Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def _map_response_parts(self, parts: Sequence[ModelResponsePart]) -> list[chat_t
elif isinstance(item, CompactionPart): # pragma: no cover
# Compaction parts are not sent back to models that don't support compaction.
pass
elif isinstance(item, ToolReturnPart): # pragma: no cover
# User-defined tool returns in user-constructed message history are not replayed to the provider.
pass
else:
assert_never(item)

Expand Down
7 changes: 5 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: # no
case _:
pass

async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]:
async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]: # noqa: C901
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the second # noqa: C901 added by this PR (the first being on otel_message_parts). Both are match/isinstance dispatches that grow by one branch per new ModelResponsePart member — the complexity is inherent to the union's size. Not actionable right now (it was kept for otel_message_parts in the previous review), but noting it for @DouweM's awareness: each new union member adds a branch to ~5 dispatch methods across the codebase, and two of those methods now suppress the complexity check.

"""Handle a `PartStartEvent`.

This method dispatches to specific `handle_*` methods based on part type:
Expand Down Expand Up @@ -381,6 +381,9 @@ async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT
case CompactionPart(): # pragma: no cover
async for e in self.handle_compaction(part):
yield e
case ToolReturnPart(): # pragma: no cover
# User-defined tool returns in user-constructed message history have no UI start event.
pass

async def handle_part_delta(self, event: PartDeltaEvent) -> AsyncIterator[EventT]:
"""Handle a PartDeltaEvent.
Expand Down Expand Up @@ -440,7 +443,7 @@ async def handle_part_end(self, event: PartEndEvent) -> AsyncIterator[EventT]:
case NativeToolCallPart():
async for e in self.handle_builtin_tool_call_end(part):
yield e
case NativeToolReturnPart() | FilePart() | CompactionPart(): # pragma: no cover
case NativeToolReturnPart() | FilePart() | CompactionPart() | ToolReturnPart(): # pragma: no cover
# These don't have deltas, so they don't need to be ended.
pass

Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ def flush() -> None:
)
elif isinstance(part, CompactionPart): # pragma: no cover
pass # Compaction parts are not rendered in AG-UI
elif isinstance(part, ToolReturnPart): # pragma: no cover
pass # User-defined tool returns in user-constructed message history are not rendered in AG-UI
else:
assert_never(part)

Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def _dump_response_message(
ui_parts.extend(cls._dump_tool_call_part(part, tool_results, sdk_version))
elif isinstance(part, CompactionPart): # pragma: no cover
pass # Compaction parts are not rendered in the UI
elif isinstance(part, ToolReturnPart): # pragma: no cover
pass # User-defined tool returns in user-constructed message history are not rendered in the UI
else:
assert_never(part)

Expand Down
30 changes: 29 additions & 1 deletion tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.function import (
AgentInfo,
Comment thread
dsfaccini marked this conversation as resolved.
DeltaToolCall,
DeltaToolCalls,
FunctionModel,
_estimate_usage, # pyright: ignore[reportPrivateUsage]
)
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import RunUsage
from pydantic_ai.usage import RequestUsage
Expand Down Expand Up @@ -567,3 +573,25 @@ async def test_return_empty():
with pytest.raises(ValueError, match='Stream function must return at least one item'):
async with agent.run_stream(''):
pass


def test_estimate_usage_handles_tool_return_part_in_response():
"""Regression for #5721: a `ToolReturnPart` stored on a `ModelResponse` must be handled.

Adding the base `ToolReturnPart` to the `ModelResponsePart` union means every consumer
that iterates `ModelResponse.parts` (here the usage estimator) must handle the new
variant rather than falling through to `assert_never`. The framework stores these parts
on responses for user-defined output tools, so they appear in real message history.
"""
messages: list[ModelMessage] = [
ModelResponse(
parts=[
TextPart(content='hello'),
ToolReturnPart(tool_name='my_tool', content='tool result here', tool_call_id='call-1'),
]
)
]
estimated = _estimate_usage(messages)
# Text + the tool return content are both counted; the key assertion is that no
# `assert_never` is raised for the response-embedded `ToolReturnPart`.
assert estimated.output_tokens is not None and estimated.output_tokens > 0
4 changes: 3 additions & 1 deletion tests/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11525,7 +11525,9 @@ async def test_openai_responses_compact_stateful_mode_stream(allow_model_request
e for e in all_events if isinstance(e, PartStartEvent) and isinstance(e.part, CompactionPart)
]
assert compaction_start_events, 'expected PartStartEvent for CompactionPart during streaming'
assert compaction_start_events[0].part.provider_name == 'openai'
first_compaction_part = compaction_start_events[0].part
assert isinstance(first_compaction_part, CompactionPart)
assert first_compaction_part.provider_name == 'openai'

# Verify final messages contain the CompactionPart with encrypted_content
compaction_parts = [
Expand Down
Loading
Loading