Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion logfire/variables/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,11 @@ def format(self, *, colors: bool = True) -> str:

variables_with_errors = len({e.variable_name for e in self.errors})
valid_count = self.variables_checked - variables_with_errors - len(self.variables_not_on_server)
if valid_count > 0:
# Only advertise "Valid" when the report as a whole is valid. Otherwise
# a partial pass (per-variable type checks succeeded but reference /
# template-field errors exist) emits the contradictory pair
# "=== Valid (N variables) ===" + "=== Reference errors ===".
if valid_count > 0 and self.is_valid:
lines.append(f'\n{green}=== Valid ({valid_count} variables) ==={reset}')

# Show description differences as informational warnings
Expand Down
57 changes: 15 additions & 42 deletions logfire/variables/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,13 @@ def expand_references(
return serialized_value, composed
Comment thread
dmontagu marked this conversation as resolved.

# Collect all unique base variable names referenced anywhere in the decoded
# value. If there are none we still walk the structure through `_render_value`
# — the value may contain only escape sequences (`\@{x}@` etc.) that need
# to be processed through the renderer to produce the literal output.
all_ref_names = _collect_ref_names(decoded)
# value. Sorted so composition resolution order is deterministic — which
# `composed_from` entry surfaces first, which error gets reported when
# several refs fail, etc. shouldn't depend on set-iteration order. If there
# are none we still walk the structure through `_render_value` — the value
# may contain only escape sequences (`\@{x}@` etc.) that need to be
# processed through the renderer to produce the literal output.
all_ref_names = sorted(_collect_ref_names(decoded))

# Resolve each unique variable name and recursively expand nested references.
context: dict[str, Any] = {}
Expand Down Expand Up @@ -267,10 +270,9 @@ def find_references(serialized_value: str) -> list[str]:
serialized_value: The raw JSON-serialized variable value to scan.

Returns:
List of unique top-level variable names referenced, in order of
first occurrence.
Sorted (alphabetical) list of unique top-level variable names referenced.
"""
return _collect_ref_names(_safe_json_load(serialized_value))
return sorted(_collect_ref_names(_safe_json_load(serialized_value)))


# ---------------------------------------------------------------------------
Expand All @@ -286,33 +288,22 @@ def _safe_json_load(serialized_value: str) -> Any:
return None


def _collect_ref_names(value: Any) -> list[str]:
def _collect_ref_names(value: Any) -> set[str]:
Comment thread
dmontagu marked this conversation as resolved.
"""Recursively walk a decoded JSON value and collect unique top-level reference names.

For each string the AST-aware
``pydantic_handlebars.extract_dependencies`` picks the authoritative set
of real references (so block helpers, dotted paths, subexpressions are
handled correctly and Handlebars helper names are excluded). Names from
that set are added to the result list ordered by their first textual
occurrence in the source string, giving deterministic output across
`dict` iteration orders.
handled correctly and Handlebars helper names are excluded).
"""
from logfire.variables._handlebars import extract_composition_dependencies

seen: set[str] = set()
result: list[str] = []
refs: set[str] = set()

def _walk(v: Any) -> None:
if isinstance(v, str):
if not has_references(v):
return
valid = extract_composition_dependencies(v)
if not valid:
return
for name in _order_by_first_position(valid, v):
if name not in seen:
seen.add(name)
result.append(name)
if has_references(v):
refs.update(extract_composition_dependencies(v))
elif isinstance(v, dict):
for val in v.values(): # pyright: ignore[reportUnknownVariableType]
_walk(val)
Expand All @@ -321,25 +312,7 @@ def _walk(v: Any) -> None:
_walk(item)

_walk(value)
return result


def _order_by_first_position(names: set[str], source: str) -> list[str]:
"""Order *names* by their first whole-word occurrence in *source*.

