diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index 40ec86ae..8439a669 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -14,7 +14,7 @@ stop_after_attempt, wait_exponential, ) -from typing_extensions import Any, Generic +from typing_extensions import Any, Generic, TypeVar from .case import Case from .detectors.diagnosis import diagnose_session @@ -40,6 +40,10 @@ _INITIAL_RETRY_DELAY = 4 _MAX_RETRY_DELAY = 240 # 4 minutes +# Subclasses can bind ReportT to a custom report element type and set `report_cls` to +# the matching class. run_evaluations* always returns list[ReportT]. +ReportT = TypeVar("ReportT", bound=EvaluationReport, default=EvaluationReport) + def _get_label_from_score(evaluator: Evaluator, score: float) -> str: """ @@ -65,7 +69,7 @@ def _get_label_from_score(evaluator: Evaluator, score: float) -> str: return "YES" if score >= 0.5 else "NO" -class Experiment(Generic[InputT, OutputT]): +class Experiment(Generic[InputT, OutputT, ReportT]): """ An evaluation experiment containing test cases and evaluators. @@ -100,6 +104,12 @@ class Experiment(Generic[InputT, OutputT]): ) """ + # Subclasses bind ReportT and set report_cls to construct that type per evaluator. + # The ignore is needed because at the base-class definition site mypy can't prove + # type[EvaluationReport] satisfies type[ReportT] for every ReportT binding; subclasses + # override report_cls with their bound type, where the substitution holds. + report_cls: type[ReportT] = EvaluationReport # type: ignore[assignment] + def __init__( self, cases: list[Case[InputT, OutputT]] | None = None, @@ -187,7 +197,15 @@ async def _run_task_async( Return: An EvaluationData record containing the input and actual output, name, expected output, and metadata. """ - # Create evaluation context + # Fields declared on Case subclasses (e.g. a typed `config`) pass through + # to the evaluator via EvaluationData's extra="allow". getattr preserves nested BaseModels. + # Filter out names that collide with base Case fields (already passed explicitly below) + # and EvaluationData fields (would silently overwrite the explicit kwargs). + extra_fields = { + name: getattr(case, name) + for name in type(case).model_fields + if name not in Case.model_fields and name not in EvaluationData.model_fields + } evaluation_context = EvaluationData( name=case.name, input=case.input, @@ -197,6 +215,7 @@ async def _run_task_async( expected_interactions=case.expected_interactions, expected_environment_state=case.expected_environment_state, metadata=case.metadata, + **extra_fields, ) # Handle both async and sync tasks @@ -544,7 +563,7 @@ def run_evaluations( self, task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], evaluation_data_store: EvaluationDataStore | None = None, - ) -> list[EvaluationReport]: + ) -> list[ReportT]: """ Run the evaluations for all of the test cases with all evaluators. @@ -557,8 +576,8 @@ def run_evaluations( results are loaded instead of running the task, and new results are saved after task execution. Return: - A list of EvaluationReport objects, one for each evaluator, containing the overall score, - individual case results, and basic feedback for each test case. + A list of report objects, one per evaluator. The element type defaults to + EvaluationReport; subclasses bind ReportT to a subclass and set `report_cls`. """ if asyncio.iscoroutinefunction(task): raise ValueError("Async task is not supported. Please use run_evaluations_async instead.") @@ -570,7 +589,7 @@ async def run_evaluations_async( task: Callable, max_workers: int = 10, evaluation_data_store: EvaluationDataStore | None = None, - ) -> list[EvaluationReport]: + ) -> list[ReportT]: """ Run evaluations asynchronously using a queue for parallel processing. @@ -583,7 +602,8 @@ async def run_evaluations_async( results are loaded instead of running the task, and new results are saved after task execution. Returns: - List of EvaluationReport objects, one for each evaluator, containing evaluation results + A list of report objects, one per evaluator. The element type defaults to + EvaluationReport; subclasses bind ReportT to a subclass and set `report_cls`. """ if evaluation_data_store is not None: self._validate_case_names() @@ -638,7 +658,7 @@ async def run_evaluations_async( eval_name = evaluator.get_type_name() data = evaluator_data[eval_name] scores = data["scores"] - report = EvaluationReport( + report = self.report_cls( evaluator_name=eval_name, overall_score=sum(scores) / len(scores) if scores else 0, scores=scores, diff --git a/src/strands_evals/types/evaluation.py b/src/strands_evals/types/evaluation.py index 4e25f3b0..56e0ee89 100644 --- a/src/strands_evals/types/evaluation.py +++ b/src/strands_evals/types/evaluation.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing_extensions import Any, Generic, TypedDict, TypeVar from .trace import Session @@ -93,8 +93,15 @@ class EvaluationData(BaseModel, Generic[InputT, OutputT]): metadata: Additional information about the test case. actual_interactions: The actual interaction sequence given the input. expected_interactions: The expected interaction sequence given the input. + + Extra fields from `Case` subclasses are preserved as model attributes (e.g. a typed + `config` field on a `Case` subclass passes through to the evaluator with its type intact). + Note: `extra="allow"` means misspelled field names will be silently accepted rather than + rejected — the tradeoff for supporting subclass passthrough. """ + model_config = ConfigDict(extra="allow") + input: InputT actual_output: OutputT | None = None name: str | None = None diff --git a/tests/strands_evals/test_experiment.py b/tests/strands_evals/test_experiment.py index 095b850a..44296541 100644 --- a/tests/strands_evals/test_experiment.py +++ b/tests/strands_evals/test_experiment.py @@ -3,6 +3,7 @@ import pytest from botocore.exceptions import ClientError +from pydantic import BaseModel from strands.models.model import Model from strands.types.exceptions import EventLoopException, ModelThrottledException @@ -21,6 +22,7 @@ from strands_evals.experiment import is_throttling_error from strands_evals.providers.trace_provider import TraceProvider from strands_evals.types import EvaluationData, EvaluationOutput +from strands_evals.types.evaluation_report import EvaluationReport class MockEvaluator(Evaluator[str, str]): @@ -2041,3 +2043,105 @@ def task_with_session(c): assert len(reports) == 1 assert reports[0].diagnoses == [None] assert reports[0].recommendations == [None] + + +# --------------------------------------------------------------------------- +# Extension-point tests: Case-subclass extras pass through, and Experiment +# subclasses can swap the report element type via ReportT + report_cls. +# --------------------------------------------------------------------------- + + +class _NestedConfig(BaseModel): + """Typed nested object used to verify it survives the Case -> EvaluationData hop.""" + + label: str + weight: float + + +class _CaseWithExtras(Case[str, str]): + """Case subclass that adds a typed field beyond the base Case schema.""" + + config: _NestedConfig + tag: str | None = None + + +@pytest.mark.asyncio +async def test_evaluation_data_preserves_subclass_extra_fields(): + """Subclass-only Case fields reach EvaluationData with their types intact.""" + case = _CaseWithExtras( + name="extras", + input="hi", + config=_NestedConfig(label="phishing", weight=0.7), + tag="redteam", + ) + experiment = Experiment(cases=[case], evaluators=[MockEvaluator()]) + + evaluation_context = await experiment._run_task_async(lambda c: c.input, case) + + # Nested BaseModel survives as an object (not a dict from model_dump). + assert isinstance(evaluation_context.config, _NestedConfig) + assert evaluation_context.config.label == "phishing" + assert evaluation_context.config.weight == 0.7 + assert evaluation_context.tag == "redteam" + # Base fields still populated as before. + assert evaluation_context.input == "hi" + assert evaluation_context.actual_output == "hi" + + +def test_evaluation_data_evaluator_can_read_subclass_extras(): + """An evaluator reading a typed subclass field via getattr works end-to-end.""" + + class ConfigReadingEvaluator(Evaluator[str, str]): + def evaluate(self, evaluation_case: EvaluationData[str, str]) -> list[EvaluationOutput]: + cfg = getattr(evaluation_case, "config", None) + assert isinstance(cfg, _NestedConfig) + return [EvaluationOutput(score=cfg.weight, test_pass=True, reason=cfg.label)] + + case = _CaseWithExtras(name="x", input="hi", config=_NestedConfig(label="bypass", weight=0.42)) + experiment = Experiment(cases=[case], evaluators=[ConfigReadingEvaluator()]) + + reports = experiment.run_evaluations(lambda c: c.input) + + assert len(reports) == 1 + assert reports[0].scores == [0.42] + assert reports[0].reasons == ["bypass"] + + +def test_experiment_default_report_cls_is_evaluation_report(): + """Without subclassing, run_evaluations still returns plain EvaluationReport.""" + case = Case(name="t", input="hi", expected_output="hi") + experiment = Experiment(cases=[case], evaluators=[MockEvaluator()]) + + reports = experiment.run_evaluations(lambda c: c.input) + + assert len(reports) == 1 + assert type(reports[0]) is EvaluationReport + + +class _CustomReport(EvaluationReport): + """Subclass report type — the existence of the type is enough for this test.""" + + def headline(self) -> str: + return f"{self.evaluator_name}: {self.overall_score:.2f}" + + +class _CustomExperiment(Experiment[str, str, _CustomReport]): + report_cls = _CustomReport + + +def test_experiment_subclass_can_swap_report_cls(): + """Binding ReportT and setting report_cls produces list[ReportT] at runtime.""" + cases = [ + Case(name="a", input="hi", expected_output="hi"), + Case(name="b", input="hi", expected_output="bye"), + ] + experiment = _CustomExperiment(cases=cases, evaluators=[MockEvaluator()]) + + reports = experiment.run_evaluations(lambda c: c.input) + + assert len(reports) == 1 + assert type(reports[0]) is _CustomReport + # Subclass-only method works because the runtime object is the right class. + assert reports[0].headline().startswith("MockEvaluator:") + # Base fields still populated. + assert reports[0].scores == [1.0, 0.0]