Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/strands_evals/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
stop_after_attempt,
wait_exponential,
)
from typing_extensions import Any, Generic
from typing_extensions import Any, Generic, TypeVar

from .case import Case
from .detectors.diagnosis import diagnose_session
Expand All @@ -40,6 +40,10 @@
_INITIAL_RETRY_DELAY = 4
_MAX_RETRY_DELAY = 240 # 4 minutes

# Subclasses can bind ReportT to a custom report element type and set `report_cls` to
# the matching class. run_evaluations* always returns list[ReportT].
ReportT = TypeVar("ReportT", bound=EvaluationReport, default=EvaluationReport)


def _get_label_from_score(evaluator: Evaluator, score: float) -> str:
"""
Expand All @@ -65,7 +69,7 @@ def _get_label_from_score(evaluator: Evaluator, score: float) -> str:
return "YES" if score >= 0.5 else "NO"


class Experiment(Generic[InputT, OutputT]):
class Experiment(Generic[InputT, OutputT, ReportT]):
Comment thread
poshinchen marked this conversation as resolved.
"""
An evaluation experiment containing test cases and evaluators.

Expand Down Expand Up @@ -100,6 +104,12 @@ class Experiment(Generic[InputT, OutputT]):
)
"""

# Subclasses bind ReportT and set report_cls to construct that type per evaluator.
Comment thread
poshinchen marked this conversation as resolved.
# The ignore is needed because at the base-class definition site mypy can't prove
# type[EvaluationReport] satisfies type[ReportT] for every ReportT binding; subclasses
# override report_cls with their bound type, where the substitution holds.
report_cls: type[ReportT] = EvaluationReport # type: ignore[assignment]

def __init__(
self,
cases: list[Case[InputT, OutputT]] | None = None,
Expand Down Expand Up @@ -187,7 +197,15 @@ async def _run_task_async(
Return:
An EvaluationData record containing the input and actual output, name, expected output, and metadata.
"""
# Create evaluation context
# Fields declared on Case subclasses (e.g. a typed `config`) pass through
# to the evaluator via EvaluationData's extra="allow". getattr preserves nested BaseModels.
# Filter out names that collide with base Case fields (already passed explicitly below)
# and EvaluationData fields (would silently overwrite the explicit kwargs).
extra_fields = {
name: getattr(case, name)
for name in type(case).model_fields
if name not in Case.model_fields and name not in EvaluationData.model_fields
}
evaluation_context = EvaluationData(
name=case.name,
input=case.input,
Expand All @@ -197,6 +215,7 @@ async def _run_task_async(
expected_interactions=case.expected_interactions,
expected_environment_state=case.expected_environment_state,
metadata=case.metadata,
**extra_fields,
)

# Handle both async and sync tasks
Expand Down Expand Up @@ -544,7 +563,7 @@ def run_evaluations(
self,
task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]],
evaluation_data_store: EvaluationDataStore | None = None,
) -> list[EvaluationReport]:
) -> list[ReportT]:
"""
Run the evaluations for all of the test cases with all evaluators.

Expand All @@ -557,8 +576,8 @@ def run_evaluations(
results are loaded instead of running the task, and new results are saved after task execution.

Return:
A list of EvaluationReport objects, one for each evaluator, containing the overall score,
individual case results, and basic feedback for each test case.
A list of report objects, one per evaluator. The element type defaults to
EvaluationReport; subclasses bind ReportT to a subclass and set `report_cls`.
"""
if asyncio.iscoroutinefunction(task):
raise ValueError("Async task is not supported. Please use run_evaluations_async instead.")
Expand All @@ -570,7 +589,7 @@ async def run_evaluations_async(
task: Callable,
max_workers: int = 10,
evaluation_data_store: EvaluationDataStore | None = None,
) -> list[EvaluationReport]:
) -> list[ReportT]:
"""
Run evaluations asynchronously using a queue for parallel processing.

Expand All @@ -583,7 +602,8 @@ async def run_evaluations_async(
results are loaded instead of running the task, and new results are saved after task execution.

Returns:
List of EvaluationReport objects, one for each evaluator, containing evaluation results
A list of report objects, one per evaluator. The element type defaults to
EvaluationReport; subclasses bind ReportT to a subclass and set `report_cls`.
"""
if evaluation_data_store is not None:
self._validate_case_names()
Expand Down Expand Up @@ -638,7 +658,7 @@ async def run_evaluations_async(
eval_name = evaluator.get_type_name()
data = evaluator_data[eval_name]
scores = data["scores"]
report = EvaluationReport(
report = self.report_cls(
evaluator_name=eval_name,
overall_score=sum(scores) / len(scores) if scores else 0,
scores=scores,
Expand Down
9 changes: 8 additions & 1 deletion src/strands_evals/types/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing_extensions import Any, Generic, TypedDict, TypeVar

from .trace import Session
Expand Down Expand Up @@ -93,8 +93,15 @@ class EvaluationData(BaseModel, Generic[InputT, OutputT]):
metadata: Additional information about the test case.
actual_interactions: The actual interaction sequence given the input.
expected_interactions: The expected interaction sequence given the input.

Comment thread
poshinchen marked this conversation as resolved.
Extra fields from `Case` subclasses are preserved as model attributes (e.g. a typed
`config` field on a `Case` subclass passes through to the evaluator with its type intact).
Note: `extra="allow"` means misspelled field names will be silently accepted rather than
rejected — the tradeoff for supporting subclass passthrough.
"""

model_config = ConfigDict(extra="allow")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's double-check that there's not a subtle gotcha lurking here.

The pydantic extras are preserved as live objects in memory, but they have no field annotation, so model_validate can't reconstruct their type. Any EvaluationDataStore round-trips via model_dump/model_validate, so on a cache hit a typed config: SomeModel comes back as a plain dict:

  FRESH path : config -> SomeModel   (config.label works)
  CACHED path: config -> dict        (config.label raises AttributeError)

So an evaluator written with evaluation_case.config.label passes on first run and raises an exception on the second once evaluation_data_store= is set. The new tests only cover the cache-miss path.

Please either:

  1. add a cache-path test (save → load → assert the evaluator still reads the field), and document that extras survive as dict through a store; or
  2. have subclasses declare the field on a typed EvaluationData subclass instead of relying on extra="allow".

Minimal repro:

  ed = EvaluationData(input="hi", config=SomeModel(label="x", weight=0.7))
  EvaluationData.model_validate(ed.model_dump()).config  # -> dict, not SomeModel

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think this is not a blocker as of now, but we definitely need to think about how to achieve the recovery when loading from the EvaluationData.


input: InputT
actual_output: OutputT | None = None
name: str | None = None
Expand Down
104 changes: 104 additions & 0 deletions tests/strands_evals/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from botocore.exceptions import ClientError
from pydantic import BaseModel
from strands.models.model import Model
from strands.types.exceptions import EventLoopException, ModelThrottledException

Expand All @@ -21,6 +22,7 @@
from strands_evals.experiment import is_throttling_error
from strands_evals.providers.trace_provider import TraceProvider
from strands_evals.types import EvaluationData, EvaluationOutput
from strands_evals.types.evaluation_report import EvaluationReport


class MockEvaluator(Evaluator[str, str]):
Expand Down Expand Up @@ -2041,3 +2043,105 @@ def task_with_session(c):
assert len(reports) == 1
assert reports[0].diagnoses == [None]
assert reports[0].recommendations == [None]


# ---------------------------------------------------------------------------
# Extension-point tests: Case-subclass extras pass through, and Experiment
# subclasses can swap the report element type via ReportT + report_cls.
# ---------------------------------------------------------------------------


class _NestedConfig(BaseModel):
"""Typed nested object used to verify it survives the Case -> EvaluationData hop."""

label: str
weight: float


class _CaseWithExtras(Case[str, str]):
"""Case subclass that adds a typed field beyond the base Case schema."""

config: _NestedConfig
tag: str | None = None


@pytest.mark.asyncio
async def test_evaluation_data_preserves_subclass_extra_fields():
"""Subclass-only Case fields reach EvaluationData with their types intact."""
case = _CaseWithExtras(
name="extras",
input="hi",
config=_NestedConfig(label="phishing", weight=0.7),
tag="redteam",
)
experiment = Experiment(cases=[case], evaluators=[MockEvaluator()])

evaluation_context = await experiment._run_task_async(lambda c: c.input, case)

# Nested BaseModel survives as an object (not a dict from model_dump).
assert isinstance(evaluation_context.config, _NestedConfig)
assert evaluation_context.config.label == "phishing"
assert evaluation_context.config.weight == 0.7
assert evaluation_context.tag == "redteam"
# Base fields still populated as before.
assert evaluation_context.input == "hi"
assert evaluation_context.actual_output == "hi"


def test_evaluation_data_evaluator_can_read_subclass_extras():
"""An evaluator reading a typed subclass field via getattr works end-to-end."""

class ConfigReadingEvaluator(Evaluator[str, str]):
def evaluate(self, evaluation_case: EvaluationData[str, str]) -> list[EvaluationOutput]:
cfg = getattr(evaluation_case, "config", None)
assert isinstance(cfg, _NestedConfig)
return [EvaluationOutput(score=cfg.weight, test_pass=True, reason=cfg.label)]

case = _CaseWithExtras(name="x", input="hi", config=_NestedConfig(label="bypass", weight=0.42))
experiment = Experiment(cases=[case], evaluators=[ConfigReadingEvaluator()])

reports = experiment.run_evaluations(lambda c: c.input)

assert len(reports) == 1
assert reports[0].scores == [0.42]
assert reports[0].reasons == ["bypass"]


def test_experiment_default_report_cls_is_evaluation_report():
"""Without subclassing, run_evaluations still returns plain EvaluationReport."""
case = Case(name="t", input="hi", expected_output="hi")
experiment = Experiment(cases=[case], evaluators=[MockEvaluator()])

reports = experiment.run_evaluations(lambda c: c.input)

assert len(reports) == 1
assert type(reports[0]) is EvaluationReport


class _CustomReport(EvaluationReport):
"""Subclass report type — the existence of the type is enough for this test."""

def headline(self) -> str:
return f"{self.evaluator_name}: {self.overall_score:.2f}"


class _CustomExperiment(Experiment[str, str, _CustomReport]):
report_cls = _CustomReport


def test_experiment_subclass_can_swap_report_cls():
"""Binding ReportT and setting report_cls produces list[ReportT] at runtime."""
cases = [
Case(name="a", input="hi", expected_output="hi"),
Case(name="b", input="hi", expected_output="bye"),
]
experiment = _CustomExperiment(cases=cases, evaluators=[MockEvaluator()])

reports = experiment.run_evaluations(lambda c: c.input)

assert len(reports) == 1
assert type(reports[0]) is _CustomReport
# Subclass-only method works because the runtime object is the right class.
assert reports[0].headline().startswith("MockEvaluator:")
# Base fields still populated.
assert reports[0].scores == [1.0, 0.0]
Loading