Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@
from .info import ComparisonInfo, EvaluationModuleInfo, MeasurementInfo, MetricInfo
from .inspect import inspect_evaluation_module, list_evaluation_modules
from .loading import load
from .module import CombinedEvaluations, Comparison, EvaluationModule, Measurement, Metric, combine
from .module import (
CombinedEvaluations,
Comparison,
EvaluationModule,
EvaluationModuleError,
Measurement,
Metric,
combine,
)
from .saving import save
from .utils import *
from .utils import gradio, logging
17 changes: 16 additions & 1 deletion src/evaluate/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
logger = get_logger(__name__)


class EvaluationModuleError(Exception):
"""Base error raised when an evaluation module fails to compute its result.

Failures coming from the underlying ``_compute`` implementation (for example a
``ValueError`` or ``KeyError`` raised by scikit-learn) are wrapped in this error so
that callers can catch evaluate-specific failures without catching a bare
``Exception``. The original exception is preserved on ``__cause__``.
"""


class FileFreeLock(BaseFileLock):
"""Thread lock until a file **cannot** be locked"""

Expand Down Expand Up @@ -464,7 +474,12 @@ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[di

inputs = {input_name: self.data[input_name][:] for input_name in self._feature_names()}
with temp_seed(self.seed):
output = self._compute(**inputs, **compute_kwargs)
try:
output = self._compute(**inputs, **compute_kwargs)
except EvaluationModuleError:
raise
except Exception as e:
raise EvaluationModuleError(f"Error computing {self.name} metric: {type(e).__name__}: {e}") from e

if self.buf_writer is not None:
self.buf_writer = None
Expand Down
35 changes: 34 additions & 1 deletion tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest
from datasets.features import Features, Sequence, Value

from evaluate.module import EvaluationModule, EvaluationModuleInfo, combine
import evaluate
from evaluate.module import EvaluationModule, EvaluationModuleError, EvaluationModuleInfo, combine

from .utils import require_tf, require_torch

Expand Down Expand Up @@ -757,3 +758,35 @@ def test_modules_from_string_poslabel(self):
self.assertDictEqual(
expected_result, combined_evaluation.compute(predictions=predictions, references=references, pos_label=0)
)


class RaisingMetric(EvaluationModule):
"""Dummy metric whose ``_compute`` raises a bare ``ValueError``, as scikit-learn does."""

def _info(self):
return EvaluationModuleInfo(
description="dummy metric that raises in _compute",
citation="insert citation here",
features=Features({"predictions": Value("int64"), "references": Value("int64")}),
)

def _compute(self, predictions, references):
raise ValueError("Found input variables with inconsistent numbers of samples")


class TestEvaluationModuleError(TestCase):
def test_error_is_exported_from_public_api(self):
self.assertTrue(hasattr(evaluate, "EvaluationModuleError"))
self.assertIs(evaluate.EvaluationModuleError, EvaluationModuleError)

def test_compute_wraps_underlying_error(self):
metric = RaisingMetric(experiment_id="test_compute_wraps_underlying_error")
with self.assertRaises(EvaluationModuleError) as ctx:
metric.compute(predictions=[1], references=[1])
# The original exception is preserved for debugging.
self.assertIsInstance(ctx.exception.__cause__, ValueError)

def test_compute_catchable_via_public_api(self):
metric = RaisingMetric(experiment_id="test_compute_catchable_via_public_api")
with self.assertRaises(evaluate.EvaluationModuleError):
metric.compute(predictions=[1], references=[1])