From 276816e471d3a1145b3fba05a9fbd1e5d4a9041b Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Thu, 21 May 2026 21:00:24 +0200 Subject: [PATCH] Fix variable helper review findings --- logfire/variables/variable.py | 14 ++++-- tests/test_variables.py | 94 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 6 deletions(-) diff --git a/logfire/variables/variable.py b/logfire/variables/variable.py index 6dcc31003..62a1e1689 100644 --- a/logfire/variables/variable.py +++ b/logfire/variables/variable.py @@ -171,11 +171,11 @@ def get_template_inputs_schema(self) -> dict[str, Any] | None: """ return None - def _deserialize(self, serialized_value: str) -> T_co | ValidationError | ValueError: + def _deserialize(self, serialized_value: str) -> T_co | ValidationError | ValueError | TypeError: """Deserialize a JSON string to the variable's type, returning an Exception on failure.""" try: return self.type_adapter.validate_json(serialized_value) - except (ValidationError, ValueError) as e: + except (ValidationError, ValueError, TypeError) as e: return e @contextmanager @@ -241,6 +241,7 @@ def _resolve( attributes, span, render_fn=render_fn, + provider_exception=serialized_result.exception, ) if default_result is not None: return default_result @@ -270,7 +271,7 @@ def _render_default(self, default: Any, render_fn: Callable[[str], str]) -> T_co serialized = self.type_adapter.dump_json(default).decode('utf-8') rendered = render_fn(serialized) result = self._deserialize(rendered) - if isinstance(result, (ValidationError, ValueError)): + if isinstance(result, (ValidationError, ValueError, TypeError)): raise result return result @@ -338,13 +339,14 @@ def _resolve_serialized_default( attributes: Mapping[str, Any] | None, span: logfire.LogfireSpan | None, render_fn: Callable[[str], str] | None = None, + provider_exception: Exception | None = None, ) -> ResolvedVariable[T_co] | None: """Resolve the code default through composition/rendering when needed.""" + if render_fn is None: + return None serialized_default = self._get_serialized_default(targeting_key, attributes) if serialized_default is None: return None - if render_fn is None: - return None result = self._expand_and_deserialize( ResolvedVariable(name=self.name, value=serialized_default, reason='missing_config'), @@ -358,6 +360,8 @@ def _resolve_serialized_default( # The expansion succeeded against the code default; flag the top-level # reason as 'code_default' so callers can distinguish from a provider hit. result.reason = 'code_default' + if result.exception is None: + result.exception = provider_exception return result def _get_merged_attributes(self, attributes: Mapping[str, Any] | None = None) -> Mapping[str, Any]: diff --git a/tests/test_variables.py b/tests/test_variables.py index 2bd2ea91c..a821c0832 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -14,7 +14,7 @@ import pytest import requests_mock as requests_mock_module from inline_snapshot import snapshot -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError, field_validator from requests import Session import logfire @@ -1665,6 +1665,55 @@ def test_get_uses_default_when_no_config(self, config_kwargs: dict[str, Any]): assert result.value == 'my_default' assert result.reason == 'code_default' + def test_get_calls_function_default_once_when_no_config(self, config_kwargs: dict[str, Any]): + config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={})) + lf = logfire.configure(**config_kwargs) + calls = 0 + + def default(targeting_key: str | None, attributes: Mapping[str, Any] | None) -> str: + nonlocal calls + calls += 1 + return 'my_default' + + var = lf.var(name='unconfigured', default=default, type=str) + result = var.get() + assert result.value == 'my_default' + assert result.reason == 'code_default' + assert calls == 1 + + def test_get_preserves_metadata_with_deserialization_type_error(self, config_kwargs: dict[str, Any]): + class TypeErrorModel(BaseModel): + value: int + + @field_validator('value') + @classmethod + def fail_for_one(cls, value: int) -> int: + if value == 1: + raise TypeError('validator exploded') + return value + + config_kwargs['variables'] = LocalVariablesOptions( + config=VariablesConfig( + variables={ + 'type_error_var': VariableConfig( + name='type_error_var', + labels={'default': LabeledValue(version=1, serialized_value='{"value": 1}')}, + rollout=Rollout(labels={'default': 1.0}), + overrides=[], + ) + } + ) + ) + lf = logfire.configure(**config_kwargs) + + var = lf.var(name='type_error_var', default=TypeErrorModel(value=0), type=TypeErrorModel) + result = var.get() + assert result.value == TypeErrorModel(value=0) + assert isinstance(result.exception, TypeError) + assert result.reason == 'other_error' + assert result.label == 'default' + assert result.version == 1 + def test_plain_variable_has_no_template_inputs_schema(self, config_kwargs: dict[str, Any]): lf = logfire.configure(**config_kwargs) @@ -1707,6 +1756,27 @@ def test_render_fn_applies_to_context_override( assert invalid.reason == 'other_error' assert isinstance(invalid.exception, ValidationError) + class TypeErrorModel(BaseModel): + value: int + + @field_validator('value') + @classmethod + def fail_for_one(cls, value: int) -> int: + if value == 1: + raise TypeError('validator exploded') + return value + + type_error_var = lf.var(name='type_error_var', default=TypeErrorModel(value=0), type=TypeErrorModel) + + with type_error_var.override(TypeErrorModel(value=0)): + type_error_result = type_error_var._get_result_and_record_span( + None, None, None, render_fn=lambda _: '{"value": 1}' + ) + + assert type_error_result.value == TypeErrorModel(value=0) + assert type_error_result.reason == 'other_error' + assert isinstance(type_error_result.exception, TypeError) + def test_render_fn_applies_to_code_default(self, config_kwargs: dict[str, Any]): config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={})) lf = logfire.configure(**config_kwargs) @@ -1724,6 +1794,28 @@ def test_render_fn_applies_to_code_default(self, config_kwargs: dict[str, Any]): assert invalid.reason == 'validation_error' assert isinstance(invalid.exception, ValidationError) + def test_render_fn_preserves_provider_exception_when_using_code_default( + self, config_kwargs: dict[str, Any], monkeypatch: pytest.MonkeyPatch + ): + config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={})) + lf = logfire.configure(**config_kwargs) + provider_error = RuntimeError('missing') + + def missing_get( + variable_name: str, targeting_key: str | None = None, attributes: Mapping[str, Any] | None = None + ) -> ResolvedVariable[str | None]: + return ResolvedVariable( + name=variable_name, value=None, exception=provider_error, reason='unrecognized_variable' + ) + + monkeypatch.setattr(lf.config._variable_provider, 'get_serialized_value', missing_get) + + var = lf.var(name='unconfigured', default='my_default', type=str) + result = var._get_result_and_record_span(None, None, None, render_fn=lambda _: '"rendered_default"') + assert result.value == 'rendered_default' + assert result.reason == 'code_default' + assert result.exception is provider_error + def test_render_fn_skips_code_default_when_default_cannot_be_serialized(self, config_kwargs: dict[str, Any]): config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={})) lf = logfire.configure(**config_kwargs)