Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions logfire/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ def get_template_inputs_schema(self) -> dict[str, Any] | None:
"""
return None

def _deserialize(self, serialized_value: str) -> T_co | ValidationError | ValueError:
def _deserialize(self, serialized_value: str) -> T_co | ValidationError | ValueError | TypeError:
"""Deserialize a JSON string to the variable's type, returning an Exception on failure."""
try:
return self.type_adapter.validate_json(serialized_value)
except (ValidationError, ValueError) as e:
except (ValidationError, ValueError, TypeError) as e:
return e

@contextmanager
Expand Down Expand Up @@ -241,6 +241,7 @@ def _resolve(
attributes,
span,
render_fn=render_fn,
provider_exception=serialized_result.exception,
)
if default_result is not None:
return default_result
Expand Down Expand Up @@ -270,7 +271,7 @@ def _render_default(self, default: Any, render_fn: Callable[[str], str]) -> T_co
serialized = self.type_adapter.dump_json(default).decode('utf-8')
rendered = render_fn(serialized)
result = self._deserialize(rendered)
if isinstance(result, (ValidationError, ValueError)):
if isinstance(result, (ValidationError, ValueError, TypeError)):
raise result
return result

Expand Down Expand Up @@ -338,13 +339,14 @@ def _resolve_serialized_default(
attributes: Mapping[str, Any] | None,
span: logfire.LogfireSpan | None,
render_fn: Callable[[str], str] | None = None,
provider_exception: Exception | None = None,
) -> ResolvedVariable[T_co] | None:
"""Resolve the code default through composition/rendering when needed."""
if render_fn is None:
return None
serialized_default = self._get_serialized_default(targeting_key, attributes)
if serialized_default is None:
return None
if render_fn is None:
return None

result = self._expand_and_deserialize(
ResolvedVariable(name=self.name, value=serialized_default, reason='missing_config'),
Expand All @@ -358,6 +360,8 @@ def _resolve_serialized_default(
# The expansion succeeded against the code default; flag the top-level
# reason as 'code_default' so callers can distinguish from a provider hit.
result.reason = 'code_default'
if result.exception is None:
result.exception = provider_exception
return result

def _get_merged_attributes(self, attributes: Mapping[str, Any] | None = None) -> Mapping[str, Any]:
Expand Down
94 changes: 93 additions & 1 deletion tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest
import requests_mock as requests_mock_module
from inline_snapshot import snapshot
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ValidationError, field_validator
from requests import Session

import logfire
Expand Down Expand Up @@ -1665,6 +1665,55 @@ def test_get_uses_default_when_no_config(self, config_kwargs: dict[str, Any]):
assert result.value == 'my_default'
assert result.reason == 'code_default'

def test_get_calls_function_default_once_when_no_config(self, config_kwargs: dict[str, Any]):
config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={}))
lf = logfire.configure(**config_kwargs)
calls = 0

def default(targeting_key: str | None, attributes: Mapping[str, Any] | None) -> str:
nonlocal calls
calls += 1
return 'my_default'

var = lf.var(name='unconfigured', default=default, type=str)
result = var.get()
assert result.value == 'my_default'
assert result.reason == 'code_default'
assert calls == 1

def test_get_preserves_metadata_with_deserialization_type_error(self, config_kwargs: dict[str, Any]):
class TypeErrorModel(BaseModel):
value: int

@field_validator('value')
@classmethod
def fail_for_one(cls, value: int) -> int:
if value == 1:
raise TypeError('validator exploded')
return value

config_kwargs['variables'] = LocalVariablesOptions(
config=VariablesConfig(
variables={
'type_error_var': VariableConfig(
name='type_error_var',
labels={'default': LabeledValue(version=1, serialized_value='{"value": 1}')},
rollout=Rollout(labels={'default': 1.0}),
overrides=[],
)
}
)
)
lf = logfire.configure(**config_kwargs)

var = lf.var(name='type_error_var', default=TypeErrorModel(value=0), type=TypeErrorModel)
result = var.get()
assert result.value == TypeErrorModel(value=0)
assert isinstance(result.exception, TypeError)
assert result.reason == 'other_error'
assert result.label == 'default'
assert result.version == 1

def test_plain_variable_has_no_template_inputs_schema(self, config_kwargs: dict[str, Any]):
lf = logfire.configure(**config_kwargs)

Expand Down Expand Up @@ -1707,6 +1756,27 @@ def test_render_fn_applies_to_context_override(
assert invalid.reason == 'other_error'
assert isinstance(invalid.exception, ValidationError)

class TypeErrorModel(BaseModel):
value: int

@field_validator('value')
@classmethod
def fail_for_one(cls, value: int) -> int:
if value == 1:
raise TypeError('validator exploded')
return value

type_error_var = lf.var(name='type_error_var', default=TypeErrorModel(value=0), type=TypeErrorModel)

with type_error_var.override(TypeErrorModel(value=0)):
type_error_result = type_error_var._get_result_and_record_span(
None, None, None, render_fn=lambda _: '{"value": 1}'
)

assert type_error_result.value == TypeErrorModel(value=0)
assert type_error_result.reason == 'other_error'
assert isinstance(type_error_result.exception, TypeError)

def test_render_fn_applies_to_code_default(self, config_kwargs: dict[str, Any]):
config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={}))
lf = logfire.configure(**config_kwargs)
Expand All @@ -1724,6 +1794,28 @@ def test_render_fn_applies_to_code_default(self, config_kwargs: dict[str, Any]):
assert invalid.reason == 'validation_error'
assert isinstance(invalid.exception, ValidationError)

def test_render_fn_preserves_provider_exception_when_using_code_default(
self, config_kwargs: dict[str, Any], monkeypatch: pytest.MonkeyPatch
):
config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={}))
lf = logfire.configure(**config_kwargs)
provider_error = RuntimeError('missing')

def missing_get(
variable_name: str, targeting_key: str | None = None, attributes: Mapping[str, Any] | None = None
) -> ResolvedVariable[str | None]:
return ResolvedVariable(
name=variable_name, value=None, exception=provider_error, reason='unrecognized_variable'
)

monkeypatch.setattr(lf.config._variable_provider, 'get_serialized_value', missing_get)

var = lf.var(name='unconfigured', default='my_default', type=str)
result = var._get_result_and_record_span(None, None, None, render_fn=lambda _: '"rendered_default"')
assert result.value == 'rendered_default'
assert result.reason == 'code_default'
assert result.exception is provider_error

def test_render_fn_skips_code_default_when_default_cannot_be_serialized(self, config_kwargs: dict[str, Any]):
config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={}))
lf = logfire.configure(**config_kwargs)
Expand Down
Loading