From 8ee25f9bc748c5c4453ea48aa224a9f3cd2075aa Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:43:18 +0530 Subject: [PATCH 1/6] fix(evals): constrain judge_output reason to be concise and retry-safe The judge agents' system prompts now explicitly instruct the model to keep the `reason` field to a concise 1-2 sentence summary, preventing reasoning/thinking text from leaking into the public reason. The GradingOutput.reason field also gains a description that reinforces this constraint via the JSON schema. This makes `reason` stable and suitable for use in ModelRetry feedback loops, where verbose or self-contradictory reasoning text would otherwise degrade retry quality. Fixes #5034 --- .../evaluators/llm_as_a_judge.py | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py index 0b6fed34e9..f5158ae14d 100644 --- a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py +++ b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py @@ -27,26 +27,37 @@ class GradingOutput(BaseModel, populate_by_name=True): """The output of a grading operation.""" - reason: str + reason: str = Field( + description='A concise 1-2 sentence explanation of why the output passed or failed.', + ) pass_: bool = Field(validation_alias='pass', serialization_alias='pass') score: float +_JUDGE_REASON_INSTRUCTION = ( + 'The "reason" field must be a concise 1-2 sentence summary of your verdict. ' + 'Do not include your reasoning process, self-corrections, or re-checking in the reason. ' + 'State only the final conclusion.' +) + + _judge_output_agent = Agent( name='judge_output', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: Hello world Content contains a greeting - {"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0} + {{"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0}} Avast ye swabs, repel the invaders! Does not speak like a pirate - {"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0} + {{"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, @@ -73,20 +84,22 @@ async def judge_output( _judge_input_output_agent = Agent( name='judge_input_output', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input and output, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input and output, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: Hello world Hello Content contains a greeting word which is present in the input - {"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0} + {{"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0}} Pirate Avast ye swabs, repel the invaders! Does not speak in the style described by the input - {"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0} + {{"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, @@ -115,8 +128,10 @@ async def judge_input_output( _judge_input_output_expected_agent = Agent( name='judge_input_output_expected', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input, expected output, and output, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input, expected output, and output, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: @@ -124,13 +139,13 @@ async def judge_input_output( Blue Cerulean The output is consistent with the expected output but doesn't have to match exactly - {"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0} + {{"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0}} How many legs does a spider have? 8 Six The output is factually consistent with the expected output - {"reason": "Spiders have 8 legs", "pass": false, "score": 0.0} + {{"reason": "Spiders have 8 legs", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, @@ -162,20 +177,22 @@ async def judge_input_output_expected( _judge_output_expected_agent = Agent( name='judge_output_expected', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided expected output and output, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided expected output and output, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: Blue Cerulean The output should be a shade of the expected output color - {"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0} + {{"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0}} 8 Six The output should be a number written in words which matches the number written in digits in the expected output - {"reason": "The output is 'Six' which is a different number than 8", "pass": false, "score": 0.0} + {{"reason": "The output is 'Six' which is a different number than 8", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, From a5c0e3effc28a1665057e4bf5a965b41cd02534e Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:47:26 +0530 Subject: [PATCH 2/6] feat: add built-in `repair_orphaned_tool_parts` history processor Adds a ready-to-use history processor that removes structurally invalid tool call/return pairs from message history. This prevents 400 errors from providers (especially Anthropic) that reject orphaned tool references after streaming timeouts, deferred tool drops, or history trimming. Two-pass repair: 1. Remove ToolReturnPart/RetryPromptPart whose tool_call_id has no matching ToolCallPart 2. Remove ToolCallPart whose tool_call_id has no matching return Output-validation RetryPromptParts (tool_name=None) are preserved since they are not tied to tool calls. Closes #4728 --- .../pydantic_ai/history_processors.py | 110 ++++++++++ tests/test_history_processors.py | 205 ++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 pydantic_ai_slim/pydantic_ai/history_processors.py create mode 100644 tests/test_history_processors.py diff --git a/pydantic_ai_slim/pydantic_ai/history_processors.py b/pydantic_ai_slim/pydantic_ai/history_processors.py new file mode 100644 index 0000000000..14a161a153 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/history_processors.py @@ -0,0 +1,110 @@ +"""Built-in history processor functions for common message history repair tasks. + +These functions can be passed directly to `Agent(history_processors=[...])` or +used with `capabilities.HistoryProcessor(processor=...)`. +""" + +from __future__ import annotations + +import logging +from dataclasses import replace + +from pydantic_ai import messages as _messages + +__all__ = ('repair_orphaned_tool_parts',) + +logger = logging.getLogger(__name__) + + +def repair_orphaned_tool_parts( + messages: list[_messages.ModelMessage], +) -> list[_messages.ModelMessage]: + """Remove orphaned tool call/return parts from message history. + + Multi-turn agent conversations can accumulate structurally invalid history + when tool calls and their corresponding results become mismatched. Common + causes include streaming timeouts, deferred tool result drops, and history + trimming by other processors. + + Providers like Anthropic strictly enforce that every `ToolCallPart` has a + matching `ToolReturnPart` (or `RetryPromptPart`) and vice versa; orphaned + entries cause 400 errors. + + This processor performs a two-pass repair: + + 1. **Orphaned returns/retries**: `ToolReturnPart` or `RetryPromptPart` whose + `tool_call_id` does not match any preceding `ToolCallPart` are removed. + 2. **Orphaned calls**: `ToolCallPart` whose `tool_call_id` does not match + any following `ToolReturnPart` or `RetryPromptPart` are removed. + + Empty messages (all parts removed) are dropped entirely. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_ai.history_processors import repair_orphaned_tool_parts + + agent = Agent('openai:gpt-4o', history_processors=[repair_orphaned_tool_parts]) + ``` + """ + call_ids: set[str] = set() + for message in messages: + if isinstance(message, _messages.ModelResponse): + for part in message.parts: + if isinstance(part, _messages.ToolCallPart) and part.tool_call_id: + call_ids.add(part.tool_call_id) + + return_ids: set[str] = set() + for message in messages: + if isinstance(message, _messages.ModelRequest): + for part in message.parts: + if isinstance(part, (_messages.ToolReturnPart, _messages.RetryPromptPart)): + if part.tool_call_id: + return_ids.add(part.tool_call_id) + + repaired: list[_messages.ModelMessage] = [] + for message in messages: + if isinstance(message, _messages.ModelRequest): + kept_parts: list[_messages.ModelRequestPart] = [] + for part in message.parts: + if isinstance(part, _messages.ToolReturnPart): + if part.tool_call_id and part.tool_call_id not in call_ids: + logger.debug( + 'Removing orphaned ToolReturnPart with tool_call_id=%r (no matching ToolCallPart)', + part.tool_call_id, + ) + continue + elif isinstance(part, _messages.RetryPromptPart): + if part.tool_name is not None and part.tool_call_id and part.tool_call_id not in call_ids: + logger.debug( + 'Removing orphaned RetryPromptPart with tool_call_id=%r (no matching ToolCallPart)', + part.tool_call_id, + ) + continue + kept_parts.append(part) + + if kept_parts: + if len(kept_parts) != len(message.parts): + repaired.append(replace(message, parts=kept_parts)) + else: + repaired.append(message) + + elif isinstance(message, _messages.ModelResponse): + kept_response_parts: list[_messages.ModelResponsePart] = [] + for part in message.parts: + if isinstance(part, _messages.ToolCallPart): + if part.tool_call_id and part.tool_call_id not in return_ids: + logger.debug( + 'Removing orphaned ToolCallPart with tool_call_id=%r (no matching return)', + part.tool_call_id, + ) + continue + kept_response_parts.append(part) + + if kept_response_parts: + if len(kept_response_parts) != len(message.parts): + repaired.append(replace(message, parts=kept_response_parts)) + else: + repaired.append(message) + + return repaired diff --git a/tests/test_history_processors.py b/tests/test_history_processors.py new file mode 100644 index 0000000000..7a59077d5c --- /dev/null +++ b/tests/test_history_processors.py @@ -0,0 +1,205 @@ +"""Tests for built-in history processor functions.""" + +from __future__ import annotations + +import pytest + +from pydantic_ai.history_processors import repair_orphaned_tool_parts +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + RetryPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) + + +def test_no_changes_needed(): + """Matched pairs pass through untouched.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse(parts=[ToolCallPart(tool_name='get_data', tool_call_id='call_1')]), + ModelRequest(parts=[ToolReturnPart(tool_name='get_data', content='result', tool_call_id='call_1')]), + ModelResponse(parts=[TextPart(content='done')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result[0] == messages[0] + assert result[1] == messages[1] + assert result[2] == messages[2] + assert result[3] == messages[3] + + +def test_orphaned_tool_return_removed(): + """ToolReturnPart with no matching ToolCallPart is removed.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='unknown', content='orphan', tool_call_id='call_missing'), + UserPromptPart(content='continue'), + ] + ), + ModelResponse(parts=[TextPart(content='ok')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 3 + assert len(result[1].parts) == 1 + assert isinstance(result[1].parts[0], UserPromptPart) + + +def test_orphaned_retry_prompt_removed(): + """RetryPromptPart with no matching ToolCallPart is removed.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelRequest( + parts=[ + RetryPromptPart(content='try again', tool_name='missing_tool', tool_call_id='call_gone'), + ] + ), + ModelResponse(parts=[TextPart(content='ok')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 2 + + +def test_orphaned_tool_call_removed(): + """ToolCallPart with no matching ToolReturnPart or RetryPromptPart is removed.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse( + parts=[ + TextPart(content='Let me call a tool'), + ToolCallPart(tool_name='timed_out', tool_call_id='call_orphan'), + ] + ), + ModelRequest(parts=[UserPromptPart(content='what happened?')]), + ModelResponse(parts=[TextPart(content='sorry about that')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + response = result[1] + assert isinstance(response, ModelResponse) + assert len(response.parts) == 1 + assert isinstance(response.parts[0], TextPart) + + +def test_empty_message_removed(): + """Messages with all parts removed are dropped entirely.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse(parts=[ToolCallPart(tool_name='lost', tool_call_id='call_lost')]), + ModelRequest( + parts=[ToolReturnPart(tool_name='ghost', content='data', tool_call_id='call_ghost')] + ), + ModelResponse(parts=[TextPart(content='end')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 2 + assert isinstance(result[0].parts[0], UserPromptPart) + assert isinstance(result[1], ModelResponse) + assert isinstance(result[1].parts[0], TextPart) + + +def test_multiple_matched_pairs(): + """Multiple valid tool call/return pairs are preserved.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='do work')]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='a', tool_call_id='id_a'), + ToolCallPart(tool_name='b', tool_call_id='id_b'), + ] + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='a', content='result_a', tool_call_id='id_a'), + ToolReturnPart(tool_name='b', content='result_b', tool_call_id='id_b'), + ] + ), + ModelResponse(parts=[TextPart(content='all done')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result == messages + + +def test_mixed_orphans_and_valid(): + """Only orphaned parts are removed; valid pairs remain.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='go')]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='valid', tool_call_id='id_ok'), + ToolCallPart(tool_name='orphan_call', tool_call_id='id_orphan'), + ] + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='valid', content='good', tool_call_id='id_ok'), + ToolReturnPart(tool_name='orphan_return', content='bad', tool_call_id='id_no_call'), + ] + ), + ModelResponse(parts=[TextPart(content='done')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + + response = result[1] + assert isinstance(response, ModelResponse) + assert len(response.parts) == 1 + assert isinstance(response.parts[0], ToolCallPart) + assert response.parts[0].tool_call_id == 'id_ok' + + request = result[2] + assert isinstance(request, ModelRequest) + assert len(request.parts) == 1 + assert isinstance(request.parts[0], ToolReturnPart) + assert request.parts[0].tool_call_id == 'id_ok' + + +def test_retry_prompt_matches_call(): + """RetryPromptPart with matching ToolCallPart is preserved.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='try')]), + ModelResponse(parts=[ToolCallPart(tool_name='flaky', tool_call_id='id_retry')]), + ModelRequest( + parts=[RetryPromptPart(content='bad args', tool_name='flaky', tool_call_id='id_retry')] + ), + ModelResponse(parts=[TextPart(content='ok')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result == messages + + +def test_empty_history(): + """Empty input returns empty output.""" + assert repair_orphaned_tool_parts([]) == [] + + +def test_text_only_history(): + """History with no tool parts passes through unchanged.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse(parts=[TextPart(content='hi')]), + ModelRequest(parts=[UserPromptPart(content='bye')]), + ModelResponse(parts=[TextPart(content='goodbye')]), + ] + result = repair_orphaned_tool_parts(messages) + assert result == messages + + +def test_retry_prompt_without_tool_name_preserved(): + """RetryPromptPart without tool_name (output validation retry) is kept.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='generate')]), + ModelResponse(parts=[TextPart(content='bad output')]), + ModelRequest(parts=[RetryPromptPart(content='validation failed')]), + ModelResponse(parts=[TextPart(content='better output')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result == messages From 387a38ee7b9752877e699d0cf48c0f108c62a747 Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:14:03 +0530 Subject: [PATCH 3/6] fix: address lint, coverage, and Devin review findings - Remove unused `import pytest` from test file - Fix formatting (pre-commit auto-format compliance) - Remove redundant `tool_call_id` truthiness guards (always set by default) - Add `pragma: no branch` for exhaustive ModelMessage union branch - Achieves 100% branch coverage --- .../pydantic_ai/history_processors.py | 15 +++++++-------- tests/test_history_processors.py | 10 ++-------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/history_processors.py b/pydantic_ai_slim/pydantic_ai/history_processors.py index 14a161a153..3c86182de0 100644 --- a/pydantic_ai_slim/pydantic_ai/history_processors.py +++ b/pydantic_ai_slim/pydantic_ai/history_processors.py @@ -44,7 +44,7 @@ def repair_orphaned_tool_parts( from pydantic_ai import Agent from pydantic_ai.history_processors import repair_orphaned_tool_parts - agent = Agent('openai:gpt-4o', history_processors=[repair_orphaned_tool_parts]) + agent = Agent('openai:gpt-4o-mini', history_processors=[repair_orphaned_tool_parts]) ``` """ call_ids: set[str] = set() @@ -58,9 +58,8 @@ def repair_orphaned_tool_parts( for message in messages: if isinstance(message, _messages.ModelRequest): for part in message.parts: - if isinstance(part, (_messages.ToolReturnPart, _messages.RetryPromptPart)): - if part.tool_call_id: - return_ids.add(part.tool_call_id) + if isinstance(part, (_messages.ToolReturnPart, _messages.RetryPromptPart)) and part.tool_call_id: + return_ids.add(part.tool_call_id) repaired: list[_messages.ModelMessage] = [] for message in messages: @@ -68,14 +67,14 @@ def repair_orphaned_tool_parts( kept_parts: list[_messages.ModelRequestPart] = [] for part in message.parts: if isinstance(part, _messages.ToolReturnPart): - if part.tool_call_id and part.tool_call_id not in call_ids: + if part.tool_call_id not in call_ids: logger.debug( 'Removing orphaned ToolReturnPart with tool_call_id=%r (no matching ToolCallPart)', part.tool_call_id, ) continue elif isinstance(part, _messages.RetryPromptPart): - if part.tool_name is not None and part.tool_call_id and part.tool_call_id not in call_ids: + if part.tool_name is not None and part.tool_call_id not in call_ids: logger.debug( 'Removing orphaned RetryPromptPart with tool_call_id=%r (no matching ToolCallPart)', part.tool_call_id, @@ -89,11 +88,11 @@ def repair_orphaned_tool_parts( else: repaired.append(message) - elif isinstance(message, _messages.ModelResponse): + elif isinstance(message, _messages.ModelResponse): # pragma: no branch kept_response_parts: list[_messages.ModelResponsePart] = [] for part in message.parts: if isinstance(part, _messages.ToolCallPart): - if part.tool_call_id and part.tool_call_id not in return_ids: + if part.tool_call_id not in return_ids: logger.debug( 'Removing orphaned ToolCallPart with tool_call_id=%r (no matching return)', part.tool_call_id, diff --git a/tests/test_history_processors.py b/tests/test_history_processors.py index 7a59077d5c..6dcdd2515b 100644 --- a/tests/test_history_processors.py +++ b/tests/test_history_processors.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from pydantic_ai.history_processors import repair_orphaned_tool_parts from pydantic_ai.messages import ( ModelRequest, @@ -91,9 +89,7 @@ def test_empty_message_removed(): messages = [ ModelRequest(parts=[UserPromptPart(content='hello')]), ModelResponse(parts=[ToolCallPart(tool_name='lost', tool_call_id='call_lost')]), - ModelRequest( - parts=[ToolReturnPart(tool_name='ghost', content='data', tool_call_id='call_ghost')] - ), + ModelRequest(parts=[ToolReturnPart(tool_name='ghost', content='data', tool_call_id='call_ghost')]), ModelResponse(parts=[TextPart(content='end')]), ] result = repair_orphaned_tool_parts(messages) @@ -165,9 +161,7 @@ def test_retry_prompt_matches_call(): messages = [ ModelRequest(parts=[UserPromptPart(content='try')]), ModelResponse(parts=[ToolCallPart(tool_name='flaky', tool_call_id='id_retry')]), - ModelRequest( - parts=[RetryPromptPart(content='bad args', tool_name='flaky', tool_call_id='id_retry')] - ), + ModelRequest(parts=[RetryPromptPart(content='bad args', tool_name='flaky', tool_call_id='id_retry')]), ModelResponse(parts=[TextPart(content='ok')]), ] result = repair_orphaned_tool_parts(messages) From 40797f21fb98fac2fc94339a8f456d2d030ff0f8 Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:14:24 +0530 Subject: [PATCH 4/6] fix: use frontier model name in docstring example per repo conventions --- pydantic_ai_slim/pydantic_ai/history_processors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/history_processors.py b/pydantic_ai_slim/pydantic_ai/history_processors.py index 3c86182de0..89b144f5a1 100644 --- a/pydantic_ai_slim/pydantic_ai/history_processors.py +++ b/pydantic_ai_slim/pydantic_ai/history_processors.py @@ -44,7 +44,7 @@ def repair_orphaned_tool_parts( from pydantic_ai import Agent from pydantic_ai.history_processors import repair_orphaned_tool_parts - agent = Agent('openai:gpt-4o-mini', history_processors=[repair_orphaned_tool_parts]) + agent = Agent('openai:gpt-5.2', history_processors=[repair_orphaned_tool_parts]) ``` """ call_ids: set[str] = set() From 972394805f21acbef0d79b5ca64aef350760b668 Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:26:46 +0530 Subject: [PATCH 5/6] refactor: extract helpers to fix C901 complexity (24 > 15) Split repair_orphaned_tool_parts into focused helpers: - _collect_tool_call_ids / _collect_tool_return_ids - _is_orphaned_request_part - _repair_request / _repair_response - _rebuild_or_drop All 11 tests pass, 100% branch coverage, ruff clean. --- .../pydantic_ai/history_processors.py | 129 +++++++++++------- 1 file changed, 78 insertions(+), 51 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/history_processors.py b/pydantic_ai_slim/pydantic_ai/history_processors.py index 89b144f5a1..4bc2daec55 100644 --- a/pydantic_ai_slim/pydantic_ai/history_processors.py +++ b/pydantic_ai_slim/pydantic_ai/history_processors.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +from collections.abc import Sequence from dataclasses import replace from pydantic_ai import messages as _messages @@ -47,63 +48,89 @@ def repair_orphaned_tool_parts( agent = Agent('openai:gpt-5.2', history_processors=[repair_orphaned_tool_parts]) ``` """ - call_ids: set[str] = set() + call_ids = _collect_tool_call_ids(messages) + return_ids = _collect_tool_return_ids(messages) + + repaired: list[_messages.ModelMessage] = [] + for message in messages: + if isinstance(message, _messages.ModelRequest): + repaired_msg = _repair_request(message, call_ids) + elif isinstance(message, _messages.ModelResponse): # pragma: no branch + repaired_msg = _repair_response(message, return_ids) + if repaired_msg is not None: + repaired.append(repaired_msg) + + return repaired + + +def _collect_tool_call_ids(messages: list[_messages.ModelMessage]) -> set[str]: + """Collect all tool_call_ids from ToolCallPart in ModelResponse messages.""" + ids: set[str] = set() for message in messages: if isinstance(message, _messages.ModelResponse): for part in message.parts: if isinstance(part, _messages.ToolCallPart) and part.tool_call_id: - call_ids.add(part.tool_call_id) + ids.add(part.tool_call_id) + return ids - return_ids: set[str] = set() - for message in messages: - if isinstance(message, _messages.ModelRequest): - for part in message.parts: - if isinstance(part, (_messages.ToolReturnPart, _messages.RetryPromptPart)) and part.tool_call_id: - return_ids.add(part.tool_call_id) - repaired: list[_messages.ModelMessage] = [] +def _collect_tool_return_ids(messages: list[_messages.ModelMessage]) -> set[str]: + """Collect all tool_call_ids from ToolReturnPart/RetryPromptPart in ModelRequest messages.""" + ids: set[str] = set() for message in messages: if isinstance(message, _messages.ModelRequest): - kept_parts: list[_messages.ModelRequestPart] = [] - for part in message.parts: - if isinstance(part, _messages.ToolReturnPart): - if part.tool_call_id not in call_ids: - logger.debug( - 'Removing orphaned ToolReturnPart with tool_call_id=%r (no matching ToolCallPart)', - part.tool_call_id, - ) - continue - elif isinstance(part, _messages.RetryPromptPart): - if part.tool_name is not None and part.tool_call_id not in call_ids: - logger.debug( - 'Removing orphaned RetryPromptPart with tool_call_id=%r (no matching ToolCallPart)', - part.tool_call_id, - ) - continue - kept_parts.append(part) - - if kept_parts: - if len(kept_parts) != len(message.parts): - repaired.append(replace(message, parts=kept_parts)) - else: - repaired.append(message) - - elif isinstance(message, _messages.ModelResponse): # pragma: no branch - kept_response_parts: list[_messages.ModelResponsePart] = [] for part in message.parts: - if isinstance(part, _messages.ToolCallPart): - if part.tool_call_id not in return_ids: - logger.debug( - 'Removing orphaned ToolCallPart with tool_call_id=%r (no matching return)', - part.tool_call_id, - ) - continue - kept_response_parts.append(part) - - if kept_response_parts: - if len(kept_response_parts) != len(message.parts): - repaired.append(replace(message, parts=kept_response_parts)) - else: - repaired.append(message) - - return repaired + if isinstance(part, (_messages.ToolReturnPart, _messages.RetryPromptPart)) and part.tool_call_id: + ids.add(part.tool_call_id) + return ids + + +def _is_orphaned_request_part(part: _messages.ModelRequestPart, call_ids: set[str]) -> bool: + """Check if a request part is orphaned (no matching tool call).""" + if isinstance(part, _messages.ToolReturnPart): + return part.tool_call_id not in call_ids + if isinstance(part, _messages.RetryPromptPart): + return part.tool_name is not None and part.tool_call_id not in call_ids + return False + + +def _repair_request(message: _messages.ModelRequest, call_ids: set[str]) -> _messages.ModelMessage | None: + """Remove orphaned ToolReturnPart/RetryPromptPart from a ModelRequest.""" + kept: list[_messages.ModelRequestPart] = [] + for part in message.parts: + if _is_orphaned_request_part(part, call_ids): + logger.debug( + 'Removing orphaned %s with tool_call_id=%r (no matching ToolCallPart)', + type(part).__name__, + getattr(part, 'tool_call_id', None), + ) + continue + kept.append(part) + return _rebuild_or_drop(message, message.parts, kept) + + +def _repair_response(message: _messages.ModelResponse, return_ids: set[str]) -> _messages.ModelMessage | None: + """Remove orphaned ToolCallPart from a ModelResponse.""" + kept: list[_messages.ModelResponsePart] = [] + for part in message.parts: + if isinstance(part, _messages.ToolCallPart) and part.tool_call_id not in return_ids: + logger.debug( + 'Removing orphaned ToolCallPart with tool_call_id=%r (no matching return)', + part.tool_call_id, + ) + continue + kept.append(part) + return _rebuild_or_drop(message, message.parts, kept) + + +def _rebuild_or_drop( + message: _messages.ModelMessage, + original_parts: Sequence[object], + kept_parts: list[object], +) -> _messages.ModelMessage | None: + """Return the message with filtered parts, or None if all parts were removed.""" + if not kept_parts: + return None + if len(kept_parts) != len(original_parts): + return replace(message, parts=kept_parts) # type: ignore[arg-type] + return message From a56ab1c236bcf19dbae1c6a1ddb88dbdfed9ab9d Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:40:21 +0530 Subject: [PATCH 6/6] fix: resolve pyright type errors in history_processors - Fix "possibly unbound" by using if/else instead of if/elif for exhaustive ModelMessage union - Fix list invariance errors by inlining rebuild logic into _repair_request and _repair_response with proper return types - Remove _rebuild_or_drop helper and its type: ignore comment Locally verified: pyright 0 errors, ruff clean, 100% branch coverage. --- .../pydantic_ai/history_processors.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/history_processors.py b/pydantic_ai_slim/pydantic_ai/history_processors.py index 4bc2daec55..abb8a3253a 100644 --- a/pydantic_ai_slim/pydantic_ai/history_processors.py +++ b/pydantic_ai_slim/pydantic_ai/history_processors.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -from collections.abc import Sequence from dataclasses import replace from pydantic_ai import messages as _messages @@ -54,11 +53,11 @@ def repair_orphaned_tool_parts( repaired: list[_messages.ModelMessage] = [] for message in messages: if isinstance(message, _messages.ModelRequest): - repaired_msg = _repair_request(message, call_ids) - elif isinstance(message, _messages.ModelResponse): # pragma: no branch - repaired_msg = _repair_response(message, return_ids) - if repaired_msg is not None: - repaired.append(repaired_msg) + result = _repair_request(message, call_ids) + else: + result = _repair_response(message, return_ids) + if result is not None: + repaired.append(result) return repaired @@ -94,7 +93,10 @@ def _is_orphaned_request_part(part: _messages.ModelRequestPart, call_ids: set[st return False -def _repair_request(message: _messages.ModelRequest, call_ids: set[str]) -> _messages.ModelMessage | None: +def _repair_request( + message: _messages.ModelRequest, + call_ids: set[str], +) -> _messages.ModelRequest | None: """Remove orphaned ToolReturnPart/RetryPromptPart from a ModelRequest.""" kept: list[_messages.ModelRequestPart] = [] for part in message.parts: @@ -106,10 +108,17 @@ def _repair_request(message: _messages.ModelRequest, call_ids: set[str]) -> _mes ) continue kept.append(part) - return _rebuild_or_drop(message, message.parts, kept) + if not kept: + return None + if len(kept) != len(message.parts): + return replace(message, parts=kept) + return message -def _repair_response(message: _messages.ModelResponse, return_ids: set[str]) -> _messages.ModelMessage | None: +def _repair_response( + message: _messages.ModelResponse, + return_ids: set[str], +) -> _messages.ModelResponse | None: """Remove orphaned ToolCallPart from a ModelResponse.""" kept: list[_messages.ModelResponsePart] = [] for part in message.parts: @@ -120,17 +129,8 @@ def _repair_response(message: _messages.ModelResponse, return_ids: set[str]) -> ) continue kept.append(part) - return _rebuild_or_drop(message, message.parts, kept) - - -def _rebuild_or_drop( - message: _messages.ModelMessage, - original_parts: Sequence[object], - kept_parts: list[object], -) -> _messages.ModelMessage | None: - """Return the message with filtered parts, or None if all parts were removed.""" - if not kept_parts: + if not kept: return None - if len(kept_parts) != len(original_parts): - return replace(message, parts=kept_parts) # type: ignore[arg-type] + if len(kept) != len(message.parts): + return replace(message, parts=kept) return message