Used to give `find_references` deterministic, source-order output for a
set produced by `extract_composition_dependencies`. Names that don't
appear textually in the source — which shouldn't happen for refs
returned by the AST walker, but is defensive — sort to the end of the
list in alphabetical order so output remains stable.
"""
positions: dict[str, int] = {}
for name in names:
# Use a word-boundary search so e.g. `key` doesn't match inside `keyword`.
pattern = re.compile(rf'\b{re.escape(name)}\b')
match = pattern.search(source)
positions[name] = match.start() if match is not None else len(source) + sum(map(ord, name))
return sorted(names, key=lambda n: positions[n])
return refs


def _render_value(value: Any, context: dict[str, Any], unresolved_names: set[str]) -> Any:
Expand Down
59 changes: 54 additions & 5 deletions logfire/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@

_VARIABLE_OVERRIDES: ContextVar[dict[str, Any] | None] = ContextVar('_VARIABLE_OVERRIDES', default=None)

# Per-`get()`-call cache for the code-default value. Keyed by `id(variable)`.
# Lets `_get_default_cached` short-circuit when the same callable default
# would otherwise be invoked twice in one resolution (once to feed
# composition, then again on a fallback path). Set up by `Variable._resolve`
# at the top of the call and reset when it returns.
_DEFAULT_CACHE: ContextVar[dict[int, Any] | None] = ContextVar('_DEFAULT_CACHE', default=None)


@dataclass
class _TargetingContextData:
Expand Down Expand Up @@ -237,6 +244,25 @@ def _resolve(
span: logfire.LogfireSpan | None,
label: str | None = None,
render_fn: Callable[[str], str] | None = None,
) -> ResolvedVariable[T_co]:
# `_DEFAULT_CACHE` memoises the code-default value across every
# `_get_default_cached` call inside this `get()` invocation, so a
# callable default isn't re-invoked when the code-default tier
# supplies the value AND a downstream step (composition expansion,
# template render, deserialization) falls back to it.
cache_token = _DEFAULT_CACHE.set({})
try:
return self._resolve_inner(targeting_key, attributes, span, label, render_fn)
finally:
_DEFAULT_CACHE.reset(cache_token)

def _resolve_inner(
self,
targeting_key: str | None,
attributes: Mapping[str, Any] | None,
span: logfire.LogfireSpan | None,
label: str | None,
render_fn: Callable[[str], str] | None,
) -> ResolvedVariable[T_co]:
serialized_result: ResolvedVariable[str | None] | None = None
try:
Expand Down Expand Up @@ -290,7 +316,7 @@ def _resolve(
span.set_attribute('invalid_serialized_label', serialized_result.label)
span.set_attribute('invalid_serialized_value', serialized_result.value)
try:
default = self._get_default(targeting_key, attributes)
default = self._get_default_cached(targeting_key, attributes)
except Exception:
default = cast('T_co', None)
return ResolvedVariable(name=self.name, value=default, exception=e, reason='other_error')
Expand Down Expand Up @@ -473,7 +499,7 @@ def resolve_ref(
reason: str = 'validation_error' if isinstance(value_or_exc, ValidationError) else 'other_error'
return ResolvedVariable(
name=self.name,
value=self._get_default(targeting_key, attributes),
value=self._get_default_cached(targeting_key, attributes),
exception=value_or_exc,
reason=reason,
label=serialized_result.label,
Expand Down Expand Up @@ -515,7 +541,7 @@ def _fallback_to_default(
)
return ResolvedVariable(
name=self.name,
value=self._get_default(targeting_key, attributes),
value=self._get_default_cached(targeting_key, attributes),
exception=exception,
reason='other_error',
label=serialized_result.label,
Expand All @@ -531,12 +557,35 @@ def _get_default(
else:
return self.default

def _get_default_cached(
self, targeting_key: str | None = None, merged_attributes: Mapping[str, Any] | None = None
) -> T_co:
"""Return the code default, memoised for the duration of one `_resolve` call.

