Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions strands-py/src/strands/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from ..types.citations import WebLocationDict # noqa: E402
from ..types.content import ContentBlock, Messages, Role, SystemContentBlock # noqa: E402
from ..types.event_loop import Usage # noqa: E402
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402
from ..types.streaming import StreamEvent # noqa: E402
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402
Expand Down Expand Up @@ -837,13 +838,19 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:

case "metadata":
# Responses API uses input_tokens/output_tokens naming convention
usage_data: Usage = {
"inputTokens": getattr(event["data"], "input_tokens", 0),
"outputTokens": getattr(event["data"], "output_tokens", 0),
"totalTokens": getattr(event["data"], "total_tokens", 0),
}

if token_details := getattr(event["data"], "input_tokens_details", None):
if cached := getattr(token_details, "cached_tokens", None):
usage_data["cacheReadInputTokens"] = cached

return {
"metadata": {
"usage": {
"inputTokens": getattr(event["data"], "input_tokens", 0),
"outputTokens": getattr(event["data"], "output_tokens", 0),
"totalTokens": getattr(event["data"], "total_tokens", 0),
},
"usage": usage_data,
"metrics": {
"latencyMs": 0, # TODO
},
Expand Down
106 changes: 96 additions & 10 deletions strands-py/tests/strands/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,11 @@ def test_format_request(model, messages, tool_specs, system_prompt):
{"chunk_type": "message_stop", "data": "stop"},
{"messageStop": {"stopReason": "end_turn"}},
),
# Metadata
# Metadata - no cache tokens
(
{
"chunk_type": "metadata",
"data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150),
"data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150, input_tokens_details=None),
},
{
"metadata": {
Expand All @@ -470,6 +470,31 @@ def test_format_request(model, messages, tool_specs, system_prompt):
},
},
),
# Metadata - with cache read tokens
(
{
"chunk_type": "metadata",
"data": unittest.mock.Mock(
input_tokens=100,
output_tokens=50,
total_tokens=150,
input_tokens_details=unittest.mock.Mock(cached_tokens=80),
),
},
{
"metadata": {
"usage": {
"inputTokens": 100,
"outputTokens": 50,
"totalTokens": 150,
"cacheReadInputTokens": 80,
},
"metrics": {
"latencyMs": 0,
},
},
},
),
],
)
def test_format_chunk(event, exp_chunk, model):
Expand All @@ -490,7 +515,7 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello")
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)),
)

openai_client.responses.create = unittest.mock.AsyncMock(
Expand Down Expand Up @@ -526,6 +551,67 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
openai_client.responses.create.assert_called_once_with(**expected_request)


@pytest.mark.asyncio
async def test_stream_cache_tokens_propagated(openai_client, model, agenerator, alist):
"""Cache read tokens from input_tokens_details are surfaced in the metadata event."""
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hi")
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(
usage=unittest.mock.Mock(
input_tokens=100,
output_tokens=10,
total_tokens=110,
input_tokens_details=unittest.mock.Mock(cached_tokens=80),
)
),
)

openai_client.responses.create = unittest.mock.AsyncMock(
return_value=agenerator([mock_text_event, mock_complete_event])
)

messages = [{"role": "user", "content": [{"text": "test"}]}]
tru_events = await alist(model.stream(messages))

metadata_events = [e for e in tru_events if "metadata" in e]
assert len(metadata_events) == 1
usage = metadata_events[0]["metadata"]["usage"]
assert usage["inputTokens"] == 100
assert usage["outputTokens"] == 10
assert usage["totalTokens"] == 110
assert usage["cacheReadInputTokens"] == 80


@pytest.mark.asyncio
async def test_stream_no_cache_tokens_when_absent(openai_client, model, agenerator, alist):
"""cacheReadInputTokens is omitted from metadata when input_tokens_details is absent."""
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hi")
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(
usage=unittest.mock.Mock(
input_tokens=100,
output_tokens=10,
total_tokens=110,
input_tokens_details=None,
)
),
)

openai_client.responses.create = unittest.mock.AsyncMock(
return_value=agenerator([mock_text_event, mock_complete_event])
)

messages = [{"role": "user", "content": [{"text": "test"}]}]
tru_events = await alist(model.stream(messages))

metadata_events = [e for e in tru_events if "metadata" in e]
assert len(metadata_events) == 1
usage = metadata_events[0]["metadata"]["usage"]
assert "cacheReadInputTokens" not in usage


@pytest.mark.asyncio
async def test_stream_with_tool_calls(openai_client, model, agenerator, alist):
# Mock tool call events
Expand All @@ -538,7 +624,7 @@ async def test_stream_with_tool_calls(openai_client, model, agenerator, alist):
)
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)),
)

openai_client.responses.create = unittest.mock.AsyncMock(
Expand Down Expand Up @@ -571,7 +657,7 @@ async def test_stream_with_tool_calls_done_event(openai_client, model, agenerato
)
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)),
)

openai_client.responses.create = unittest.mock.AsyncMock(
Expand All @@ -594,7 +680,7 @@ async def test_stream_response_incomplete(openai_client, model, agenerator, alis
mock_incomplete_event = unittest.mock.Mock(
type="response.incomplete",
response=unittest.mock.Mock(
usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110),
usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110, input_tokens_details=None),
incomplete_details=unittest.mock.Mock(reason="max_output_tokens"),
),
)
Expand Down Expand Up @@ -628,7 +714,7 @@ async def test_stream_reasoning_content(openai_client, model, agenerator, alist,
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is 42")
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30)),
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30, input_tokens_details=None)),
)

openai_client.responses.create = unittest.mock.AsyncMock(
Expand Down Expand Up @@ -672,7 +758,7 @@ async def test_stream_citation_annotations(openai_client, model, agenerator, ali
)
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)),
)

openai_client.responses.create = unittest.mock.AsyncMock(
Expand Down Expand Up @@ -708,7 +794,7 @@ async def test_stream_unsupported_annotation_type(openai_client, model, agenerat
)
mock_complete_event = unittest.mock.Mock(
type="response.completed",
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None)),
)

openai_client.responses.create = unittest.mock.AsyncMock(
Expand Down Expand Up @@ -1152,7 +1238,7 @@ async def test_stream_stateful(openai_client, model_id, agenerator, alist):
type="response.completed",
response=unittest.mock.Mock(
id="resp_abc123",
usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15),
usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None),
),
),
]
Expand Down