diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index c8fc93f64b..95c08773a8 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -28,7 +28,7 @@ from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc from ._warnings import PydanticAIDeprecationWarning from .exceptions import UnexpectedModelBehavior -from .usage import RequestUsage +from .usage import RequestUsage, RunUsage if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings @@ -2083,7 +2083,7 @@ class ModelResponse: _: KW_ONLY - usage: RequestUsage = field(default_factory=RequestUsage) + usage: RequestUsage | RunUsage = field(default_factory=RequestUsage) """Usage information for the request. This has a default to make tests easier, and to support loading old messages where usage will be missing. diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 14591d8df6..5e8c6d45b6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime -from typing import Any +from typing import Any, cast from typing_extensions import Self @@ -46,7 +46,7 @@ def get(self) -> ModelResponse: '`StreamedResponse.usage` is no longer a method; access it as a property (drop the parentheses).' ) def usage(self) -> RequestUsage: - return self.response.usage # pragma: no cover + return cast(RequestUsage, self.response.usage) # pragma: no cover @property def model_name(self) -> str: diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index a6282917a7..0a6e5dba88 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -5283,7 +5283,7 @@ async def test_stream_with_continuous_usage_stats(allow_model_requests: None): # Verify usage is updated at each step via stream_response usage_at_each_step: list[RequestUsage] = [] async for response in result.stream_response(debounce_by=None): - usage_at_each_step.append(response.usage) + usage_at_each_step.append(cast(RequestUsage, response.usage)) # Each step should have the cumulative usage from that chunk (not accumulated) # The stream emits responses for each content chunk plus final diff --git a/tests/test_messages.py b/tests/test_messages.py index 0b46d73142..0cf121078b 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -24,6 +24,7 @@ NativeToolReturnPart, RequestUsage, RetryPromptPart, + RunUsage, TextContent, TextPart, ThinkingPart, @@ -509,6 +510,26 @@ def test_pre_usage_refactor_messages_deserializable(): ) +def test_model_response_preserves_run_usage_roundtrip(): + response = ModelResponse( + parts=[TextPart(content='Hello!')], + usage=RunUsage(requests=5, tool_calls=3, input_tokens=1000, output_tokens=500), + model_name='test', + ) + + serialized = ModelMessagesTypeAdapter.dump_json([response]) + deserialized = ModelMessagesTypeAdapter.validate_json(serialized) + + msg = deserialized[0] + assert isinstance(msg, ModelResponse) + usage = msg.usage + assert isinstance(usage, RunUsage) + assert usage.requests == 5 + assert usage.tool_calls == 3 + assert usage.input_tokens == 1000 + assert usage.output_tokens == 500 + + def test_file_part_has_content(): filepart = FilePart(content=BinaryContent(data=b'', media_type='application/pdf')) assert not filepart.has_content()