Avoids re-invoking a callable default twice when the same `get()`
consults it for the code-default tier (to feed composition) and then
again on a fallback path (composition failure, render failure,
deserialization failure). Outside a `_resolve` call the cache is
not set and this is a direct passthrough to `_get_default`.
"""
cache = _DEFAULT_CACHE.get()
if cache is None: # pragma: no cover
# Defensive: every production call site is inside `_resolve`,
# which sets the cache. Falling back to a direct compute keeps
# the helper safe if someone reaches in from an unexpected
# entry point in the future.
return self._get_default(targeting_key, merged_attributes)
key = id(self)
if key not in cache:
cache[key] = self._get_default(targeting_key, merged_attributes)
return cache[key]

def _get_serialized_default(
self, targeting_key: str | None = None, merged_attributes: Mapping[str, Any] | None = None
) -> str | None:
"""Return the code default serialized as JSON, or None if serialization fails."""
try:
default = self._get_default(targeting_key, merged_attributes)
default = self._get_default_cached(targeting_key, merged_attributes)
return self.type_adapter.dump_json(default).decode('utf-8')
except (ValueError, TypeError, RuntimeError):
return None
Expand All @@ -557,7 +606,7 @@ def _resolve_code_default(
"""
return ResolvedVariable(
name=self.name,
value=self._get_default(targeting_key, attributes),
value=self._get_default_cached(targeting_key, attributes),
exception=serialized_result.exception,
label=serialized_result.label,
version=serialized_result.version,
Expand Down
71 changes: 53 additions & 18 deletions tests/test_variable_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def test_multiple_references(self):
)
expanded, composed = expand_references('"@{greeting}@ @{name}@!"', 'my_var', resolve_fn)
assert expanded == '"Hello World!"'
assert len(composed) == 2
assert composed[0].name == 'greeting'
assert composed[1].name == 'name'
assert {ref.name for ref in composed} == {'greeting', 'name'}

def test_same_reference_multiple_times(self):
"""The same @{ref}@ used multiple times expands each occurrence."""
Expand Down Expand Up @@ -219,7 +217,7 @@ def test_unresolvable_dotted_reference_preserves_resolved_refs(self):
resolve_fn = _make_resolve_fn({'known': '"there"'})
expanded, composed = expand_references('"Hi @{known}@ @{missing.field}@"', 'my_var', resolve_fn)
assert expanded == '"Hi there @{missing.field}@"'
assert [ref.name for ref in composed] == ['known', 'missing']
assert {ref.name for ref in composed} == {'known', 'missing'}

def test_unresolvable_simple_and_dotted_reference_same_base(self):
"""Simple and dotted unresolved refs for the same base are both preserved."""
Expand Down Expand Up @@ -289,7 +287,7 @@ def test_list_with_references(self):
expanded, composed = expand_references(serialized, 'my_var', resolve_fn)

assert json.loads(expanded) == ['Hello Alice', 42, {'nested': 'Alice'}]
assert [ref.name for ref in composed] == ['greeting', 'name']
assert {ref.name for ref in composed} == {'greeting', 'name'}

