diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index c33462fbf3..f18c857bec 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -1198,6 +1198,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa elif isinstance(part, _messages.CompactionPart): if part.content: compaction_text += part.content + elif isinstance(part, _messages.ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history produce no streamed event here. + pass else: assert_never(part) diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 3f64db0fff..13070cc17c 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -32,6 +32,7 @@ ToolCallPart, ToolCallPartDelta, ToolPartKind, + ToolReturnPart, ) from ._utils import generate_tool_call_id as _generate_tool_call_id @@ -526,6 +527,10 @@ def _resolve_provider_name( self, existing_part: ModelResponsePart | ToolCallPartDelta, provider_name: str | None ) -> str | None: """Return the provider name if it has not been set on previous parts.""" + # `ToolReturnPart` is a valid `ModelResponsePart` member but is never tracked by the parts + # manager, and unlike the other members it carries no `provider_name`. + if isinstance(existing_part, ToolReturnPart): + return provider_name if existing_part.provider_name is None or provider_name != existing_part.provider_name: return provider_name return None diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index ac866433a2..664dd01606 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -2075,6 +2075,7 @@ def _model_response_part_discriminator(v: Any) -> str | None: | Annotated[NativeToolCallPart, pydantic.Tag('builtin-tool-call')] | Annotated[NativeToolSearchReturnPart, pydantic.Tag('builtin-tool-search-return')] | Annotated[NativeToolReturnPart, pydantic.Tag('builtin-tool-return')] + | Annotated[ToolReturnPart, pydantic.Tag('tool-return')] | Annotated[ThinkingPart, pydantic.Tag('thinking')] | Annotated[CompactionPart, pydantic.Tag('compaction')] | Annotated[FilePart, pydantic.Tag('file')], @@ -2294,7 +2295,7 @@ def new_event_body(): return result - def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: # noqa: C901 parts: list[_otel_messages.MessagePart] = [] for part in self.parts: if isinstance(part, TextPart): @@ -2342,6 +2343,10 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me return_part['result'] = serialize_any(part.content) parts.append(return_part) + elif isinstance(part, ToolReturnPart): + # A user-defined tool return can appear here via user-constructed message history; map it + # like its request-side counterpart (no `builtin` flag) so it isn't dropped from telemetry. + parts.extend(part.otel_message_parts(settings)) elif isinstance(part, CompactionPart): # Compaction parts don't map to standard OTel message part types pass @@ -2719,7 +2724,16 @@ class PartStartEvent: """The newly started `ModelResponsePart`.""" previous_part_kind: ( - Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'compaction', 'file'] + Literal[ + 'text', + 'thinking', + 'tool-call', + 'builtin-tool-call', + 'builtin-tool-return', + 'tool-return', + 'compaction', + 'file', + ] | None ) = None """The kind of the previous part, if any. @@ -2760,7 +2774,16 @@ class PartEndEvent: """The complete `ModelResponsePart`.""" next_part_kind: ( - Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'compaction', 'file'] + Literal[ + 'text', + 'thinking', + 'tool-call', + 'builtin-tool-call', + 'builtin-tool-return', + 'tool-return', + 'compaction', + 'file', + ] | None ) = None """The kind of the next part, if any. diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 4d6bc41457..255f1f36be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -1533,6 +1533,9 @@ async def _map_message( # noqa: C901 elif isinstance(response_part, FilePart): # pragma: no cover # Files generated by models are not sent back to models that don't themselves generate files. pass + elif isinstance(response_part, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass else: assert_never(response_part) if len(assistant_content_params) > 0: diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index f84fbc6299..40396fcbbd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -1066,9 +1066,13 @@ async def _map_messages( # noqa: C901 elif isinstance(item, CompactionPart | FilePart): # Compaction and file parts are not sent back to models that don't support them. pass # pragma: no cover - else: - assert isinstance(item, ToolCallPart) + elif isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass + elif isinstance(item, ToolCallPart): content.append(self._map_tool_call(item)) + else: + assert_never(item) if content: bedrock_messages.append({'role': 'assistant', 'content': content}) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 26cc08c673..67cf4b8732 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -253,7 +253,7 @@ def _map_messages( elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) elif isinstance( - item, NativeToolCallPart | NativeToolReturnPart | FilePart | CompactionPart + item, NativeToolCallPart | NativeToolReturnPart | FilePart | CompactionPart | ToolReturnPart ): # pragma: no cover pass else: diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index bd5d5844de..dbe85e94a2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -410,7 +410,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage: response_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolCallPart | NativeToolCallPart): response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str()) - elif isinstance(part, NativeToolReturnPart): + elif isinstance(part, NativeToolReturnPart | ToolReturnPart): response_tokens += _estimate_string_tokens(part.model_response_str()) elif isinstance(part, FilePart): response_tokens += _estimate_string_tokens([part.content]) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 85634bdbfe..a6b80fd94f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -683,6 +683,9 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent: elif isinstance(item, CompactionPart): # pragma: no cover # Compaction parts are not sent back to models that don't support compaction. pass + elif isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass else: assert_never(item) return _GeminiContent(role='model', parts=parts) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 449d802d20..5635f6ef3c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -1529,7 +1529,7 @@ def _map_code_execution_result(self, code_execution_result: CodeExecutionResult) assert self._code_execution_tool_call_id is not None return _map_code_execution_result(code_execution_result, self.provider_name, self._code_execution_tool_call_id) - def _handle_executable_code_streaming(self, executable_code: ExecutableCode) -> ModelResponsePart: + def _handle_executable_code_streaming(self, executable_code: ExecutableCode) -> NativeToolCallPart: """Handle executable code for streaming responses. Returns a NativeToolCallPart for file search or code execution. @@ -1636,6 +1636,9 @@ def _content_model_response( elif isinstance(item, CompactionPart): # pragma: no cover # Compaction parts are not sent back to models that don't support compaction. part = None + elif isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + part = None else: assert_never(item) @@ -1655,6 +1658,9 @@ def _decode_inline_thought_signature( Returns the raw signature bytes ready to embed in a `PartDict`, or `None` if no signature applies (either missing, or the response originated from a different provider). """ + if isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns carry no provider signature. + return None if not item.provider_details: return None if m.provider_name not in accepted_provider_names and item.provider_name not in accepted_provider_names: diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index c429d21d12..805b390f6a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -483,8 +483,8 @@ async def _map_messages( elif isinstance(item, FilePart): # pragma: no cover # Files generated by models are not sent back to models that don't themselves generate files. pass - elif isinstance(item, CompactionPart): # pragma: no cover - # Compaction parts are not sent back to models that don't support compaction. + elif isinstance(item, CompactionPart | ToolReturnPart): # pragma: no cover + # Compaction parts and user-constructed tool returns are not sent back to the provider. pass else: assert_never(item) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 253288764f..f126032684 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -400,8 +400,8 @@ async def _map_messages( elif isinstance(item, FilePart): # pragma: no cover # Files generated by models are not sent back to models that don't themselves generate files. pass - elif isinstance(item, CompactionPart): # pragma: no cover - # Compaction parts are not sent back to models that don't support compaction. + elif isinstance(item, CompactionPart | ToolReturnPart): # pragma: no cover + # Compaction parts and user-constructed tool returns are not sent back to the provider. pass else: assert_never(item) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 6bc04d9664..d26dcc2888 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -598,6 +598,9 @@ async def _map_messages( # noqa: C901 elif isinstance(part, CompactionPart): # pragma: no cover # Compaction parts are not sent back to models that don't support compaction. pass + elif isinstance(part, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass else: assert_never(part) if thinking_chunks: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 10b7035b96..97ddda3cac 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1312,6 +1312,9 @@ def map_assistant_message(self, message: ModelResponse) -> chat.ChatCompletionAs elif isinstance(item, CompactionPart): # pragma: no cover # Compaction parts are not sent back to the Chat Completions API. pass + elif isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass else: assert_never(item) return self._into_message_param() @@ -2776,6 +2779,9 @@ async def _map_messages( # noqa: C901 file_search_item: responses.ResponseFileSearchToolCallParam | None = None code_interpreter_item: responses.ResponseCodeInterpreterToolCallParam | None = None for item in message.parts: + if isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + continue should_send_item_id = send_item_ids and ( item.provider_name == self.system or (item.provider_name is None and message.provider_name == self.system) diff --git a/pydantic_ai_slim/pydantic_ai/models/outlines.py b/pydantic_ai_slim/pydantic_ai/models/outlines.py index 394d660514..d696d09bb6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/outlines.py +++ b/pydantic_ai_slim/pydantic_ai/models/outlines.py @@ -506,6 +506,9 @@ async def _format_prompt( # noqa: C901 elif isinstance(part, CompactionPart): # pragma: no cover # Compaction parts are not sent back to models that don't support compaction. pass + elif isinstance(part, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass else: assert_never(part) if len(text_parts) == 1 and len(image_parts) == 0: diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 37adb79978..7caa4c690d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -365,6 +365,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(part, CompactionPart): # pragma: no cover # NOTE: There's no way to reach this part of the code, since we don't generate CompactionPart on TestModel. assert False, "This should be unreachable — we don't generate CompactionPart on TestModel." + elif isinstance(part, ToolReturnPart): # pragma: no cover + # NOTE: There's no way to reach this part of the code, since we don't generate ToolReturnPart on TestModel. + assert False, "This should be unreachable — we don't generate ToolReturnPart on TestModel." else: assert_never(part) diff --git a/pydantic_ai_slim/pydantic_ai/models/xai.py b/pydantic_ai_slim/pydantic_ai/models/xai.py index a17a2b9313..32854a03d6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/xai.py +++ b/pydantic_ai_slim/pydantic_ai/models/xai.py @@ -438,6 +438,9 @@ def _map_response_parts(self, parts: Sequence[ModelResponsePart]) -> list[chat_t elif isinstance(item, CompactionPart): # pragma: no cover # Compaction parts are not sent back to models that don't support compaction. pass + elif isinstance(item, ToolReturnPart): # pragma: no cover + # User-defined tool returns in user-constructed message history are not replayed to the provider. + pass else: assert_never(item) diff --git a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py index d5259da7e1..92b442fb3b 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py @@ -338,7 +338,7 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: # no case _: pass - async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]: + async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]: # noqa: C901 """Handle a `PartStartEvent`. This method dispatches to specific `handle_*` methods based on part type: @@ -381,6 +381,9 @@ async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT case CompactionPart(): # pragma: no cover async for e in self.handle_compaction(part): yield e + case ToolReturnPart(): # pragma: no cover + # User-defined tool returns in user-constructed message history have no UI start event. + pass async def handle_part_delta(self, event: PartDeltaEvent) -> AsyncIterator[EventT]: """Handle a PartDeltaEvent. @@ -440,7 +443,7 @@ async def handle_part_end(self, event: PartEndEvent) -> AsyncIterator[EventT]: case NativeToolCallPart(): async for e in self.handle_builtin_tool_call_end(part): yield e - case NativeToolReturnPart() | FilePart() | CompactionPart(): # pragma: no cover + case NativeToolReturnPart() | FilePart() | CompactionPart() | ToolReturnPart(): # pragma: no cover # These don't have deltas, so they don't need to be ended. pass diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index 44e315bb23..9ab7e360a2 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -678,6 +678,8 @@ def flush() -> None: ) elif isinstance(part, CompactionPart): # pragma: no cover pass # Compaction parts are not rendered in AG-UI + elif isinstance(part, ToolReturnPart): # pragma: no cover + pass # User-defined tool returns in user-constructed message history are not rendered in AG-UI else: assert_never(part) diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index 855a9dc3a7..12d07b0463 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -682,6 +682,8 @@ def _dump_response_message( ui_parts.extend(cls._dump_tool_call_part(part, tool_results, sdk_version)) elif isinstance(part, CompactionPart): # pragma: no cover pass # Compaction parts are not rendered in the UI + elif isinstance(part, ToolReturnPart): # pragma: no cover + pass # User-defined tool returns in user-constructed message history are not rendered in the UI else: assert_never(part) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 46097d0746..f4286c465f 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -21,7 +21,12 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel +from pydantic_ai.models.function import ( + AgentInfo, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, +) from pydantic_ai.models.test import TestModel from pydantic_ai.result import RunUsage from pydantic_ai.usage import RequestUsage @@ -567,3 +572,32 @@ async def test_return_empty(): with pytest.raises(ValueError, match='Stream function must return at least one item'): async with agent.run_stream(''): pass + + +async def test_estimate_usage_handles_tool_return_part_in_response(): + """Regression for #5721: a `ToolReturnPart` on a `ModelResponse` must be handled. + + A base `ToolReturnPart` can appear on a `ModelResponse` only via user-constructed or + deserialized message history (the framework itself always routes it into `ModelRequest.parts`). + Adding it to the `ModelResponsePart` union means every consumer iterating `ModelResponse.parts` + must handle it rather than fall through to `assert_never`. `FunctionModel` reaches that path + when it estimates usage of the request history, so running an agent over such a history + exercises the estimator through the public API. + """ + + def return_text(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('done')]) + + agent = Agent(FunctionModel(return_text)) + result = await agent.run( + 'hello', + message_history=[ + ModelResponse( + parts=[ + TextPart(content='hello'), + ToolReturnPart(tool_name='my_tool', content='tool result here', tool_call_id='call-1'), + ] + ) + ], + ) + assert result.usage == snapshot(RunUsage(input_tokens=51, output_tokens=5, requests=1)) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 84c13d37fa..ebb002de5e 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -11525,7 +11525,9 @@ async def test_openai_responses_compact_stateful_mode_stream(allow_model_request e for e in all_events if isinstance(e, PartStartEvent) and isinstance(e.part, CompactionPart) ] assert compaction_start_events, 'expected PartStartEvent for CompactionPart during streaming' - assert compaction_start_events[0].part.provider_name == 'openai' + first_compaction_part = compaction_start_events[0].part + assert isinstance(first_compaction_part, CompactionPart) + assert first_compaction_part.provider_name == 'openai' # Verify final messages contain the CompactionPart with encrypted_content compaction_parts = [ diff --git a/tests/test_messages.py b/tests/test_messages.py index 0b46d73142..07b0fd7b9b 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -34,7 +34,12 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.messages import INVALID_JSON_KEY, MULTI_MODAL_CONTENT_TYPES, is_multi_modal_content +from pydantic_ai.messages import ( + INVALID_JSON_KEY, + MULTI_MODAL_CONTENT_TYPES, + ToolSearchReturnPart, + is_multi_modal_content, +) from ._inline_snapshot import snapshot from .conftest import IsDatetime, IsNow, IsStr @@ -1226,6 +1231,86 @@ def test_tool_return_content_nested_multimodal(): assert reloaded_content['regular_data'] == [{'url': '/api/path', 'id': 123, 'name': 'test'}] +def test_tool_return_part_in_model_response_round_trips(): + """`ToolReturnPart` inside a `ModelResponse` must survive serialization round-trip. + + Regression test for #5721: the `ModelResponsePart` discriminated union omitted the + base `ToolReturnPart` (`part_kind='tool-return'`), so `_agent_graph` parts persisted + inside a `ModelResponse` (resumed runs, durable workflows, AG-UI/Vercel adapters) + failed to deserialize with a `union_tag_invalid` `ValidationError`. + """ + msg = ModelResponse( + parts=[ + ToolReturnPart( + tool_name='my_tool', + content={'result': 'success'}, + tool_call_id='call-abc', + metadata={'custom': 'data'}, + ) + ] + ) + + # python mode + dumped_py = ModelMessagesTypeAdapter.dump_python([msg], mode='python') + reloaded_py = ModelMessagesTypeAdapter.validate_python(dumped_py) + part_py = reloaded_py[0].parts[0] + assert isinstance(part_py, ToolReturnPart) + assert part_py.content == {'result': 'success'} + assert part_py.tool_call_id == 'call-abc' + # metadata is the field that silently round-tripped to a wrong shape before the fix + assert part_py.metadata == {'custom': 'data'} + + # json mode + dumped_json = ModelMessagesTypeAdapter.dump_json([msg]) + reloaded_json = ModelMessagesTypeAdapter.validate_json(dumped_json) + part_json = reloaded_json[0].parts[0] + assert isinstance(part_json, ToolReturnPart) + assert part_json.content == {'result': 'success'} + assert part_json.metadata == {'custom': 'data'} + + +def test_tool_return_part_in_model_response_otel_message_parts(): + """A `ToolReturnPart` in a `ModelResponse` must render in OTel message parts, not silently drop. + + `ModelResponse.otel_message_parts` maps `NativeToolReturnPart` to a `tool_call_response` part; + now that base `ToolReturnPart` is a valid `ModelResponsePart` (reachable via user-constructed + history), it must map the same way (without the `builtin` flag) rather than being omitted. + """ + response = ModelResponse( + parts=[ + TextPart(content='hi'), + ToolReturnPart(tool_name='my_tool', content={'result': 'success'}, tool_call_id='call-1'), + ] + ) + settings = InstrumentationSettings(include_content=True) + assert response.otel_message_parts(settings) == snapshot( + [ + {'type': 'text', 'content': 'hi'}, + {'type': 'tool_call_response', 'id': 'call-1', 'name': 'my_tool', 'result': {'result': 'success'}}, + ] + ) + + +def test_tool_search_return_part_in_model_request_still_narrows(): + """Adding base `ToolReturnPart` to the response union must not regress its `ToolSearchReturnPart` subclass. + + `ToolSearchReturnPart` carries `tool_kind='tool-search'` and is registered under the + `'tool-search-return'` tag; the discriminator must still route it there (in a + `ModelRequest`, where search returns live) rather than to the base `'tool-return'` tag. + """ + search_return = ToolSearchReturnPart( + tool_name='search_tools', + content={'discovered_tools': [{'name': 'a_tool', 'description': 'does a thing'}]}, + tool_call_id='call-search', + ) + msg = ModelRequest(parts=[search_return]) + + reloaded = ModelMessagesTypeAdapter.validate_json(ModelMessagesTypeAdapter.dump_json([msg])) + part = reloaded[0].parts[0] + assert isinstance(part, ToolSearchReturnPart) + assert part.tool_name == 'search_tools' + + def test_multi_modal_content_types_matches_union(): """Validate that MULTI_MODAL_CONTENT_TYPES matches the MultiModalContent union members, and that is_multi_modal_content correctly narrows types.""" diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py index 6b512c0b5a..b20a80e497 100644 --- a/tests/test_parts_manager.py +++ b/tests/test_parts_manager.py @@ -15,6 +15,7 @@ ThinkingPartDelta, ToolCallPart, ToolCallPartDelta, + ToolReturnPart, UnexpectedModelBehavior, ) from pydantic_ai._parts_manager import ModelResponsePartsManager @@ -708,3 +709,17 @@ def test_get_part_by_vendor_id(): assert part == snapshot(TextPart(content='hello', part_kind='text')) assert manager.get_part_by_vendor_id('missing') is None + + +def test_resolve_provider_name_tool_return_part(): + """`ToolReturnPart` carries no `provider_name`, so the resolver returns the incoming one. + + `ToolReturnPart` is a valid `ModelResponsePart` member but is never tracked by the parts + manager and, unlike the other members, has no `provider_name` attribute. The resolver must + short-circuit on it rather than reading a missing attribute. + """ + manager = ModelResponsePartsManager(model_request_parameters=ModelRequestParameters()) + existing_part = ToolReturnPart(tool_name='tool1', content='result', tool_call_id='call1') + + assert manager._resolve_provider_name(existing_part, 'openai') == 'openai' # pyright: ignore[reportPrivateUsage] + assert manager._resolve_provider_name(existing_part, None) is None # pyright: ignore[reportPrivateUsage]