diff --git a/src/evaluate/__init__.py b/src/evaluate/__init__.py index a8c25bd9..d52146d6 100644 --- a/src/evaluate/__init__.py +++ b/src/evaluate/__init__.py @@ -45,7 +45,15 @@ from .info import ComparisonInfo, EvaluationModuleInfo, MeasurementInfo, MetricInfo from .inspect import inspect_evaluation_module, list_evaluation_modules from .loading import load -from .module import CombinedEvaluations, Comparison, EvaluationModule, Measurement, Metric, combine +from .module import ( + CombinedEvaluations, + Comparison, + EvaluationModule, + EvaluationModuleError, + Measurement, + Metric, + combine, +) from .saving import save from .utils import * from .utils import gradio, logging diff --git a/src/evaluate/module.py b/src/evaluate/module.py index ca38b9b1..ed188c42 100644 --- a/src/evaluate/module.py +++ b/src/evaluate/module.py @@ -41,6 +41,16 @@ logger = get_logger(__name__) +class EvaluationModuleError(Exception): + """Base error raised when an evaluation module fails to compute its result. + + Failures coming from the underlying ``_compute`` implementation (for example a + ``ValueError`` or ``KeyError`` raised by scikit-learn) are wrapped in this error so + that callers can catch evaluate-specific failures without catching a bare + ``Exception``. The original exception is preserved on ``__cause__``. + """ + + class FileFreeLock(BaseFileLock): """Thread lock until a file **cannot** be locked""" @@ -464,7 +474,12 @@ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[di inputs = {input_name: self.data[input_name][:] for input_name in self._feature_names()} with temp_seed(self.seed): - output = self._compute(**inputs, **compute_kwargs) + try: + output = self._compute(**inputs, **compute_kwargs) + except EvaluationModuleError: + raise + except Exception as e: + raise EvaluationModuleError(f"Error computing {self.name} metric: {type(e).__name__}: {e}") from e if self.buf_writer is not None: self.buf_writer = None diff --git a/tests/test_metric.py b/tests/test_metric.py index 598b0f92..ab3b6981 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -8,7 +8,8 @@ import pytest from datasets.features import Features, Sequence, Value -from evaluate.module import EvaluationModule, EvaluationModuleInfo, combine +import evaluate +from evaluate.module import EvaluationModule, EvaluationModuleError, EvaluationModuleInfo, combine from .utils import require_tf, require_torch @@ -757,3 +758,35 @@ def test_modules_from_string_poslabel(self): self.assertDictEqual( expected_result, combined_evaluation.compute(predictions=predictions, references=references, pos_label=0) ) + + +class RaisingMetric(EvaluationModule): + """Dummy metric whose ``_compute`` raises a bare ``ValueError``, as scikit-learn does.""" + + def _info(self): + return EvaluationModuleInfo( + description="dummy metric that raises in _compute", + citation="insert citation here", + features=Features({"predictions": Value("int64"), "references": Value("int64")}), + ) + + def _compute(self, predictions, references): + raise ValueError("Found input variables with inconsistent numbers of samples") + + +class TestEvaluationModuleError(TestCase): + def test_error_is_exported_from_public_api(self): + self.assertTrue(hasattr(evaluate, "EvaluationModuleError")) + self.assertIs(evaluate.EvaluationModuleError, EvaluationModuleError) + + def test_compute_wraps_underlying_error(self): + metric = RaisingMetric(experiment_id="test_compute_wraps_underlying_error") + with self.assertRaises(EvaluationModuleError) as ctx: + metric.compute(predictions=[1], references=[1]) + # The original exception is preserved for debugging. + self.assertIsInstance(ctx.exception.__cause__, ValueError) + + def test_compute_catchable_via_public_api(self): + metric = RaisingMetric(experiment_id="test_compute_catchable_via_public_api") + with self.assertRaises(evaluate.EvaluationModuleError): + metric.compute(predictions=[1], references=[1])