def test_keyword_block_references_are_ignored(self):
"""Handlebars built-in names (`this`, helpers, `else`) aren't treated as variable references.
Expand Down Expand Up @@ -382,11 +380,13 @@ def test_single_reference(self):
assert find_references('"@{greeting}@"') == ['greeting']

def test_multiple_unique_references(self):
assert find_references('"@{a}@ @{b}@ @{c}@"') == ['a', 'b', 'c']
# Sorted alphabetically — the parser doesn't surface source order, and
# callers shouldn't depend on iteration-order-dependent behaviour.
assert find_references('"@{c}@ @{a}@ @{b}@"') == ['a', 'b', 'c']

def test_duplicate_references(self):
"""Duplicates are deduplicated, order preserved."""
assert find_references('"@{a}@ @{b}@ @{a}@"') == ['a', 'b']
"""The same name appearing in multiple `@{ref}@` slots is deduplicated."""
assert find_references('"@{b}@ @{a}@ @{b}@"') == ['a', 'b']

def test_escaped_not_matched(self):
assert find_references(r'"\\@{escaped}@"') == []
Expand All @@ -397,7 +397,7 @@ def test_mixed_escaped_and_real(self):

def test_in_structured_json(self):
serialized = json.dumps({'prompt': '@{safety}@', 'other': '@{format}@'})
assert find_references(serialized) == ['safety', 'format']
assert find_references(serialized) == ['format', 'safety']

def test_find_references_block_helpers(self):
"""find_references detects variable names from block helper syntax."""
Expand Down Expand Up @@ -560,16 +560,13 @@ class TestFindReferencesNativeHandlebarsSyntax:

def test_dotted_path_in_block_helper_header_contributes_top_level(self):
# `@{#if user.active}@` only references `user` at the top level.
refs = find_references('"@{#if user.active}@x@{/if}@"')
assert refs == ['user']
assert find_references('"@{#if user.active}@x@{/if}@"') == ['user']

def test_each_block_helper_contributes_iterable_name(self):
refs = find_references('"@{#each tags}@@{this}@@{/each}@"')
assert refs == ['tags']
assert find_references('"@{#each tags}@@{this}@@{/each}@"') == ['tags']

def test_lookup_helper_arguments_are_refs(self):
refs = find_references('"@{lookup obj key}@"')
assert sorted(refs) == ['key', 'obj']
assert find_references('"@{lookup obj key}@"') == ['key', 'obj']

def test_known_helpers_are_not_treated_as_context_refs(self):
# `if` / `each` / `lookup` are registered helpers; their names must
Expand All @@ -578,13 +575,12 @@ def test_known_helpers_are_not_treated_as_context_refs(self):
# resolve against each iteration item rather than the top-level
# context and are not top-level dependencies.
refs = find_references('"@{#if cond}@@{#each items}@@{lookup obj key}@@{/each}@@{/if}@"')
assert sorted(refs) == ['cond', 'items']
assert refs == ['cond', 'items']

def test_lookup_args_at_top_level_are_refs(self):
# When the helper call is at the top level (no enclosing context-
# shifting block), its arguments are top-level deps.
refs = find_references('"@{lookup obj key}@"')
assert sorted(refs) == ['key', 'obj']
assert find_references('"@{lookup obj key}@"') == ['key', 'obj']


# =============================================================================
Expand Down Expand Up @@ -713,6 +709,45 @@ def raise_composition_error(*args: Any, **kwargs: Any) -> Any:
assert result.exception is not None
assert result.reason == 'other_error'

def test_callable_default_invoked_once_on_composition_failure(
self, config_kwargs: dict[str, Any], monkeypatch: pytest.MonkeyPatch
):
"""A callable default must not be re-invoked on the composition-failure fallback path.

Regression for #1954 r3287513610 — when the code-default tier
supplies the value AND composition then fails, both the
serialize step (in `_lookup_serialized` → `_get_serialized_default`)
and the fallback step (in `_fallback_to_default`) previously
invoked the callable, doubling side effects. With `_DEFAULT_CACHE`
in place the callable is invoked once per `get()`.
"""
config_kwargs['variables'] = LocalVariablesOptions(config=VariablesConfig(variables={}))
lf = logfire.configure(**config_kwargs)

call_count = 0

def make_default(targeting_key: str | None, attributes: Any) -> str:
nonlocal call_count
call_count += 1
# Returns a value that contains a reference, so composition
# runs against it; we force composition to fail below so the
# fallback path also needs the default.
return '@{missing_for_test}@'

# Provider has nothing; code default (the callable above) supplies
# the serialized value. Then we force composition to fail.
def raise_composition_error(*args: Any, **kwargs: Any) -> Any:
raise VariableCompositionError('forced composition failure')

monkeypatch.setattr('logfire.variables.variable.expand_references', raise_composition_error)

var = lf.var(name='callable_default', default=make_default, type=str)
with pytest.warns(RuntimeWarning, match='composition failed'):
result = var.get()

assert result.reason == 'other_error'
assert call_count == 1, f'callable default invoked {call_count} times, expected 1'

def test_nested_reference(self, config_kwargs: dict[str, Any]):
"""A→B→C chain resolves fully."""
variables_config = _make_variables_config(
Expand Down
Loading