From a4ce8f38c367317071a330d380fd8717286c331f Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Thu, 28 May 2026 18:34:31 +0800 Subject: [PATCH] fix: round-trip instruction request parts --- pydantic_ai_slim/pydantic_ai/_enqueue.py | 4 ++- pydantic_ai_slim/pydantic_ai/messages.py | 10 +++++++ .../pydantic_ai/models/__init__.py | 19 ++++++++----- .../pydantic_ai/models/bedrock.py | 3 +- pydantic_ai_slim/pydantic_ai/models/cohere.py | 3 +- .../pydantic_ai/models/function.py | 3 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 3 +- pydantic_ai_slim/pydantic_ai/models/google.py | 3 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 3 +- .../pydantic_ai/models/huggingface.py | 3 +- .../pydantic_ai/models/instrumented.py | 5 +++- .../pydantic_ai/models/mistral.py | 3 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 5 ++-- .../pydantic_ai/models/outlines.py | 3 +- pydantic_ai_slim/pydantic_ai/models/xai.py | 3 +- pydantic_ai_slim/pydantic_ai/ui/_adapter.py | 3 +- .../pydantic_ai/ui/ag_ui/_adapter.py | 3 +- .../pydantic_ai/ui/vercel_ai/_adapter.py | 3 +- tests/models/test_model.py | 14 ++++++++++ tests/test_messages.py | 28 +++++++++++++++++++ 20 files changed, 100 insertions(+), 24 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_enqueue.py b/pydantic_ai_slim/pydantic_ai/_enqueue.py index ffccdf8027..dbf40723f0 100644 --- a/pydantic_ai_slim/pydantic_ai/_enqueue.py +++ b/pydantic_ai_slim/pydantic_ai/_enqueue.py @@ -12,6 +12,7 @@ from .exceptions import UserError from .messages import ( + InstructionPart, ModelMessage, ModelRequest, ModelRequestPart, @@ -96,7 +97,8 @@ def flush_request() -> None: flush_request() messages.append(item) elif isinstance( - item, (SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart, ToolSearchReturnPart) + item, + (SystemPromptPart, InstructionPart, UserPromptPart, ToolReturnPart, RetryPromptPart, ToolSearchReturnPart), ): flush_content() parts.append(item) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index c8fc93f64b..65f1307bce 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1538,6 +1538,15 @@ def sorted(parts: Sequence[InstructionPart]) -> list[InstructionPart]: """Sort instruction parts with static (`dynamic=False`) before dynamic, preserving relative order.""" return sorted(parts, key=lambda p: p.dynamic) + def otel_event(self, settings: InstrumentationSettings) -> LogRecord: + return LogRecord( + attributes={'event.name': 'gen_ai.system.message'}, + body={'role': 'system', **({'content': self.content} if settings.include_content else {})}, + ) + + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + return [_otel_messages.TextPart(type='text', **{'content': self.content} if settings.include_content else {})] + __repr__ = _utils.dataclasses_no_defaults_repr @@ -2022,6 +2031,7 @@ def _model_request_part_discriminator(v: Any) -> str | None: ModelRequestPart = Annotated[ Annotated[SystemPromptPart, pydantic.Tag('system-prompt')] + | Annotated[InstructionPart, pydantic.Tag('instruction')] | Annotated[UserPromptPart, pydantic.Tag('user-prompt')] | Annotated[ToolSearchReturnPart, pydantic.Tag('tool-search-return')] | Annotated[ToolReturnPart, pydantic.Tag('tool-return')] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index a3066d1c8f..a72bd1296e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -40,6 +40,7 @@ InstructionPart, ModelMessage, ModelRequest, + ModelRequestPart, ModelResponse, ModelResponsePart, ModelResponseState, @@ -1823,13 +1824,17 @@ def _wrap_non_leading_system_prompts(messages: list[ModelMessage]) -> list[Model new_messages: list[ModelMessage] = list(messages[: first_request_idx + 1]) changed = False for msg in messages[first_request_idx + 1 :]: - if isinstance(msg, ModelRequest) and any(isinstance(p, SystemPromptPart) for p in msg.parts): - new_parts = [ - UserPromptPart(content=f'{part.content}', timestamp=part.timestamp) - if isinstance(part, SystemPromptPart) - else part - for part in msg.parts - ] + if isinstance(msg, ModelRequest) and any(isinstance(p, SystemPromptPart | InstructionPart) for p in msg.parts): + new_parts: list[ModelRequestPart] = [] + for part in msg.parts: + if isinstance(part, SystemPromptPart): + new_parts.append( + UserPromptPart(content=f'{part.content}', timestamp=part.timestamp) + ) + elif isinstance(part, InstructionPart): + new_parts.append(UserPromptPart(content=f'{part.content}')) + else: + new_parts.append(part) new_messages.append(replace(msg, parts=new_parts)) changed = True else: diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index f84fbc6299..543075253e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -34,6 +34,7 @@ FilePart, FinishReason, ImageUrl, + InstructionPart, ModelMessage, ModelProfileSpec, ModelRequest, @@ -919,7 +920,7 @@ async def _map_messages( # noqa: C901 for message in messages: if isinstance(message, ModelRequest): for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): if part.content: # pragma: no branch system_prompt.append({'text': part.content}) elif isinstance(part, UserPromptPart): diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 26cc08c673..6caab63046 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -15,6 +15,7 @@ CompactionPart, FilePart, FinishReason, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -309,7 +310,7 @@ def _map_tool_definition(f: ToolDefinition) -> ToolV2: @classmethod def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): yield SystemChatMessageV2(role='system', content=part.content) elif isinstance(part, UserPromptPart): if isinstance(part.content, str): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index bd5d5844de..f7db7ab6e2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -19,6 +19,7 @@ BinaryContent, CompactionPart, FilePart, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -394,7 +395,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage: for message in messages: if isinstance(message, ModelRequest): for part in message.parts: - if isinstance(part, SystemPromptPart | UserPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart | UserPromptPart): request_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolReturnPart): request_tokens += _estimate_string_tokens(part.model_response_str()) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 85634bdbfe..4c1117a16e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -31,6 +31,7 @@ CompactionPart, FilePart, FileUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -350,7 +351,7 @@ async def _message_to_gemini_content( message_parts: list[_GeminiPartUnion] = [] for part in m.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): sys_prompt_parts.append(_GeminiTextPart(text=part.content)) elif isinstance(part, UserPromptPart): message_parts.extend(await self._map_user_prompt(part)) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index f708c46e51..b2aeb4ef97 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -24,6 +24,7 @@ FilePart, FileUrl, FinishReason, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -1020,7 +1021,7 @@ async def _map_messages( # noqa: C901 message_parts: list[PartDict] = [] for part in m.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): system_parts.append({'text': part.content}) elif isinstance(part, UserPromptPart): message_parts.extend(await self._map_user_prompt(part)) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index c429d21d12..7aad08de65 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -25,6 +25,7 @@ FilePart, FinishReason, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -542,7 +543,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]: file_content: list[UserContent] = [] for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) elif isinstance(part, UserPromptPart): yield await self._map_user_prompt(part) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 253288764f..b1026c40a7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -21,6 +21,7 @@ FilePart, FinishReason, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -455,7 +456,7 @@ async def _map_user_message( self, message: ModelRequest ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]: for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore elif isinstance(part, UserPromptPart): yield await self._map_user_prompt(part) diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index b04cf8150b..470816edb9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -32,6 +32,7 @@ from .. import _otel_messages from .._run_context import RunContext from ..messages import ( + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -221,7 +222,9 @@ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_ result: list[_otel_messages.ChatMessage] = [] for message in messages: if isinstance(message, ModelRequest): - for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)): + for is_system, group in itertools.groupby( + message.parts, key=lambda p: isinstance(p, SystemPromptPart | InstructionPart) + ): message_parts: list[_otel_messages.MessagePart] = [] for part in group: if hasattr(part, 'otel_message_parts'): diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 6bc04d9664..d7bf3fff71 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -23,6 +23,7 @@ FilePart, FinishReason, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -544,7 +545,7 @@ def _get_timeout_ms(timeout: Timeout | float | None) -> int | None: async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[MistralMessages]: file_content: list[UserContent] = [] for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): yield MistralSystemMessage(content=part.content) elif isinstance(part, UserPromptPart): yield await self._map_user_prompt(part) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 96bf139e0d..925b45e354 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -40,6 +40,7 @@ FilePart, FinishReason, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -1511,7 +1512,7 @@ def _map_tool_definition(self, f: ToolDefinition, model_settings: ModelSettings) async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]: file_content: list[UserContent] = [] for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): system_prompt_role = OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role if system_prompt_role == 'developer': yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content) @@ -2718,7 +2719,7 @@ async def _map_messages( # noqa: C901 for message in messages: if isinstance(message, ModelRequest): for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): openai_messages.append( responses.EasyInputMessageParam( role=profile.openai_system_prompt_role or 'system', content=part.content diff --git a/pydantic_ai_slim/pydantic_ai/models/outlines.py b/pydantic_ai_slim/pydantic_ai/models/outlines.py index 394d660514..7fbc1aa683 100644 --- a/pydantic_ai_slim/pydantic_ai/models/outlines.py +++ b/pydantic_ai_slim/pydantic_ai/models/outlines.py @@ -24,6 +24,7 @@ CompactionPart, FilePart, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -448,7 +449,7 @@ async def _format_prompt( # noqa: C901 for message in messages: if isinstance(message, ModelRequest): for part in message.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): chat.add_system_message(part.content) elif isinstance(part, UserPromptPart): if isinstance(part.content, str): diff --git a/pydantic_ai_slim/pydantic_ai/models/xai.py b/pydantic_ai_slim/pydantic_ai/models/xai.py index 09d2d9717c..129cd2fce9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/xai.py +++ b/pydantic_ai_slim/pydantic_ai/models/xai.py @@ -24,6 +24,7 @@ FilePart, FinishReason, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelRequestPart, @@ -314,7 +315,7 @@ async def _map_request_parts( tool_results: list[ToolReturnPart | RetryPromptPart] = [] for part in parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): xai_messages.append(system(part.content)) elif isinstance(part, UserPromptPart): if user_msg := await self._map_user_prompt(part): diff --git a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py index d280f84f49..e01b9d724b 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py @@ -30,6 +30,7 @@ BaseToolReturnPart, FileUrl, ForceDownloadMode, + InstructionPart, ModelMessage, ModelRequest, ModelRequestPart, @@ -440,7 +441,7 @@ def _sanitize_request_parts( stripped_system_prompt = False new_parts: list[ModelRequestPart] = [] for part in parts: - if strip_system_prompt and isinstance(part, SystemPromptPart): + if strip_system_prompt and isinstance(part, SystemPromptPart | InstructionPart): stripped_system_prompt = True continue if isinstance(part, UserPromptPart) and not isinstance(part.content, str): diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index 44e315bb23..11b8543d6f 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -28,6 +28,7 @@ FilePart, ForceDownloadMode, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -511,7 +512,7 @@ def _dump_request_parts( ] = [] for part in msg.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): system_content.append(part.content) elif isinstance(part, UserPromptPart): if isinstance(part.content, str): diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index 855a9dc3a7..3113507694 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -24,6 +24,7 @@ FilePart, ForceDownloadMode, ImageUrl, + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -517,7 +518,7 @@ def _dump_request_message(msg: ModelRequest) -> tuple[list[UIMessagePart], list[ user_ui_parts: list[UIMessagePart] = [] for part in msg.parts: - if isinstance(part, SystemPromptPart): + if isinstance(part, SystemPromptPart | InstructionPart): system_ui_parts.append(TextUIPart(text=part.content, state='done')) elif isinstance(part, UserPromptPart): user_ui_parts.extend(_convert_user_prompt_part(part)) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index d64305ca8a..f4801f8490 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -8,6 +8,7 @@ from pydantic_ai import UserError from pydantic_ai._warnings import PydanticAIDeprecationWarning from pydantic_ai.messages import ( + InstructionPart, ModelMessage, ModelRequest, ModelResponse, @@ -461,6 +462,19 @@ def _request_parts(messages: list[ModelMessage]) -> list[list[tuple[str, object] ], id='wraps-multiple-non-leading-system-prompts', ), + pytest.param( + False, + [ + ModelRequest(parts=[UserPromptPart(content='hi')]), + ModelResponse(parts=[TextPart(content='hello')]), + ModelRequest(parts=[InstructionPart(content='Use a short answer.'), UserPromptPart(content='ok?')]), + ], + [ + [('UserPromptPart', 'hi')], + [('UserPromptPart', 'Use a short answer.'), ('UserPromptPart', 'ok?')], + ], + id='wraps-non-leading-instruction-part', + ), pytest.param( False, [ diff --git a/tests/test_messages.py b/tests/test_messages.py index 0b46d73142..77fc1cc7df 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1501,6 +1501,16 @@ def test_serialization_round_trip(self): assert isinstance(msg, ModelRequest) assert msg.instructions == 'static part\n\ndynamic part' + def test_request_part_serialization_round_trip(self): + original = ModelRequest(parts=[InstructionPart(content='Use the short path.', dynamic=True)]) + + serialized = ModelMessagesTypeAdapter.dump_json([original]) + deserialized = ModelMessagesTypeAdapter.validate_json(serialized) + + msg = deserialized[0] + assert isinstance(msg, ModelRequest) + assert msg.parts == original.parts + def test_repr(self): """InstructionPart repr omits default values.""" part = InstructionPart(content='hello') @@ -1508,6 +1518,24 @@ def test_repr(self): dynamic_part = InstructionPart(content='world', dynamic=True) assert repr(dynamic_part) == "InstructionPart(content='world', dynamic=True)" + def test_otel_event(self): + part = InstructionPart(content='Keep replies short.') + + event = part.otel_event(InstrumentationSettings(include_content=True)) + assert event.attributes == snapshot({'event.name': 'gen_ai.system.message'}) + assert event.body == snapshot({'role': 'system', 'content': 'Keep replies short.'}) + + event_without_content = part.otel_event(InstrumentationSettings(include_content=False)) + assert event_without_content.body == snapshot({'role': 'system'}) + + def test_otel_message_parts(self): + part = InstructionPart(content='Keep replies short.') + + assert part.otel_message_parts(InstrumentationSettings(include_content=True)) == snapshot( + [{'type': 'text', 'content': 'Keep replies short.'}] + ) + assert part.otel_message_parts(InstrumentationSettings(include_content=False)) == snapshot([{'type': 'text'}]) + def test_retry_prompt_strips_input_from_top_level_errors(): """Top-level validation errors should not include `input` in model_response() since it duplicates the entire generated output."""