From 597e2a51172e221866a307d7e1fb4a6382079b1c Mon Sep 17 00:00:00 2001 From: gautham18113 Date: Mon, 1 Jun 2026 23:04:14 -0700 Subject: [PATCH] fix(openai-responses): surface cache read tokens in metadata chunk OpenAIResponsesModel._format_chunk dropped input_tokens_details when building the metadata usage dict, so cacheReadInputTokens was never set and cache hits were invisible to telemetry and cost tooling. Mirror the fix already present in OpenAIModel.format_chunk (added in #2115 / #2116): read input_tokens_details.cached_tokens and set cacheReadInputTokens when the field is present. Fixes #2407 --- .../src/strands/models/openai_responses.py | 17 ++- .../strands/models/test_openai_responses.py | 106 ++++++++++++++++-- 2 files changed, 108 insertions(+), 15 deletions(-) diff --git a/strands-py/src/strands/models/openai_responses.py b/strands-py/src/strands/models/openai_responses.py index 8914fb01c0..a79004e963 100644 --- a/strands-py/src/strands/models/openai_responses.py +++ b/strands-py/src/strands/models/openai_responses.py @@ -55,6 +55,7 @@ from ..types.citations import WebLocationDict # noqa: E402 from ..types.content import ContentBlock, Messages, Role, SystemContentBlock # noqa: E402 +from ..types.event_loop import Usage # noqa: E402 from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 @@ -837,13 +838,19 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "metadata": # Responses API uses input_tokens/output_tokens naming convention + usage_data: Usage = { + "inputTokens": getattr(event["data"], "input_tokens", 0), + "outputTokens": getattr(event["data"], "output_tokens", 0), + "totalTokens": getattr(event["data"], "total_tokens", 0), + } + + if token_details := getattr(event["data"], "input_tokens_details", None): + if cached := getattr(token_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + return { "metadata": { - "usage": { - "inputTokens": getattr(event["data"], "input_tokens", 0), - "outputTokens": getattr(event["data"], "output_tokens", 0), - "totalTokens": getattr(event["data"], "total_tokens", 0), - }, + "usage": usage_data, "metrics": { "latencyMs": 0, # TODO }, diff --git a/strands-py/tests/strands/models/test_openai_responses.py b/strands-py/tests/strands/models/test_openai_responses.py index 697508339c..40ca79c522 100644 --- a/strands-py/tests/strands/models/test_openai_responses.py +++ b/strands-py/tests/strands/models/test_openai_responses.py @@ -451,11 +451,11 @@ def test_format_request(model, messages, tool_specs, system_prompt): {"chunk_type": "message_stop", "data": "stop"}, {"messageStop": {"stopReason": "end_turn"}}, ), - # Metadata + # Metadata - no cache tokens ( { "chunk_type": "metadata", - "data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150), + "data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150, input_tokens_details=None), }, { "metadata": { @@ -470,6 +470,31 @@ def test_format_request(model, messages, tool_specs, system_prompt): }, }, ), + # Metadata - with cache read tokens + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock( + input_tokens=100, + output_tokens=50, + total_tokens=150, + input_tokens_details=unittest.mock.Mock(cached_tokens=80), + ), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + "cacheReadInputTokens": 80, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), ], ) def test_format_chunk(event, exp_chunk, model): @@ -490,7 +515,7 @@ async def test_stream(openai_client, model_id, model, agenerator, alist): mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") mock_complete_event = unittest.mock.Mock( type="response.completed", - response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)), ) openai_client.responses.create = unittest.mock.AsyncMock( @@ -526,6 +551,67 @@ async def test_stream(openai_client, model_id, model, agenerator, alist): openai_client.responses.create.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_cache_tokens_propagated(openai_client, model, agenerator, alist): + """Cache read tokens from input_tokens_details are surfaced in the metadata event.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hi") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock( + usage=unittest.mock.Mock( + input_tokens=100, + output_tokens=10, + total_tokens=110, + input_tokens_details=unittest.mock.Mock(cached_tokens=80), + ) + ), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + tru_events = await alist(model.stream(messages)) + + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + usage = metadata_events[0]["metadata"]["usage"] + assert usage["inputTokens"] == 100 + assert usage["outputTokens"] == 10 + assert usage["totalTokens"] == 110 + assert usage["cacheReadInputTokens"] == 80 + + +@pytest.mark.asyncio +async def test_stream_no_cache_tokens_when_absent(openai_client, model, agenerator, alist): + """cacheReadInputTokens is omitted from metadata when input_tokens_details is absent.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hi") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock( + usage=unittest.mock.Mock( + input_tokens=100, + output_tokens=10, + total_tokens=110, + input_tokens_details=None, + ) + ), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + tru_events = await alist(model.stream(messages)) + + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + usage = metadata_events[0]["metadata"]["usage"] + assert "cacheReadInputTokens" not in usage + + @pytest.mark.asyncio async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): # Mock tool call events @@ -538,7 +624,7 @@ async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): ) mock_complete_event = unittest.mock.Mock( type="response.completed", - response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)), ) openai_client.responses.create = unittest.mock.AsyncMock( @@ -571,7 +657,7 @@ async def test_stream_with_tool_calls_done_event(openai_client, model, agenerato ) mock_complete_event = unittest.mock.Mock( type="response.completed", - response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)), ) openai_client.responses.create = unittest.mock.AsyncMock( @@ -594,7 +680,7 @@ async def test_stream_response_incomplete(openai_client, model, agenerator, alis mock_incomplete_event = unittest.mock.Mock( type="response.incomplete", response=unittest.mock.Mock( - usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110), + usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110, input_tokens_details=None), incomplete_details=unittest.mock.Mock(reason="max_output_tokens"), ), ) @@ -628,7 +714,7 @@ async def test_stream_reasoning_content(openai_client, model, agenerator, alist, mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is 42") mock_complete_event = unittest.mock.Mock( type="response.completed", - response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30)), + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30, input_tokens_details=None)), ) openai_client.responses.create = unittest.mock.AsyncMock( @@ -672,7 +758,7 @@ async def test_stream_citation_annotations(openai_client, model, agenerator, ali ) mock_complete_event = unittest.mock.Mock( type="response.completed", - response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)), ) openai_client.responses.create = unittest.mock.AsyncMock( @@ -708,7 +794,7 @@ async def test_stream_unsupported_annotation_type(openai_client, model, agenerat ) mock_complete_event = unittest.mock.Mock( type="response.completed", - response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)), ) openai_client.responses.create = unittest.mock.AsyncMock( @@ -1152,7 +1238,7 @@ async def test_stream_stateful(openai_client, model_id, agenerator, alist): type="response.completed", response=unittest.mock.Mock( id="resp_abc123", - usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15), + usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None), ), ), ]