diff --git a/docs/capabilities.md b/docs/capabilities.md index 67146798fb..1dbddc8671 100644 --- a/docs/capabilities.md +++ b/docs/capabilities.md @@ -117,6 +117,11 @@ agent = Agent('openai:gpt-5.2', name='my_agent', capabilities=[hooks]) All hooks receive [`RunContext`][pydantic_ai.tools.RunContext], which provides access to the running agent via [`ctx.agent`][pydantic_ai.tools.RunContext.agent] — useful for logging, metrics, and other cross-cutting concerns that need to identify which agent is running. +Hooks can also push follow-up messages into the conversation via +[`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] — useful for capability +authors that need to surface an event to the model mid-run without rebuilding the +cached system prompt. See [Injecting messages mid-run](message-history.md#injecting-messages-mid-run). + See the dedicated [Hooks](hooks.md) page for the full API: decorator and constructor registration, timeouts, tool filtering, wrap hooks, per-event hooks, and more. ### Provider-adaptive tools diff --git a/docs/message-history.md b/docs/message-history.md index 29288c1ab2..628b4d1a7a 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -391,6 +391,121 @@ print(result2.all_messages()) """ ``` +## Injecting messages mid-run + +Tools, capability hooks, and external code driving an agent run can inject extra content +into the conversation mid-run with [`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] +(when a `RunContext` is in scope, e.g. inside a tool or capability hook) or +[`AgentRun.enqueue`][pydantic_ai.run.AgentRun.enqueue] (from external code driving +[`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter]). Use this when something happens during a +run that the agent should know about — a tool wants to add follow-up context, an external event +needs to *steer* the agent's plan, or background work needs to reach the agent when it completes. + +A `priority` controls when the enqueued content is delivered: + +- `'asap'` (default): delivered at the earliest opportunity — added to the next [`ModelRequest`][pydantic_ai.messages.ModelRequest], or, if the agent would otherwise terminate before another request, used to redirect the run into one more request. Use when the new context should reach the model as soon as possible; this is what other frameworks often call **steering** an in-flight agent. +- `'when_idle'`: delivered only when the agent would otherwise terminate, after any `'asap'` messages. Use when the agent shouldn't be interrupted but should pick up the new work — a follow-up task — once it's done with what it's doing. + +`enqueue` is variadic — each positional argument is one item, and can be: + +- a piece of [`UserContent`][pydantic_ai.messages.UserContent] — a `str` or multi-modal content like an [`ImageUrl`][pydantic_ai.messages.ImageUrl]. Adjacent user content is gathered into a single [`UserPromptPart`][pydantic_ai.messages.UserPromptPart], so `enqueue('caption', image)` forms one user turn. To pass an existing list, spread it: `enqueue(*items)`; +- a [`ModelRequestPart`][pydantic_ai.messages.ModelRequestPart], such as a [`SystemPromptPart`][pydantic_ai.messages.SystemPromptPart]; +- a complete [`ModelRequest`][pydantic_ai.messages.ModelRequest] or [`ModelResponse`][pydantic_ai.messages.ModelResponse], to control request-level fields like `instructions`/`metadata` or to inject a synthetic prior turn. + +Adjacent part-style items (user content and [`ModelRequestPart`][pydantic_ai.messages.ModelRequestPart]s) are coalesced into one [`ModelRequest`][pydantic_ai.messages.ModelRequest]; complete messages stay separate. This lets a single call inject an interleaved exchange — for example a synthetic tool call (a [`ModelResponse`][pydantic_ai.messages.ModelResponse]) followed by its result (a [`ModelRequest`][pydantic_ai.messages.ModelRequest]). The content must end in a request, so the agent has something to respond to. + +### From inside a tool or hook + +Use [`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] when you have a +`RunContext` in scope: + +```python {title="enqueue_from_tool.py"} +from pydantic_ai import Agent, RunContext +from pydantic_ai.messages import SystemPromptPart + +agent = Agent('anthropic:claude-opus-4-7') + + +@agent.tool +def trigger_alert(ctx: RunContext[None]) -> str: + ctx.enqueue('Alert: production is degraded, prioritize triage.') + return 'alert raised' + + +@agent.tool +def enter_incident_mode(ctx: RunContext[None]) -> str: + # Enqueue a `SystemPromptPart` to adjust the agent's standing instructions mid-run. + ctx.enqueue(SystemPromptPart(content='You are now in incident mode: be terse and action-oriented.')) + return 'incident mode enabled' +``` + +The `'asap'` message is appended to the agent's message history and is visible to the +model on the next request, alongside any tool returns from the same step. A +[`SystemPromptPart`][pydantic_ai.messages.SystemPromptPart] is delivered the same way; on +providers that hoist system prompts (e.g. Anthropic, Google) a non-leading one is sent as a +``-tagged user-role message, so it keeps its mid-conversation position rather than being +lifted to the top. + +### From external code driving `agent.iter()` + +Use [`AgentRun.enqueue`][pydantic_ai.run.AgentRun.enqueue] when you're driving a run +from outside (e.g. forwarding events from a webhook, chat platform, or job queue): + +```python {title="enqueue_from_agent_run.py"} +from pydantic_ai import Agent +from pydantic_graph import End + +agent = Agent('anthropic:claude-opus-4-7') + + +async def main(): + async with agent.iter('Summarize the latest deploy report') as agent_run: + # An external system pushes a follow-up while the agent is working. + # When the agent would otherwise finish, the message redirects it + # into a fresh model request so it can incorporate the new context. + agent_run.enqueue( + 'A new error was just reported — include it in the summary.', + priority='when_idle', + ) + node = agent_run.next_node + while not isinstance(node, End): + node = await agent_run.next(node) +``` + +The example drives the run with [`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter] + +[`AgentRun.next()`][pydantic_ai.run.AgentRun.next] because `'when_idle'` messages are only +drained when the agent would otherwise reach an `End` — that drain happens in `after_node_run`, +which doesn't fire inside a bare `async for node in agent_run:` loop. `'asap'` messages are +drained in `before_model_request` (which fires either way) and also at the same end-of-run point +if anything arrived during the final step. Reaching the end of a bare `async for` loop with +undrained pending messages raises [`UndrainedPendingMessagesError`][pydantic_ai.exceptions.UndrainedPendingMessagesError], +since those messages would otherwise be silently lost. + +!!! info "Limitations" + - End-of-run redirects need [`Agent.run`][pydantic_ai.agent.AbstractAgent.run] or + explicit [`AgentRun.next()`][pydantic_ai.run.AgentRun.next] driving — they + aren't drained inside a bare `async for node in agent_run:` loop (which raises + [`UndrainedPendingMessagesError`][pydantic_ai.exceptions.UndrainedPendingMessagesError] + if it ends with undrained messages). Messages delivered into a + `before_model_request` work in either case. + - Inside a [Temporal](durable_execution/temporal.md) workflow, tools run in + activities and don't share state with the workflow, so `ctx.enqueue` from a + tool doesn't currently propagate back to the run. Enqueue from the workflow + context (e.g. via `AgentRun.enqueue`) instead. + - Each end-of-run redirect opens a new model request. If something keeps + enqueueing on every step (e.g. a tool that always enqueues, or a + system-prompt callback that re-enqueues on each reinjection), the run will + loop indefinitely. Set [`UsageLimits`][pydantic_ai.usage.UsageLimits] on the + run as a safety net. + - `enqueue` is designed to be called from the same event loop that drives the + agent run. Inside the run that's automatic: async tools, sync tools (which + Pydantic AI auto-wraps in a thread executor), and capability hooks all + enqueue safely because the drain only iterates between graph nodes, never + concurrently with a tool body. If you're forwarding events from a *different* + thread or loop (e.g. a webhook handler), marshal the call onto the agent's + loop first — e.g. `loop.call_soon_threadsafe(agent_run.enqueue, msg)`. The + drain isn't atomic against concurrent cross-thread appends. + ## Processing Message History Sometimes you may want to modify the message history before it's sent to the model. This could be for privacy diff --git a/docs/tools.md b/docs/tools.md index 26563e581e..38b0871650 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -375,6 +375,14 @@ _(This example is complete, it can be run "as is")_ This visibility helps you understand why an agent made specific decisions and identify issues in tool implementations. +## Injecting Follow-up Messages from a Tool + +A tool can push extra messages into the conversation via +[`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] — useful when a tool wants +to add follow-up context, redirect the agent's plan, or surface an event the model +should react to. See [Injecting messages mid-run](message-history.md#injecting-messages-mid-run) +for the full pattern. + ## See Also For more tool features and integrations, see: diff --git a/pydantic_ai_slim/pydantic_ai/.agents/skills/building-pydantic-ai-agents/references/INPUT-AND-HISTORY.md b/pydantic_ai_slim/pydantic_ai/.agents/skills/building-pydantic-ai-agents/references/INPUT-AND-HISTORY.md index bfebb9aedb..de91608d82 100644 --- a/pydantic_ai_slim/pydantic_ai/.agents/skills/building-pydantic-ai-agents/references/INPUT-AND-HISTORY.md +++ b/pydantic_ai_slim/pydantic_ai/.agents/skills/building-pydantic-ai-agents/references/INPUT-AND-HISTORY.md @@ -67,3 +67,28 @@ Good uses: - removing PII before provider calls - summarizing old messages - applying app-specific history policies + +## Inject Messages Mid-Run + +Use `RunContext.enqueue(...)` (from a tool or capability hook) or `AgentRun.enqueue(...)` (from external code driving `agent.iter()`) to add content to the conversation while a run is in progress — e.g. a tool adding follow-up context, or an external event "steering" the agent. + +`enqueue` is variadic; each positional arg is one item: a piece of `UserContent` (a `str` or multi-modal content like an `ImageUrl`), a `ModelRequestPart` (e.g. a `SystemPromptPart`), or a complete `ModelRequest`/`ModelResponse`. Adjacent user content is gathered into one `UserPromptPart`. Pass an existing list by spreading it (`enqueue(*items)`). + +```python +from pydantic_ai import Agent, RunContext + +agent = Agent('anthropic:claude-opus-4-7') + + +@agent.tool +def trigger_alert(ctx: RunContext[None]) -> str: + ctx.enqueue('Alert: production is degraded, prioritize triage.') + return 'alert raised' +``` + +A `priority` controls delivery: + +- `'asap'` (default): delivered at the earliest opportunity — added to the next model request, or, if the agent would otherwise terminate, used to redirect the run into one more request. This is "steering" an in-flight agent. +- `'when_idle'`: delivered only when the agent would otherwise terminate, after any `'asap'` messages — a follow-up task that shouldn't interrupt in-flight work. + +`'when_idle'` redirects need `agent.run()` or explicit `AgentRun.next()` driving; they aren't drained inside a bare `async for node in agent_run:` loop. See [message history docs](https://ai.pydantic.dev/message-history/#injecting-messages-mid-run) for details. diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 530232096f..f6ea35f3b7 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -40,6 +40,7 @@ SkipModelRequest, SkipToolExecution, SkipToolValidation, + UndrainedPendingMessagesError, UnexpectedModelBehavior, UsageLimitExceeded, UserError, @@ -200,6 +201,7 @@ 'SkipModelRequest', 'SkipToolExecution', 'SkipToolValidation', + 'UndrainedPendingMessagesError', 'UnexpectedModelBehavior', 'UsageLimitExceeded', 'UserError', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 776e9619bb..99dffc826b 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -17,7 +17,7 @@ from pydantic_ai._history_processor import HistoryProcessor from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION -from pydantic_ai._utils import cancel_and_drain, dataclasses_no_defaults_repr, now_utc +from pydantic_ai._utils import cancel_and_drain, dataclasses_no_defaults_repr, fill_run_metadata, now_utc from pydantic_ai._uuid import uuid7 from pydantic_ai.capabilities.abstract import AbstractCapability from pydantic_ai.models import ModelRequestContext @@ -28,7 +28,7 @@ from pydantic_graph.basenode import End, NodeRunEndT from pydantic_graph.graph_builder import Graph -from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from . import _enqueue, _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from ._run_context import set_current_run_context from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec @@ -135,6 +135,9 @@ class GraphAgentState: """Last-resolved `max_tokens` from model settings, used only in error messages.""" last_model_request_parameters: models.ModelRequestParameters | None = None """Last-resolved model request parameters, used for OTel span attributes.""" + pending_messages: list[_enqueue.PendingMessage] = dataclasses.field(default_factory=list[_enqueue.PendingMessage]) + """Internal: queue used by [`PendingMessageDrainCapability`][pydantic_ai.capabilities._pending_messages.PendingMessageDrainCapability] + for messages enqueued via [`enqueue`][pydantic_ai.tools.RunContext.enqueue] or [`AgentRun.enqueue`][pydantic_ai.run.AgentRun.enqueue].""" def check_incomplete_tool_call(self) -> None: """Raise `IncompleteToolCall` if the last model response was truncated mid-tool-call.""" @@ -896,14 +899,8 @@ async def _prepare_request( if not isinstance(messages[-1], _messages.ModelRequest): raise exceptions.UserError('Processed history must end with a `ModelRequest`.') - # Ensure the last request has a timestamp (history processors may create new ModelRequest objects without one) - if messages[-1].timestamp is None: - messages[-1].timestamp = now_utc() - - if messages and messages[-1].run_id is None: - messages[-1].run_id = ctx.state.run_id - if messages and messages[-1].conversation_id is None: - messages[-1].conversation_id = ctx.state.conversation_id + # Fill in framework metadata the history processors may have left unset on a new `ModelRequest`. + fill_run_metadata(messages[-1], run_id=ctx.state.run_id, conversation_id=ctx.state.conversation_id) if self.is_resuming_without_prompt: ctx.deps.resumed_request = self.request @@ -960,8 +957,7 @@ async def _finish_handling( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], response: _messages.ModelResponse, ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: - response.run_id = response.run_id or ctx.state.run_id - response.conversation_id = response.conversation_id or ctx.state.conversation_id + fill_run_metadata(response, run_id=ctx.state.run_id, conversation_id=ctx.state.conversation_id) run_context = build_run_context(ctx) assert self.last_request_context is not None, 'last_request_context must be set before _finish_handling' @@ -1013,8 +1009,7 @@ def _append_response( response: _messages.ModelResponse, ) -> None: """Append a model response to history, updating usage tracking.""" - response.run_id = response.run_id or ctx.state.run_id - response.conversation_id = response.conversation_id or ctx.state.conversation_id + fill_run_metadata(response, run_id=ctx.state.run_id, conversation_id=ctx.state.conversation_id) ctx.state.usage.incr(response.usage) if ctx.deps.usage_limits: # pragma: no branch ctx.deps.usage_limits.check_tokens(ctx.state.usage) @@ -1410,6 +1405,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT conversation_id=ctx.state.conversation_id, metadata=ctx.state.metadata, tool_manager=ctx.deps.tool_manager, + pending_messages=ctx.state.pending_messages, ) validation_context = build_validation_context(ctx.deps.validation_context, run_context) run_context = replace(run_context, validation_context=validation_context) diff --git a/pydantic_ai_slim/pydantic_ai/_enqueue.py b/pydantic_ai_slim/pydantic_ai/_enqueue.py new file mode 100644 index 0000000000..ffccdf8027 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_enqueue.py @@ -0,0 +1,150 @@ +"""Internal helpers for the `RunContext.enqueue` / `AgentRun.enqueue` APIs. + +These types live here (rather than in `messages.py`) because they're internal runtime +state for the pending message queue, not part of the wire-serializable message history. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, TypeAlias + +from .exceptions import UserError +from .messages import ( + ModelMessage, + ModelRequest, + ModelRequestPart, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + ToolReturnPart, + ToolSearchReturnPart, + UserPromptPart, +) + +if TYPE_CHECKING: + from .messages import UserContent + + +PendingMessagePriority: TypeAlias = Literal['asap', 'when_idle'] +"""When to deliver a pending message. + +- `'asap'`: Delivered at the earliest opportunity — either prepended to the next + [`ModelRequest`][pydantic_ai.messages.ModelRequest], or, if the agent would + otherwise terminate before another request, used to redirect the run into one + more request. +- `'when_idle'`: Delivered only when the agent would otherwise terminate, after + any `'asap'` messages. Doesn't interrupt in-flight work. +""" + + +EnqueueContent: TypeAlias = 'UserContent | ModelRequestPart | ModelMessage' +"""A single item accepted by [`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] +and [`AgentRun.enqueue`][pydantic_ai.run.AgentRun.enqueue]. + +`enqueue` is variadic, so each item is one positional argument: + +- [`UserContent`][pydantic_ai.messages.UserContent] (a `str` or a piece of multi-modal content + like an [`ImageUrl`][pydantic_ai.messages.ImageUrl]): adjacent user content is gathered into a + single [`UserPromptPart`][pydantic_ai.messages.UserPromptPart], so `enqueue('caption', image)` + forms one user turn. To pass an existing list, spread it: `enqueue(*items)`. +- [`ModelRequestPart`][pydantic_ai.messages.ModelRequestPart] (e.g. a + [`SystemPromptPart`][pydantic_ai.messages.SystemPromptPart]): included verbatim. +- [`ModelMessage`][pydantic_ai.messages.ModelMessage] (a complete + [`ModelRequest`][pydantic_ai.messages.ModelRequest] or + [`ModelResponse`][pydantic_ai.messages.ModelResponse]): emitted as its own message. + +Consecutive part-style items (user content and `ModelRequestPart`s) are coalesced into a single +`ModelRequest`; complete `ModelMessage`s stay separate. This lets one `enqueue` call inject an +interleaved exchange (e.g. a synthetic tool call + result — a `ModelResponse` followed by a +`ModelRequest`). The assembled sequence must end in a `ModelRequest` so the agent has something to +respond to. +""" + + +def _build_enqueue_messages(items: Sequence[EnqueueContent]) -> list[ModelMessage]: + """Assemble enqueue items into a list of [`ModelMessage`][pydantic_ai.messages.ModelMessage]s. + + Adjacent [`UserContent`][pydantic_ai.messages.UserContent] items are gathered into one + [`UserPromptPart`][pydantic_ai.messages.UserPromptPart], and part-style items (user content and + [`ModelRequestPart`][pydantic_ai.messages.ModelRequestPart]s) are coalesced into a single + [`ModelRequest`][pydantic_ai.messages.ModelRequest]; complete `ModelMessage`s are emitted as-is. + Order is preserved, so a `ModelResponse` followed by part-style items produces the response then + a request built from those parts. + """ + messages: list[ModelMessage] = [] + parts: list[ModelRequestPart] = [] + content: list[UserContent] = [] + + def flush_content() -> None: + if content: + # Collapse a lone string to `str` content, matching `Agent.run('...')`; anything else + # (multiple items, or a single non-string like an image) becomes a content list. + single = content[0] if len(content) == 1 and isinstance(content[0], str) else list(content) + parts.append(UserPromptPart(content=single)) + content.clear() + + def flush_request() -> None: + flush_content() + if parts: + messages.append(ModelRequest(parts=list(parts))) + parts.clear() + + for item in items: + if isinstance(item, (ModelRequest, ModelResponse)): + flush_request() + messages.append(item) + elif isinstance( + item, (SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart, ToolSearchReturnPart) + ): + flush_content() + parts.append(item) + else: + content.append(item) + flush_request() + return messages + + +@dataclass +class PendingMessage: + """One or more [`ModelMessage`][pydantic_ai.messages.ModelMessage]s queued for injection into the agent conversation. + + Enqueued via [`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] or + [`AgentRun.enqueue`][pydantic_ai.run.AgentRun.enqueue] and automatically drained + at the appropriate time during the agent run by + [`PendingMessageDrainCapability`][pydantic_ai.capabilities._pending_messages.PendingMessageDrainCapability]. + """ + + messages: list[ModelMessage] + """The message(s) to inject, in order. Always ends in a + [`ModelRequest`][pydantic_ai.messages.ModelRequest].""" + + priority: PendingMessagePriority = 'asap' + """When to deliver these messages: + + - `'asap'`: at the earliest opportunity (next model request, or redirect if the agent + would otherwise terminate). + - `'when_idle'`: only when the agent would otherwise terminate, after `'asap'` messages. + """ + + @classmethod + def from_content(cls, *content: EnqueueContent, priority: PendingMessagePriority = 'asap') -> PendingMessage | None: + """Build a `PendingMessage` from `enqueue` arguments, or `None` when there's nothing to send. + + Returns `None` for an empty call (enqueueing nothing is a no-op rather than an error). + + Raises: + UserError: If the assembled messages don't end in a + [`ModelRequest`][pydantic_ai.messages.ModelRequest] — e.g. a lone `ModelResponse` — + since the agent needs a request to respond to. + """ + messages = _build_enqueue_messages(content) + if not messages: + return None + if not isinstance(messages[-1], ModelRequest): + raise UserError( + 'Enqueued content must end with a `ModelRequest` (or user content / `ModelRequestPart` ' + 'items that form one), so the agent has a request to respond to.' + ) + return cls(messages=messages, priority=priority) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index c734151e60..4d557b8cdc 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -13,13 +13,15 @@ from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION from . import _utils, messages as _messages +from ._enqueue import EnqueueContent, PendingMessage, PendingMessagePriority +from .exceptions import UserError if TYPE_CHECKING: from .agent import Agent from .models import Model - from .result import RunUsage from .settings import ModelSettings from .tool_manager import ToolManager + from .usage import RunUsage # TODO (v2): Change the default for all typevars like this from `None` to `object` AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) @@ -99,6 +101,14 @@ class RunContext(Generic[RunContextAgentDepsT]): `after_model_request`). Currently `None` in tool hooks, output validators, and during agent construction. """ + pending_messages: list[PendingMessage] | None = field(default=None, repr=False) + """Internal: queue read and mutated by [`PendingMessageDrainCapability`][pydantic_ai.capabilities._pending_messages.PendingMessageDrainCapability]. + + Set to the run's live queue during an agent run; `None` in synthetic contexts that aren't + backed by a running agent (e.g. the `RunContext` built by `Agent.system_prompt_parts`), where + [`enqueue`][pydantic_ai.tools.RunContext.enqueue] would have nowhere to drain to and so raises. + Use [`enqueue`][pydantic_ai.tools.RunContext.enqueue] to add messages — don't append directly. + """ tool_manager: ToolManager[RunContextAgentDepsT] | None = None """The tool manager for the current run step. @@ -116,6 +126,51 @@ def last_attempt(self) -> bool: """Whether this is the last attempt at running this tool before an error is raised.""" return self.retry == self.max_retries + def enqueue( + self, + *content: EnqueueContent, + priority: PendingMessagePriority = 'asap', + ) -> None: + """Enqueue content to be injected into the conversation. + + Safe to call from anywhere a `RunContext` is available — async tools, + sync tools (auto-wrapped in a thread executor by Pydantic AI), and + capability hooks. The drain only iterates the queue between graph nodes + (in `before_model_request` and `after_node_run`), never concurrently + with the tool body, so `list.append` from a worker thread doesn't race + the drain. + + Args: + *content: One or more [`EnqueueContent`][pydantic_ai._enqueue.EnqueueContent] items. + Adjacent [`UserContent`][pydantic_ai.messages.UserContent] (a `str` or multi-modal + content like an [`ImageUrl`][pydantic_ai.messages.ImageUrl]) is gathered into one + [`UserPromptPart`][pydantic_ai.messages.UserPromptPart], and each + [`ModelRequestPart`][pydantic_ai.messages.ModelRequestPart] (e.g. a + [`SystemPromptPart`][pydantic_ai.messages.SystemPromptPart]) is coalesced with adjacent + part-style items into one [`ModelRequest`][pydantic_ai.messages.ModelRequest]; a complete + [`ModelRequest`][pydantic_ai.messages.ModelRequest] or + [`ModelResponse`][pydantic_ai.messages.ModelResponse] is kept as its own message. The + assembled sequence must end in a request. Calling with no positional args is a no-op. + priority: When to deliver: + `'asap'` (default) — at the earliest opportunity (next model request, + or a redirect if the agent would otherwise end). + `'when_idle'` — only when the agent would otherwise end, after `'asap'` messages. + + Raises: + UserError: If this `RunContext` isn't backed by a running agent's queue (e.g. the + synthetic context from `Agent.system_prompt_parts`), since there'd be nowhere + to deliver the message. + """ + if self.pending_messages is None: + raise UserError( + '`enqueue` is only available during an agent run (from tools, capability hooks, or ' + '`AgentRun.enqueue`). This `RunContext` has no pending-message queue to drain.' + ) + pending = PendingMessage.from_content(*content, priority=priority) + if pending is None: + return + self.pending_messages.append(pending) + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index afe7a962df..7354654768 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -387,6 +387,18 @@ def now_utc() -> datetime: return datetime.now(tz=timezone.utc) +def fill_run_metadata(message: _messages.ModelMessage, *, run_id: str | None, conversation_id: str | None) -> None: + """Fill in framework-tracked metadata (`timestamp`, `run_id`, `conversation_id`) that's still unset. + + Producer-supplied values are preserved; only unset fields are filled in. Centralizing the field + list here means a new framework-tracked field only needs to be handled in one place, rather than + every site that materializes a message into the history. + """ + message.timestamp = message.timestamp or now_utc() + message.run_id = message.run_id or run_id + message.conversation_id = message.conversation_id or conversation_id + + def guard_tool_call_id( t: _messages.ToolCallPart | _messages.ToolReturnPart diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4c49eca596..7594ec6742 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -50,6 +50,7 @@ from ..capabilities import AbstractCapability, AgentCapability, CombinedCapability, ToolSearch as ToolSearchCap from ..capabilities._dynamic import wrap_capability_funcs from ..capabilities._ordering import has_capability_type +from ..capabilities._pending_messages import PendingMessageDrainCapability from ..capabilities.instrumentation import Instrumentation as InstrumentationCap from ..models.instrumented import InstrumentationSettings, InstrumentedModel from ..output import OutputDataT, OutputSpec, StructuredDict @@ -1332,6 +1333,7 @@ def _merged_meta(ctx: RunContext[AgentDepsT]) -> dict[str, Any]: if instrumentation_settings else DEFAULT_INSTRUMENTATION_VERSION, run_step=0, + pending_messages=state.pending_messages, run_id=state.run_id, conversation_id=state.conversation_id, ) @@ -2894,7 +2896,10 @@ def _retry_overrides_from_spec(spec: AgentSpec) -> AgentRetries: `PydanticAIDeprecationWarning` raised by `AgentSpec._warn_retry_field_deprecations`. """ -_AUTO_INJECT_CAPABILITY_TYPES: tuple[type[AbstractCapability[Any]], ...] = (ToolSearchCap,) +_AUTO_INJECT_CAPABILITY_TYPES: tuple[type[AbstractCapability[Any]], ...] = ( + ToolSearchCap, + PendingMessageDrainCapability, +) """Infrastructure capabilities auto-injected when not already present.""" diff --git a/pydantic_ai_slim/pydantic_ai/capabilities/_pending_messages.py b/pydantic_ai_slim/pydantic_ai/capabilities/_pending_messages.py new file mode 100644 index 0000000000..ab818169dd --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/capabilities/_pending_messages.py @@ -0,0 +1,160 @@ +"""Auto-injected capability that drains the pending message queue at appropriate times.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic_ai._agent_graph import ModelRequestNode +from pydantic_ai._enqueue import PendingMessage, PendingMessagePriority +from pydantic_ai._utils import fill_run_metadata +from pydantic_ai.capabilities.abstract import AbstractCapability, CapabilityOrdering +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ModelMessage, ModelRequest +from pydantic_ai.tools import RunContext +from pydantic_graph import End + +if TYPE_CHECKING: + from pydantic_ai import _agent_graph + from pydantic_ai.models import ModelRequestContext + from pydantic_ai.result import FinalResult + + +def _drain_by_priority( + queue: list[PendingMessage], + priority: PendingMessagePriority, +) -> list[PendingMessage]: + """Remove and return all messages with the given priority from the queue.""" + drained: list[PendingMessage] = [] + remaining: list[PendingMessage] = [] + for msg in queue: + if msg.priority == priority: + drained.append(msg) + else: + remaining.append(msg) + queue[:] = remaining + return drained + + +def _stamped_messages( + drained: list[PendingMessage], + *, + fallback_run_id: str | None, + fallback_conversation_id: str | None, +) -> list[ModelMessage]: + """Flatten drained pending messages, stamping `timestamp` / `run_id` / `conversation_id` where unset. + + Each [`PendingMessage`][pydantic_ai._enqueue.PendingMessage] carries one or more built + [`ModelMessage`][pydantic_ai.messages.ModelMessage]s (assembled at enqueue time by + [`PendingMessage.from_content`][pydantic_ai._enqueue.PendingMessage.from_content]); this only + fills in framework-tracked metadata that the producer left unset, so producer-supplied values + are preserved. + """ + messages: list[ModelMessage] = [] + for pending in drained: + for message in pending.messages: + fill_run_metadata(message, run_id=fallback_run_id, conversation_id=fallback_conversation_id) + messages.append(message) + return messages + + +class PendingMessageDrainCapability(AbstractCapability[Any]): + """Drains the pending message queue at appropriate times. + + - `'asap'` messages drain at the earliest opportunity: into the next + [`ModelRequest`][pydantic_ai.messages.ModelRequest] via `before_model_request`, + or — if the agent would otherwise terminate — redirected through a new + `ModelRequestNode` from `after_node_run`. + - `'when_idle'` messages drain only when the agent would otherwise terminate + and no `'asap'` messages remain, after any `'asap'` redirect. + + This capability is always auto-injected and placed outermost via + [`CapabilityOrdering`][pydantic_ai.capabilities.abstract.CapabilityOrdering] + so it wraps around other capabilities. This ensures `'asap'` messages are + drained into the model request before user capabilities see it, and the + end-of-run redirection runs after all other `after_node_run` hooks (which + run in reverse). + """ + + def get_ordering(self) -> CapabilityOrdering: + return CapabilityOrdering(position='outermost') + + @classmethod + def get_serialization_name(cls) -> str | None: + return None # not spec-constructible (internal, auto-injected) + + async def before_model_request( + self, + ctx: RunContext[Any], + request_context: ModelRequestContext, + ) -> ModelRequestContext: + """Drain `'asap'` messages into the upcoming model request. + + Each drained request is appended to both `request_context.messages` (so the model + sees it this step) and `ctx.messages` (so it persists in the agent's message + history). Stamps `timestamp`/`run_id`/`conversation_id` if the producer didn't — + `ModelRequestNode.run()` only stamps `self.request` (the current node's request), + and capabilities downstream of us might append more messages, so we can't rely on + that fixup. + """ + assert ctx.pending_messages is not None, 'drain runs during an agent run, which always has a queue' + drained = _drain_by_priority(ctx.pending_messages, 'asap') + for message in _stamped_messages( + drained, fallback_run_id=ctx.run_id, fallback_conversation_id=ctx.conversation_id + ): + request_context.messages.append(message) + ctx.messages.append(message) + return request_context + + async def after_node_run( + self, + ctx: RunContext[Any], + *, + node: _agent_graph.AgentNode[Any, Any], + result: _agent_graph.AgentNode[Any, Any] | End[FinalResult[Any]], + ) -> _agent_graph.AgentNode[Any, Any] | End[FinalResult[Any]]: + """Drain remaining `'asap'` and `'when_idle'` messages if the agent would terminate. + + If the run is about to end, drain `'asap'` messages first (anything that arrived + after the most recent `before_model_request` and would otherwise be lost), then + `'when_idle'` messages. Each priority is appended independently so the history + keeps the priority split visible (matches pi-mono's separate steering / follow-up + turns). On the wire, `_clean_message_history` re-merges adjacent requests with + compatible instructions, so the model still sees one turn. + + The last resulting request becomes the redirect + [`ModelRequestNode`][pydantic_ai._agent_graph.ModelRequestNode]'s request; any + earlier ones are appended to `ctx.messages` so they appear in history before the + redirect. + """ + if not isinstance(result, End): + return result + + assert ctx.pending_messages is not None, 'drain runs during an agent run, which always has a queue' + # Pi-mono parity: drain `'asap'` first so anything that arrived during the + # final step (e.g. a background task completing while the model produced + # its final response) gets delivered before `'when_idle'` messages, and the + # agent gets another turn rather than terminating with the message lost. + leftover_asap = _drain_by_priority(ctx.pending_messages, 'asap') + when_idle = _drain_by_priority(ctx.pending_messages, 'when_idle') + if not leftover_asap and not when_idle: + return result + + messages = [ + *_stamped_messages(leftover_asap, fallback_run_id=ctx.run_id, fallback_conversation_id=ctx.conversation_id), + *_stamped_messages(when_idle, fallback_run_id=ctx.run_id, fallback_conversation_id=ctx.conversation_id), + ] + # `final` becomes the redirect node's request; `ModelRequestNode._prepare_request` + # will re-stamp it during the graph lifecycle. `_stamped_messages` already + # stamped it, which is harmless (the lifecycle stamp overwrites). `from_content` + # guarantees each `PendingMessage` ends in a `ModelRequest`, but a producer can + # construct `PendingMessage` (or mutate `RunContext.pending_messages`) directly, so + # we check rather than assert. Any earlier responses/requests become `extras` + # appended to history before the redirect. + *extras, final = messages + if not isinstance(final, ModelRequest): + raise UserError( + 'Enqueued content must end with a `ModelRequest` so the agent has a request to respond to, ' + f'but the last queued message is a `{type(final).__name__}`.' + ) + ctx.messages.extend(extras) + return ModelRequestNode(request=final) diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 498cccdea8..cc061457ea 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -24,6 +24,7 @@ 'SkipToolValidation', 'SkipToolExecution', 'UserError', + 'UndrainedPendingMessagesError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded', @@ -166,6 +167,17 @@ def __init__(self, message: str): super().__init__(message) +class UndrainedPendingMessagesError(UserError): + """Error raised when an agent run ends with messages still queued via `enqueue`. + + A bare `async for node in agent_run` loop only drains `'asap'` messages (in + `before_model_request`); `'when_idle'` messages and end-of-run redirects drain in + `after_node_run`, which bare iteration skips. Reaching the run's `End` with a non-empty + queue means those messages were stranded — drive the run with `agent.run()` or + `AgentRun.next()` instead. + """ + + class AgentRunError(RuntimeError): """Base class for errors occurring during an agent run.""" diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index f0715548cb..c8fc93f64b 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -2359,6 +2359,7 @@ def provider_request_id(self) -> str | None: ModelMessage = Annotated[ModelRequest | ModelResponse, pydantic.Discriminator('kind')] """Any message sent to or returned by a model.""" + ModelMessagesTypeAdapter = pydantic.TypeAdapter( list[ModelMessage], config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64', val_json_bytes='base64') ) diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index caeedad20f..664f160c71 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -19,6 +19,7 @@ usage as _usage, ) from ._deprecated_callable import deprecated_callable_property +from ._enqueue import EnqueueContent, PendingMessage, PendingMessagePriority from ._instrumentation import current_otel_traceparent from .output import OutputDataT from .tools import AgentDepsT @@ -216,7 +217,20 @@ async def __anext__( except BaseException as exc: self._node_error = exc raise - return self._task_to_node(task) + node = self._task_to_node(task) + if isinstance(node, End) and self._graph_run.state.pending_messages: + # `asap` messages drain in `before_model_request` (which fires either way), but + # `when_idle` messages and end-of-run redirects drain in `after_node_run`, which + # bare iteration skips. Reaching `End` with a non-empty queue means those were + # stranded — fail loudly rather than silently dropping the messages. + raise exceptions.UndrainedPendingMessagesError( + 'The agent run ended with undrained pending messages enqueued via `enqueue`. ' + 'Bare `async for node in agent_run` does not drain `when_idle` messages or ' + 'end-of-run redirects, because they fire in `after_node_run`, which bare iteration ' + 'skips. Use `agent_run.next(node)` to advance the run, or `agent.run()` which drives ' + 'via `next()` automatically.' + ) + return node def _task_to_node( self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest] @@ -411,6 +425,49 @@ def conversation_id(self) -> str: """The unique identifier for the conversation this run belongs to.""" return self._graph_run.state.conversation_id + @property + def pending_messages(self) -> list[PendingMessage]: + """Internal: live view of the queue mutated by `enqueue` and drained by [`PendingMessageDrainCapability`][pydantic_ai.capabilities._pending_messages.PendingMessageDrainCapability]. + + Exposed for inspection / debugging; use [`enqueue`][pydantic_ai.run.AgentRun.enqueue] to add messages. + """ + return self._graph_run.state.pending_messages + + def enqueue( + self, + *content: EnqueueContent, + priority: PendingMessagePriority = 'asap', + ) -> None: + """Enqueue content to be injected into the conversation. + + Designed to be called from the same event loop driving `agent.iter()`. If + you're forwarding events from a different thread (e.g. a webhook handler + running on its own loop or thread), marshal the call back onto the agent's + loop first (e.g. `loop.call_soon_threadsafe(agent_run.enqueue, msg)`). + The drain's `queue[:] = remaining` pattern in `_drain_by_priority` isn't + atomic against concurrent appends from a different thread. + + Args: + *content: One or more [`EnqueueContent`][pydantic_ai._enqueue.EnqueueContent] items. + Adjacent [`UserContent`][pydantic_ai.messages.UserContent] (a `str` or multi-modal + content like an [`ImageUrl`][pydantic_ai.messages.ImageUrl]) is gathered into one + [`UserPromptPart`][pydantic_ai.messages.UserPromptPart], and each + [`ModelRequestPart`][pydantic_ai.messages.ModelRequestPart] (e.g. a + [`SystemPromptPart`][pydantic_ai.messages.SystemPromptPart]) is coalesced with adjacent + part-style items into one [`ModelRequest`][pydantic_ai.messages.ModelRequest]; a complete + [`ModelRequest`][pydantic_ai.messages.ModelRequest] or + [`ModelResponse`][pydantic_ai.messages.ModelResponse] is kept as its own message. The + assembled sequence must end in a request. Calling with no positional args is a no-op. + priority: When to deliver: + `'asap'` (default) — at the earliest opportunity (next model request, + or a redirect if the agent would otherwise end). + `'when_idle'` — only when the agent would otherwise end, after `'asap'` messages. + """ + pending = PendingMessage.from_content(*content, priority=priority) + if pending is None: + return + self._graph_run.state.pending_messages.append(pending) + def __repr__(self) -> str: # pragma: no cover result = self._graph_run.output result_repr = '' if result is None else repr(result.output) diff --git a/tests/test_capabilities.py b/tests/test_capabilities.py index cac0a57a31..b52d5bb6cc 100644 --- a/tests/test_capabilities.py +++ b/tests/test_capabilities.py @@ -7,7 +7,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -55,6 +55,7 @@ SkipModelRequest, SkipToolExecution, SkipToolValidation, + UndrainedPendingMessagesError, UnexpectedModelBehavior, UserError, ) @@ -62,11 +63,13 @@ AgentStreamEvent, BinaryImage, FilePart, + ImageUrl, ModelMessage, ModelRequest, ModelResponse, PartStartEvent, RetryPromptPart, + SystemPromptPart, TextPart, ToolCallPart, ToolReturn, @@ -11151,6 +11154,1281 @@ def my_tool() -> str: assert model_call_count == 1 +# ===== Pending Message Queue Tests ===== + + +async def test_enqueue_asap_message_from_tool(): + """`'asap'` messages enqueued from a tool are injected before the next model request.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject_msg(ctx: RunContext[None]) -> str: + ctx.enqueue('Injected asap message') + return 'ok' + + result = await agent.run('Hello') + assert result.output == 'done' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='inject_msg', + content='ok', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='Injected asap message', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_enqueue_when_idle_message_prevents_end(): + """`'when_idle'` messages prevent the agent from ending and are drained into a new ModelRequest.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_follow_up', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + elif call_count == 2: + # Agent produces final result, but follow-up is pending + return ModelResponse( + parts=[TextPart(content='premature end')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + else: + # After follow-up is drained, agent produces real final result + return ModelResponse( + parts=[TextPart(content='final answer after follow-up')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject_follow_up(ctx: RunContext[None]) -> str: + ctx.enqueue('Follow-up context', priority='when_idle') + return 'ok' + + result = await agent.run('Hello') + assert result.output == 'final answer after follow-up' + assert call_count == 3 + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='inject_follow_up', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='inject_follow_up', + content='ok', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='premature end')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='Follow-up context', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='final answer after follow-up')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_enqueue_when_idle_redirects_after_output_tool_end(): + """A `when_idle` message redirects the run even when it would end via an output tool. + + The run terminates when the model calls an output tool (`ToolOutput` mode), which produces + an `End` from `CallToolsNode`. The drain's `after_node_run` still sees that `End` and + redirects into a fresh request, so the agent gets another turn after the structured output — + and the final `result.output` comes from that later turn. + """ + + class Answer(BaseModel): + value: int + + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + output_tool = info.output_tools[0].name + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_follow_up', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + if call_count == 2: + # Would end the run via the output tool, but a `when_idle` message is pending. + return ModelResponse( + parts=[ToolCallPart(tool_name=output_tool, args='{"value": 1}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + # After the follow-up is drained, the model produces the real final output. + return ModelResponse( + parts=[ToolCallPart(tool_name=output_tool, args='{"value": 2}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn), output_type=Answer) + + @agent.tool + def inject_follow_up(ctx: RunContext[None]) -> str: + ctx.enqueue('Follow-up context', priority='when_idle') + return 'ok' + + result = await agent.run('Hello') + + assert result.output == Answer(value=2) + assert call_count == 3 + # The `when_idle` follow-up lands as its own request after the first (superseded) output-tool + # call, redirecting the run so the second output-tool call produces the real output. + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='inject_follow_up', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='inject_follow_up', + content='ok', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"value": 1}', + tool_call_id=IsStr(), + ) + ], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='Follow-up context', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"value": 2}', + tool_call_id=IsStr(), + ) + ], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_enqueue_from_agent_run(): + """Messages can be enqueued from external code via AgentRun.enqueue.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + return ModelResponse( + parts=[TextPart(content=f'response {call_count}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + async with agent.iter('Hello') as agent_run: + assert agent_run.pending_messages == [] + # Enqueue a when_idle message from external code before iteration + agent_run.enqueue('External follow-up', priority='when_idle') + assert len(agent_run.pending_messages) == 1 + # Use next() to drive iteration so after_node_run fires + node = agent_run.next_node + while not isinstance(node, End): + node = await agent_run.next(node) + + assert agent_run.result is not None + assert call_count == 2 # First response triggers End, follow-up prevents it, second response is final + assert agent_run.result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='response 1')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='External follow-up', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='response 2')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_bare_async_for_raises_with_undrained_pending_messages(): + """Bare `async for` reaching End with undrained `when_idle` messages raises rather than stranding them. + + `when_idle` (and end-of-step `asap` leftovers) drain in `after_node_run`, which bare + iteration skips — so they'd be silently lost. `__anext__` raises + `UndrainedPendingMessagesError` when it would yield the `End` node with a non-empty queue, + pointing the user at `next()` driving. + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + async with agent.iter('hi') as agent_run: + agent_run.enqueue('stranded follow-up', priority='when_idle') + with pytest.raises(UndrainedPendingMessagesError, match='undrained pending messages'): + async for _ in agent_run: + pass + + # The message was never delivered: it's still queued. + assert len(agent_run.pending_messages) == 1 + + +async def test_pending_messages_accessible_on_run_context(): + """RunContext.pending_messages is accessible and initially empty.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='check_queue', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def check_queue(ctx: RunContext[None]) -> str: + # The queue must be live (mutations from inside a tool reach the drain). + assert ctx.pending_messages is not None + assert len(ctx.pending_messages) == 0 + ctx.enqueue('observed', priority='asap') + assert len(ctx.pending_messages) == 1 + return 'done' + + result = await agent.run('Test') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Test', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='check_queue', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='check_queue', + content='done', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='observed', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_enqueue_with_no_args_is_a_noop(): + """`ctx.enqueue()` and `agent_run.enqueue()` with no content are silent no-ops. + + Producers that conditionally enqueue (e.g. "announce if new tools were discovered") + can call `enqueue(*maybe_items)` without guarding for the empty case — `enqueue` + simply doesn't append a `PendingMessage` when there's nothing to send. + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='from_tool', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def from_tool(ctx: RunContext[None]) -> str: + ctx.enqueue() # no-op, no exception + assert ctx.pending_messages == [] + return 'ok' + + async with agent.iter('hi') as agent_run: + agent_run.enqueue() # no-op, no exception + assert agent_run.pending_messages == [] + async for _ in agent_run: + pass + + +async def test_enqueue_coerces_string_to_user_prompt(): + """A bare string passed to `enqueue` is wrapped in a `UserPromptPart`.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject_msg(ctx: RunContext[None]) -> str: + ctx.enqueue('steering as plain string') + return 'ok' + + result = await agent.run('Hello') + injected = [ + part + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + for part in msg.parts + if isinstance(part, UserPromptPart) and part.content == 'steering as plain string' + ] + assert len(injected) == 1, 'string-coerced enqueue did not land as a UserPromptPart' + + +async def test_enqueue_accepts_multimodal_user_content(): + """Adjacent user-content args (text + multi-modal) are gathered into one `UserPromptPart`.""" + image = ImageUrl(url='https://example.com/image.png') + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject_msg(ctx: RunContext[None]) -> str: + ctx.enqueue('look at this', image) + return 'ok' + + result = await agent.run('Hello') + injected = [ + part + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + for part in msg.parts + if isinstance(part, UserPromptPart) and part.content == ['look at this', image] + ] + assert len(injected) == 1 + + +async def test_enqueue_accepts_model_request_passthrough(): + """A full `ModelRequest` is enqueued verbatim, preserving `instructions`/`metadata`. + + Two passthroughs cover both branches of the fill-in-if-unset stamping logic: + one with `timestamp`/`run_id` unset (drain stamps them); one with both set + (drain leaves them alone). + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + unstamped = ModelRequest( + parts=[UserPromptPart(content='wire-level payload')], + instructions='do this carefully', + metadata={'origin': 'webhook-42'}, + ) + preset_timestamp = datetime(2024, 1, 1, tzinfo=timezone.utc) + prestamped = ModelRequest( + parts=[UserPromptPart(content='already stamped')], + instructions='preserve me', + timestamp=preset_timestamp, + run_id='caller-run-id', + conversation_id='caller-conv-id', + ) + + @agent.tool + def inject_msg(ctx: RunContext[None]) -> str: + ctx.enqueue(unstamped) + ctx.enqueue(prestamped) + return 'ok' + + result = await agent.run('Hello') + + injected_unstamped = next( + msg + for msg in result.all_messages() + if isinstance(msg, ModelRequest) and msg.instructions == 'do this carefully' + ) + assert injected_unstamped.metadata == {'origin': 'webhook-42'} + # Drain should have stamped timestamp/run_id/conversation_id since the user didn't set them. + assert injected_unstamped.timestamp is not None + assert injected_unstamped.run_id is not None + assert injected_unstamped.conversation_id is not None + + injected_prestamped = next( + msg for msg in result.all_messages() if isinstance(msg, ModelRequest) and msg.instructions == 'preserve me' + ) + # Producer-supplied timestamp/run_id/conversation_id are preserved (drain doesn't overwrite). + assert injected_prestamped.timestamp == preset_timestamp + assert injected_prestamped.run_id == 'caller-run-id' + assert injected_prestamped.conversation_id == 'caller-conv-id' + + +def test_pending_message_drain_capability_is_not_spec_constructible(): + """`PendingMessageDrainCapability` is auto-injected only; can't be in an `AgentSpec`.""" + from pydantic_ai.capabilities._pending_messages import PendingMessageDrainCapability + + assert PendingMessageDrainCapability.get_serialization_name() is None + + +def test_pending_message_allows_empty_request(): + """`PendingMessage` doesn't validate its `messages`; empty-parts requests are tolerated. + + `enqueue()` already filters out the no-args case (no `PendingMessage` is appended). + An empty `ModelRequest` reaching the queue is harmless — the drain stamps and forwards + it, and downstream wire-merging absorbs zero-part messages as a natural no-op. + """ + from pydantic_ai._enqueue import PendingMessage + + msg = PendingMessage(messages=[ModelRequest(parts=[])]) + assert msg.priority == 'asap' + assert msg.messages[0].parts == [] + + +def test_enqueue_without_live_queue_raises(): + """`ctx.enqueue` raises when the `RunContext` isn't backed by a running agent's queue. + + Synthetic contexts (e.g. the one `Agent.system_prompt_parts` builds to resolve system + prompts outside a run) have no queue to drain to, so enqueue fails loudly instead of + silently dropping the message. + """ + ctx = RunContext[None](deps=None, model=TestModel(), usage=RunUsage(), prompt=None, messages=[]) + assert ctx.pending_messages is None + with pytest.raises(UserError, match='only available during an agent run'): + ctx.enqueue('this has nowhere to go') + + +async def test_enqueue_parts_style_calls_produce_one_request_per_call(): + """Each `enqueue` call produces its own `ModelRequest` in history. + + Each `enqueue` call pre-packages its content into a `ModelRequest` at enqueue time, + so two calls produce two `PendingMessage`s with two separate `ModelRequest`s. The + history reflects per-call structure; wire-level `_clean_message_history` still merges + adjacent compatible `ModelRequest`s so the model sees one turn. Producers wanting a + single message should pass a single `ModelRequest(parts=[...])` themselves. + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject_msg(ctx: RunContext[None]) -> str: + ctx.enqueue('first hint') + ctx.enqueue('second hint') + return 'ok' + + result = await agent.run('Hello') + drained = [ + msg + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + and any(isinstance(p, UserPromptPart) and p.content in ('first hint', 'second hint') for p in msg.parts) + ] + assert len(drained) == 2, 'expected one ModelRequest per enqueue call' + assert [p.content for msg in drained for p in msg.parts if isinstance(p, UserPromptPart)] == [ + 'first hint', + 'second hint', + ] + + +async def test_enqueue_passthrough_stays_separate_from_parts_style(): + """A passthrough `ModelRequest` stays its own message even when surrounded by parts-style enqueues.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_msg', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject_msg(ctx: RunContext[None]) -> str: + ctx.enqueue('before') + ctx.enqueue( + ModelRequest(parts=[UserPromptPart(content='passthrough')], instructions='careful'), + ) + ctx.enqueue('after') + return 'ok' + + result = await agent.run('Hello') + # Three drained requests: synthesized(["before"]), passthrough, synthesized(["after"]). + drained = [ + msg + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + and any(isinstance(p, UserPromptPart) and p.content in ('before', 'passthrough', 'after') for p in msg.parts) + ] + assert len(drained) == 3 + contents = [ + next( + p.content + for p in r.parts + if isinstance(p, UserPromptPart) and p.content in ('before', 'passthrough', 'after') + ) + for r in drained + ] + assert contents == ['before', 'passthrough', 'after'] + # Passthrough preserved its instructions. + assert drained[1].instructions == 'careful' + assert drained[0].instructions is None + assert drained[2].instructions is None + + +async def test_enqueue_system_prompt_part(): + """A bare `SystemPromptPart` is coalesced into a `ModelRequest` and delivered. + + Now that mid-conversation `SystemPromptPart`s are rendered inline (not hoisted) on all + providers, `enqueue` accepts request parts directly — no `ModelRequest` wrapper needed. + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='announce', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def announce(ctx: RunContext[None]) -> str: + ctx.enqueue(SystemPromptPart(content='New tools are now available.')) + return 'ok' + + result = await agent.run('Hello') + injected = next( + msg + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + and any(isinstance(p, SystemPromptPart) and p.content == 'New tools are now available.' for p in msg.parts) + ) + assert injected is not None + + +async def test_enqueue_interleaved_response_and_request(): + """One `enqueue` call can inject an interleaved `ModelResponse` + `ModelRequest` exchange. + + This is the synthetic "tool-search call + result" shape (a `ModelResponse` carrying the call + followed by a `ModelRequest` carrying the return). Both land in history in order, and the + trailing request is what the agent responds to next. + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject_exchange', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + synthetic_response = ModelResponse( + parts=[TextPart(content='synthetic prior turn')], + usage=RequestUsage(input_tokens=1, output_tokens=1), + ) + + @agent.tool + def inject_exchange(ctx: RunContext[None]) -> str: + ctx.enqueue( + synthetic_response, + ModelRequest(parts=[UserPromptPart(content='follow-up after synthetic turn')]), + priority='when_idle', + ) + return 'ok' + + result = await agent.run('Hello') + # The synthetic response is appended to history immediately before its paired request. + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='inject_exchange', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='inject_exchange', + content='ok', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='synthetic prior turn')], + usage=RequestUsage(input_tokens=1, output_tokens=1), + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='follow-up after synthetic turn', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_enqueue_rejects_content_not_ending_in_request(): + """Enqueued content must end in a `ModelRequest`; a lone `ModelResponse` is rejected. + + The agent needs a request to respond to — content that ends in a `ModelResponse` (with no + trailing request/part-style items) would leave nothing for the model to react to. + """ + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(msg, ModelResponse) for msg in messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='from_tool', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + lone_response = ModelResponse( + parts=[TextPart(content='synthetic')], usage=RequestUsage(input_tokens=1, output_tokens=1) + ) + + @agent.tool + def from_tool(ctx: RunContext[None]) -> str: + with pytest.raises(UserError, match='must end with a `ModelRequest`'): + ctx.enqueue(lone_response) + return 'ok' + + async with agent.iter('hi') as agent_run: + with pytest.raises(UserError, match='must end with a `ModelRequest`'): + agent_run.enqueue(lone_response) + async for _ in agent_run: + pass + + +async def test_drain_rejects_directly_queued_content_not_ending_in_request(): + """Directly appending a malformed `PendingMessage` raises a `UserError` at end-of-run drain. + + `enqueue` enforces the "ends in a `ModelRequest`" rule up front, but `RunContext.pending_messages` + is public, so a producer can append a `PendingMessage` directly. The end-of-run drain catches a + request-less message with a helpful `UserError` rather than a bare assertion. + """ + from pydantic_ai._enqueue import PendingMessage + + lone_response = ModelResponse( + parts=[TextPart(content='synthetic')], usage=RequestUsage(input_tokens=1, output_tokens=1) + ) + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if any(isinstance(p, ToolReturnPart) for m in messages if isinstance(m, ModelRequest) for p in m.parts): + return ModelResponse(parts=[TextPart(content='done')], usage=RequestUsage(input_tokens=10, output_tokens=5)) + return ModelResponse( + parts=[ToolCallPart(tool_name='queue_bad', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def queue_bad(ctx: RunContext[None]) -> str: + assert ctx.pending_messages is not None + ctx.pending_messages.append(PendingMessage(messages=[lone_response], priority='when_idle')) + return 'ok' + + with pytest.raises(UserError, match='must end with a `ModelRequest`'): + await agent.run('hi') + + +async def test_enqueue_asap_with_rich_message_history_tail(): + """`'asap'` enqueue lands as its own `ModelRequest` in history *and* gets wire-merged into the rich tail. + + The history keeps the un-merged view (drain's request is a separate `ModelRequest` + after the rich tail) so `all_messages()` reflects per-call structure. On the wire, + `_clean_message_history` merges the two adjacent `ModelRequest`s and sorts + `ToolReturnPart`/`RetryPromptPart` first — non-tool parts keep arrival order, so the + enqueued content lands at the *end* of the merged turn (not interleaved between + existing parts). Captures the `messages` arg `FunctionModel` actually received to + validate the wire-level merge through the public path. + """ + captured_wire_messages: list[list[ModelMessage]] = [] + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured_wire_messages.append(messages) + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + history: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='original prompt')]), + ModelResponse( + parts=[ToolCallPart(tool_name='hint', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='hint', content='ok', tool_call_id='call-1'), + UserPromptPart(content='follow-up question'), + ], + ), + ] + + async with agent.iter(message_history=history) as agent_run: + agent_run.enqueue('injected after rich tail') + async for _ in agent_run: + pass + + assert agent_run.result is not None + # `all_messages()` keeps the un-merged view (drain's request is a separate + # `ModelRequest` after the rich tail). + assert agent_run.result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='original prompt', timestamp=IsDatetime())]), + ModelResponse( + parts=[ToolCallPart(tool_name='hint', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='hint', content='ok', tool_call_id='call-1', timestamp=IsDatetime()), + UserPromptPart(content='follow-up question', timestamp=IsDatetime()), + ], + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[UserPromptPart(content='injected after rich tail', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + # And the wire-level view: the rich tail and the drained request merged into one + # `ModelRequest`, with `ToolReturnPart` first and the user-prompt parts in arrival + # order (so the enqueued content lands at the end, not interleaved). + assert len(captured_wire_messages) == 1 + assert captured_wire_messages[0] == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='original prompt', timestamp=IsDatetime())]), + ModelResponse( + parts=[ToolCallPart(tool_name='hint', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='hint', content='ok', tool_call_id='call-1', timestamp=IsDatetime()), + UserPromptPart(content='follow-up question', timestamp=IsDatetime()), + UserPromptPart(content='injected after rich tail', timestamp=IsDatetime()), + ], + timestamp=IsDatetime(), + ), + ] + ) + + +async def test_enqueue_asap_drains_at_end_if_arrived_during_final_step(): + """`'asap'` arriving during the final step (after its `before_model_request` drain) still gets delivered. + + Simulates the background-tools pattern: a long-running task completes *during* what + would have been the model's final response. The enqueue happens after the step's + `before_model_request` drain has already fired, so the message can only be picked up + by the end-of-run drain (matching pi-mono's drain-on-end). Without this fallback the + message would be lost. `'asap'` semantically means "deliver at the earliest opportunity" + — including redirecting if the agent would otherwise terminate before another call. + """ + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[TextPart(content='would-have-ended')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[TextPart(content='final after late asap')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + @dataclass + class BackgroundTaskCap(AbstractCapability[Any]): + """Simulates a background task that completes mid-model-response on the first call only.""" + + fired: bool = False + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + if not self.fired: + ctx.enqueue('background task result', priority='asap') + self.fired = True + return response + + agent = Agent(FunctionModel(model_fn), capabilities=[BackgroundTaskCap()]) + + result = await agent.run('Hello') + assert result.output == 'final after late asap' + assert call_count == 2 + # The 'asap' message landed in its own ModelRequest before the final response, + # not lost despite the agent producing a no-tool-call response. + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='would-have-ended')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='background task result', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='final after late asap')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + conversation_id=IsStr(), + ), + ] + ) + + +async def test_enqueue_when_idle_drains_after_leftover_asap(): + """If both `'asap'` and `'when_idle'` are queued at end-of-run, `'asap'` drains first.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Only fire enqueues once. + already_enqueued = any( + isinstance(p, UserPromptPart) and p.content in ('A', 'B') + for msg in messages + if isinstance(msg, ModelRequest) + for p in msg.parts + ) + # If we've already seen our injected messages, just terminate. + if already_enqueued: + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='inject', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn)) + + @agent.tool + def inject(ctx: RunContext[None]) -> str: + ctx.enqueue('B', priority='when_idle') + ctx.enqueue('A', priority='asap') + return 'ok' + + result = await agent.run('Hello') + # Both A and B should appear in history. `'asap'` (A) drains in `before_model_request` + # before the second call. `'when_idle'` (B) drains at end-of-run when the second + # response has no tool calls. + requests_with_injected = [ + msg + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + and any(isinstance(p, UserPromptPart) and p.content in ('A', 'B') for p in msg.parts) + ] + contents = [ + [p.content for p in r.parts if isinstance(p, UserPromptPart) and p.content in ('A', 'B')] + for r in requests_with_injected + ] + assert contents == [['A'], ['B']], f'expected A before B in separate requests, got {contents}' + + +async def test_enqueue_priorities_stay_separate_when_both_drain_at_end_of_run(): + """When both `'asap'` and `'when_idle'` parts-style payloads drain together at end-of-run, + they land in separate `ModelRequest`s — the priority split stays visible in history. + + Reaches the case Devin flagged: a tool enqueues `'when_idle'` (which sits until + end-of-run), and a capability `after_model_request` hook enqueues `'asap'` during the + final step (after that step's `before_model_request` drain has already fired). Both + arrive at `after_node_run`. Without the per-priority split they'd merge into one + synthesized request, blurring the priority distinction in the persisted history. + On the wire `_clean_message_history` still merges them for the model. + """ + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='inject', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + if call_count == 2: + return ModelResponse( + parts=[TextPart(content='would-have-ended')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[TextPart(content='final')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + @dataclass + class LateAsapCap(AbstractCapability[Any]): + """Enqueues an `'asap'` message during `after_model_request` of the no-tool-call step. + + Fires after the step's `before_model_request` drain, so the message can only be + delivered via the end-of-run drain in `after_node_run`. + """ + + fired: bool = False + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + if not self.fired and any( + isinstance(p, TextPart) and p.content == 'would-have-ended' for p in response.parts + ): + ctx.enqueue('asap-from-cap') + self.fired = True + return response + + agent = Agent(FunctionModel(model_fn), capabilities=[LateAsapCap()]) + + @agent.tool + def inject(ctx: RunContext[None]) -> str: + ctx.enqueue('when-idle-from-tool', priority='when_idle') + return 'ok' + + result = await agent.run('Hello') + assert result.output == 'final' + + # Find the two end-of-run drained requests: one with the 'asap' content, one with 'when_idle'. + drained = [ + msg + for msg in result.all_messages() + if isinstance(msg, ModelRequest) + and any( + isinstance(p, UserPromptPart) and p.content in ('asap-from-cap', 'when-idle-from-tool') for p in msg.parts + ) + ] + contents = [ + next( + p.content + for p in r.parts + if isinstance(p, UserPromptPart) and p.content in ('asap-from-cap', 'when-idle-from-tool') + ) + for r in drained + ] + assert contents == ['asap-from-cap', 'when-idle-from-tool'], ( + f'asap and when_idle should land in separate ModelRequests with asap first, got {contents}' + ) + # Each priority bucket got its own ModelRequest (not merged into one). + assert all(len([p for p in r.parts if isinstance(p, UserPromptPart)]) == 1 for r in drained) + + # --- Output hook tests --- diff --git a/tests/test_examples.py b/tests/test_examples.py index 946086ea97..6f71c847e3 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -390,6 +390,10 @@ async def call_tool( tool_name='get_weather_forecast', args={'location': 'Paris'}, tool_call_id='0001' ), 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.', + 'Summarize the latest deploy report': 'Initial deploy summary: 12 services updated, all green.', + 'A new error was just reported — include it in the summary.': ( + 'Updated summary: 12 services updated, all green; one new error in the auth service was just reported.' + ), 'Tell me a different joke.': 'No.', 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.', 'What is the weather in Tokyo?': 'As of 7:48 AM on Wednesday, April 2, 2025, in Tokyo, Japan, the weather is cloudy with a temperature of 53°F (12°C).',