Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_enqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .exceptions import UserError
from .messages import (
InstructionPart,
ModelMessage,
ModelRequest,
ModelRequestPart,
Expand Down Expand Up @@ -96,7 +97,8 @@ def flush_request() -> None:
flush_request()
messages.append(item)
elif isinstance(
item, (SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart, ToolSearchReturnPart)
item,
(SystemPromptPart, InstructionPart, UserPromptPart, ToolReturnPart, RetryPromptPart, ToolSearchReturnPart),
):
flush_content()
parts.append(item)
Expand Down
10 changes: 10 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,15 @@ def sorted(parts: Sequence[InstructionPart]) -> list[InstructionPart]:
"""Sort instruction parts with static (`dynamic=False`) before dynamic, preserving relative order."""
return sorted(parts, key=lambda p: p.dynamic)

def otel_event(self, settings: InstrumentationSettings) -> LogRecord:
return LogRecord(
attributes={'event.name': 'gen_ai.system.message'},
body={'role': 'system', **({'content': self.content} if settings.include_content else {})},
)

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
return [_otel_messages.TextPart(type='text', **{'content': self.content} if settings.include_content else {})]

__repr__ = _utils.dataclasses_no_defaults_repr


Expand Down Expand Up @@ -2022,6 +2031,7 @@ def _model_request_part_discriminator(v: Any) -> str | None:

ModelRequestPart = Annotated[
Annotated[SystemPromptPart, pydantic.Tag('system-prompt')]
| Annotated[InstructionPart, pydantic.Tag('instruction')]
| Annotated[UserPromptPart, pydantic.Tag('user-prompt')]
| Annotated[ToolSearchReturnPart, pydantic.Tag('tool-search-return')]
| Annotated[ToolReturnPart, pydantic.Tag('tool-return')]
Expand Down
19 changes: 12 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
InstructionPart,
ModelMessage,
ModelRequest,
ModelRequestPart,
ModelResponse,
ModelResponsePart,
ModelResponseState,
Expand Down Expand Up @@ -1823,13 +1824,17 @@ def _wrap_non_leading_system_prompts(messages: list[ModelMessage]) -> list[Model
new_messages: list[ModelMessage] = list(messages[: first_request_idx + 1])
changed = False
for msg in messages[first_request_idx + 1 :]:
if isinstance(msg, ModelRequest) and any(isinstance(p, SystemPromptPart) for p in msg.parts):
new_parts = [
UserPromptPart(content=f'<system>{part.content}</system>', timestamp=part.timestamp)
if isinstance(part, SystemPromptPart)
else part
for part in msg.parts
]
if isinstance(msg, ModelRequest) and any(isinstance(p, SystemPromptPart | InstructionPart) for p in msg.parts):
new_parts: list[ModelRequestPart] = []
for part in msg.parts:
if isinstance(part, SystemPromptPart):
new_parts.append(
UserPromptPart(content=f'<system>{part.content}</system>', timestamp=part.timestamp)
)
elif isinstance(part, InstructionPart):
new_parts.append(UserPromptPart(content=f'<system>{part.content}</system>'))
else:
new_parts.append(part)
new_messages.append(replace(msg, parts=new_parts))
changed = True
else:
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
FilePart,
FinishReason,
ImageUrl,
InstructionPart,
ModelMessage,
ModelProfileSpec,
ModelRequest,
Expand Down Expand Up @@ -919,7 +920,7 @@ async def _map_messages( # noqa: C901
for message in messages:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
if part.content: # pragma: no branch
system_prompt.append({'text': part.content})
elif isinstance(part, UserPromptPart):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CompactionPart,
FilePart,
FinishReason,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -309,7 +310,7 @@ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
@classmethod
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
yield SystemChatMessageV2(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
if isinstance(part.content, str):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BinaryContent,
CompactionPart,
FilePart,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -394,7 +395,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
for message in messages:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, SystemPromptPart | UserPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart | UserPromptPart):
request_tokens += _estimate_string_tokens(part.content)
elif isinstance(part, ToolReturnPart):
request_tokens += _estimate_string_tokens(part.model_response_str())
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
CompactionPart,
FilePart,
FileUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -350,7 +351,7 @@ async def _message_to_gemini_content(
message_parts: list[_GeminiPartUnion] = []

for part in m.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
elif isinstance(part, UserPromptPart):
message_parts.extend(await self._map_user_prompt(part))
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FilePart,
FileUrl,
FinishReason,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -1020,7 +1021,7 @@ async def _map_messages( # noqa: C901
message_parts: list[PartDict] = []

for part in m.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
system_parts.append({'text': part.content})
elif isinstance(part, UserPromptPart):
message_parts.extend(await self._map_user_prompt(part))
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FilePart,
FinishReason,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -542,7 +543,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
file_content: list[UserContent] = []
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
yield await self._map_user_prompt(part)
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FilePart,
FinishReason,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -455,7 +456,7 @@ async def _map_user_message(
self, message: ModelRequest
) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore
elif isinstance(part, UserPromptPart):
yield await self._map_user_prompt(part)
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .. import _otel_messages
from .._run_context import RunContext
from ..messages import (
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -221,7 +222,9 @@ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_
result: list[_otel_messages.ChatMessage] = []
for message in messages:
if isinstance(message, ModelRequest):
for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)):
for is_system, group in itertools.groupby(
message.parts, key=lambda p: isinstance(p, SystemPromptPart | InstructionPart)
):
message_parts: list[_otel_messages.MessagePart] = []
for part in group:
if hasattr(part, 'otel_message_parts'):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
FilePart,
FinishReason,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -544,7 +545,7 @@ def _get_timeout_ms(timeout: Timeout | float | None) -> int | None:
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[MistralMessages]:
file_content: list[UserContent] = []
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
yield MistralSystemMessage(content=part.content)
elif isinstance(part, UserPromptPart):
yield await self._map_user_prompt(part)
Expand Down
5 changes: 3 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
FilePart,
FinishReason,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -1511,7 +1512,7 @@ def _map_tool_definition(self, f: ToolDefinition, model_settings: ModelSettings)
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
file_content: list[UserContent] = []
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
system_prompt_role = OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
if system_prompt_role == 'developer':
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
Expand Down Expand Up @@ -2718,7 +2719,7 @@ async def _map_messages( # noqa: C901
for message in messages:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
openai_messages.append(
responses.EasyInputMessageParam(
role=profile.openai_system_prompt_role or 'system', content=part.content
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CompactionPart,
FilePart,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -448,7 +449,7 @@ async def _format_prompt( # noqa: C901
for message in messages:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
chat.add_system_message(part.content)
elif isinstance(part, UserPromptPart):
if isinstance(part.content, str):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FilePart,
FinishReason,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelRequestPart,
Expand Down Expand Up @@ -314,7 +315,7 @@ async def _map_request_parts(
tool_results: list[ToolReturnPart | RetryPromptPart] = []

for part in parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
xai_messages.append(system(part.content))
elif isinstance(part, UserPromptPart):
if user_msg := await self._map_user_prompt(part):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
BaseToolReturnPart,
FileUrl,
ForceDownloadMode,
InstructionPart,
ModelMessage,
ModelRequest,
ModelRequestPart,
Expand Down Expand Up @@ -440,7 +441,7 @@ def _sanitize_request_parts(
stripped_system_prompt = False
new_parts: list[ModelRequestPart] = []
for part in parts:
if strip_system_prompt and isinstance(part, SystemPromptPart):
if strip_system_prompt and isinstance(part, SystemPromptPart | InstructionPart):
stripped_system_prompt = True
continue
if isinstance(part, UserPromptPart) and not isinstance(part.content, str):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
FilePart,
ForceDownloadMode,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -511,7 +512,7 @@ def _dump_request_parts(
] = []

for part in msg.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
system_content.append(part.content)
elif isinstance(part, UserPromptPart):
if isinstance(part.content, str):
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FilePart,
ForceDownloadMode,
ImageUrl,
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -517,7 +518,7 @@ def _dump_request_message(msg: ModelRequest) -> tuple[list[UIMessagePart], list[
user_ui_parts: list[UIMessagePart] = []

for part in msg.parts:
if isinstance(part, SystemPromptPart):
if isinstance(part, SystemPromptPart | InstructionPart):
system_ui_parts.append(TextUIPart(text=part.content, state='done'))
elif isinstance(part, UserPromptPart):
user_ui_parts.extend(_convert_user_prompt_part(part))
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic_ai import UserError
from pydantic_ai._warnings import PydanticAIDeprecationWarning
from pydantic_ai.messages import (
InstructionPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -461,6 +462,19 @@ def _request_parts(messages: list[ModelMessage]) -> list[list[tuple[str, object]
],
id='wraps-multiple-non-leading-system-prompts',
),
pytest.param(
False,
[
ModelRequest(parts=[UserPromptPart(content='hi')]),
ModelResponse(parts=[TextPart(content='hello')]),
ModelRequest(parts=[InstructionPart(content='Use a short answer.'), UserPromptPart(content='ok?')]),
],
[
[('UserPromptPart', 'hi')],
[('UserPromptPart', '<system>Use a short answer.</system>'), ('UserPromptPart', 'ok?')],
],
id='wraps-non-leading-instruction-part',
),
pytest.param(
False,
[
Expand Down
Loading
Loading