Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions logfire/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@
# Per-`get()`-call cache for the code-default value. Keyed by `id(variable)`.
# Lets `_get_default_cached` short-circuit when the same callable default
# would otherwise be invoked twice in one resolution (once to feed
# composition, then again on a fallback path). Set up by `Variable._resolve`
# composition, then again on a fallback path). Each entry is
# `(ok: bool, value_or_exception: Any)`: successful invocations cache the
# returned value; raising invocations cache the exception so re-entry
# re-raises without re-invoking the callable. Set up by `Variable._resolve`
# at the top of the call and reset when it returns.
_DEFAULT_CACHE: ContextVar[dict[int, Any] | None] = ContextVar('_DEFAULT_CACHE', default=None)
_DEFAULT_CACHE: ContextVar[dict[int, tuple[bool, Any]] | None] = ContextVar('_DEFAULT_CACHE', default=None)


@dataclass
Expand Down Expand Up @@ -565,8 +568,12 @@ def _get_default_cached(
Avoids re-invoking a callable default twice when the same `get()`
consults it for the code-default tier (to feed composition) and then
again on a fallback path (composition failure, render failure,
deserialization failure). Outside a `_resolve` call the cache is
not set and this is a direct passthrough to `_get_default`.
deserialization failure). Both successful values and raised
exceptions are cached — a callable that raises on first invocation
re-raises (without re-invoking) on subsequent calls, so a failing
default doesn't get called multiple times either. Outside a
`_resolve` call the cache is not set and this is a direct
passthrough to `_get_default`.
"""
cache = _DEFAULT_CACHE.get()
if cache is None: # pragma: no cover
Expand All @@ -577,8 +584,14 @@ def _get_default_cached(
return self._get_default(targeting_key, merged_attributes)
key = id(self)
if key not in cache:
cache[key] = self._get_default(targeting_key, merged_attributes)
return cache[key]
try:
cache[key] = (True, self._get_default(targeting_key, merged_attributes))
except Exception as e:
cache[key] = (False, e)
ok, payload = cache[key]
if ok:
return cast('T_co', payload)
raise cast('Exception', payload)

def _get_serialized_default(
self, targeting_key: str | None = None, merged_attributes: Mapping[str, Any] | None = None
Expand Down
29 changes: 29 additions & 0 deletions tests/test_variable_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,35 @@ def raise_composition_error(*args: Any, **kwargs: Any) -> Any:
assert result.reason == 'other_error'
assert call_count == 1, f'callable default invoked {call_count} times, expected 1'

def test_failing_callable_default_invoked_once_per_get(self, config_kwargs: dict[str, Any]):
"""A callable default that *raises* is invoked only once per `get()`.

Regression for #1954 r3296066209 — `_get_default_cached` originally
cached only successful values, so a raising callable escaped the
cache and could be re-invoked up to three times in one `get()`
(once each in `_get_serialized_default`, `_resolve_code_default`,
and the outer-`except` fallback). The cache now records the
exception too and re-raises it on subsequent lookups.
"""
config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={}))
lf = logfire.configure(**config_kwargs)

call_count = 0

def always_raises(targeting_key: str | None, attributes: Any) -> str:
nonlocal call_count
call_count += 1
raise RuntimeError('default unavailable')

var = lf.var(name='failing_default', default=always_raises, type=str)
result = var.get()

# The variable still resolves to something — the outer `except` in
# `_resolve` swallows the raised exception and returns `None`-typed.
# The point under test is the call count.
assert result.reason == 'other_error'
assert call_count == 1, f'failing callable invoked {call_count} times, expected 1'

def test_nested_reference(self, config_kwargs: dict[str, Any]):
"""A→B→C chain resolves fully."""
variables_config = _make_variables_config(
Expand Down
Loading