Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions packages/reflex-base/src/reflex_base/compiler/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,20 @@ def app_root_template(
return f"""
{imports_str}
{dynamic_imports_str}
import {{ EventLoopProvider, StateProvider, defaultColorMode }} from "$/utils/context";
import {{ defaultColorMode }} from "$/utils/context";
import {{ ThemeProvider }} from '$/utils/react-theme';
import {{ Layout as AppLayout }} from './_document';
import {{ Outlet }} from 'react-router';
{import_window_libraries}

{custom_code_str}

// AppWrap is the innermost element of the python app-wrap chain (rendered
// in Layout), so providers in the chain are React-tree ancestors of the
// hooks hoisted here.
function AppWrap({{children}}) {{
{_render_hooks(hooks)}
return ({_RenderUtils.render(render)})
return (children);
}}


Expand All @@ -225,11 +228,7 @@ def app_root_template(

return jsx(AppLayout, {{}},
jsx(ThemeProvider, {{defaultTheme: defaultColorMode, attribute: "class"}},
jsx(StateProvider, {{}},
jsx(EventLoopProvider, {{}},
jsx(AppWrap, {{}}, children)
)
)
{_RenderUtils.render(render)}
)
);
}}
Expand Down
16 changes: 11 additions & 5 deletions packages/reflex-base/src/reflex_base/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,14 +1879,20 @@ def _get_vars_hooks(self) -> dict[str, VarData | None]:
def _get_events_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by events referenced in this component.

The ``Hooks.EVENTS`` hook reads ``EventLoopContext``; declaring the
state/event-loop providers in its VarData pulls them into the app
root for every event-triggering component, independent of whether
the app uses ``rx.State`` directly.

Returns:
The hooks for the events.
"""
return (
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.event_triggers
else {}
)
if not self.event_triggers:
return {}
# Lazy import: ``state_context`` imports ``Component`` from this module.
from reflex_base.components.state_context import get_events_hooks_var_data

return {Hooks.EVENTS: get_events_hooks_var_data()}

def _get_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component managed by the framework.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""App-wrap components mounting the state and event-loop React providers.

These wrap children in the ``StateProvider`` / ``EventLoopProvider`` JS
functions emitted into ``utils/context.js`` by ``compile_contexts``. They are
attached to the VarData returned by :meth:`reflex_base.vars.base.VarData.from_state`
so the compiler picks them up through the generic Var-driven app-wrap pipeline,
rather than the JS Layout template hard-coding them around every app.
"""

from __future__ import annotations

from reflex_base.components.component import Component
from reflex_base.constants import Dirs
from reflex_base.constants.compiler import Hooks
from reflex_base.vars.base import VarData


class StateContextProvider(Component):
"""App wrap that mounts the React state-context provider around children."""

library = f"$/{Dirs.CONTEXTS_PATH}"
tag = "StateProvider"


class EventLoopContextProvider(Component):
"""App wrap that mounts the websocket event-loop provider around children."""

library = f"$/{Dirs.CONTEXTS_PATH}"
tag = "EventLoopProvider"


def get_events_hooks_var_data() -> VarData:
"""Build the VarData attached to ``Hooks.EVENTS`` for event triggers.

Higher priority wraps further out, so ``StateProvider`` (100) encloses
``EventLoopProvider`` (90) — the latter reads ``DispatchContext`` from
the former. The returned providers are fresh per call: the compiler's
``app_wrap_components`` registry already dedupes by ``(priority, tag)``,
and caching the instances burned us via ``copy.deepcopy`` carrying
``_cached_render_result`` from a prior compile run forward into the next.

Returns:
A new VarData carrying both providers as app_wraps.
"""
return VarData(
position=Hooks.HookPosition.INTERNAL,
app_wraps=(
(100, StateContextProvider.create()),
(90, EventLoopContextProvider.create()),
),
)
16 changes: 12 additions & 4 deletions packages/reflex-base/src/reflex_base/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,14 +1055,14 @@ def _as_event_spec(
"""
from reflex_components_core.core.upload import (
DEFAULT_UPLOAD_ID,
upload_files_context_var_data,
get_upload_files_context_var_data,
)

upload_id = self.upload_id if self.upload_id is not None else DEFAULT_UPLOAD_ID
upload_files_var = Var(
_js_expr="filesById",
_var_type=dict[str, Any],
_var_data=VarData.merge(upload_files_context_var_data),
_var_data=VarData.merge(get_upload_files_context_var_data()),
).to(ObjectVar)[LiteralVar.create(upload_id)]
spec_args = [
(
Expand Down Expand Up @@ -2335,11 +2335,14 @@ def create(
arg_def_expr = Var(_js_expr="args")

if value.invocation is None:
# Lazy import: state_context → component → event (this module).
from reflex_base.components.state_context import get_events_hooks_var_data

invocation = FunctionStringVar.create(
CompileVars.ADD_EVENTS,
_var_data=VarData(
imports=Imports.EVENTS,
hooks={Hooks.EVENTS: None},
hooks={Hooks.EVENTS: get_events_hooks_var_data()},
),
)
else:
Expand Down Expand Up @@ -2380,11 +2383,16 @@ def create(
_js_expr=f"{{{''.join(f'{statement};' for statement in statements)}}}",
)
if value.event_actions:
# Lazy import: state_context → component → event (this module).
from reflex_base.components.state_context import (
get_events_hooks_var_data,
)

apply_event_actions = FunctionStringVar.create(
CompileVars.APPLY_EVENT_ACTIONS,
_var_data=VarData(
imports=Imports.EVENTS,
hooks={Hooks.EVENTS: None},
hooks={Hooks.EVENTS: get_events_hooks_var_data()},
),
)
return_expr = apply_event_actions.call(
Expand Down
90 changes: 89 additions & 1 deletion packages/reflex-base/src/reflex_base/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ class VarData:
# Components that are part of this var
components: tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)

# App-level wrapper components this var requires when used (priority, component).
# Higher priority wraps further out, matching Component._get_app_wrap_components semantics.
app_wraps: tuple[tuple[int, BaseComponent], ...] = dataclasses.field(
default_factory=tuple
)

def __init__(
self,
state: str = "",
Expand All @@ -150,6 +156,7 @@ def __init__(
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
components: Iterable[BaseComponent] | None = None,
app_wraps: Iterable[tuple[int, BaseComponent]] | None = None,
):
"""Initialize the var data.

Expand All @@ -161,6 +168,7 @@ def __init__(
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
components: Components that are part of this var.
app_wraps: App-level wrapper components this var requires when used.
"""
if isinstance(hooks, str):
hooks = [hooks]
Expand All @@ -176,6 +184,7 @@ def __init__(
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)
object.__setattr__(self, "components", tuple(components or []))
object.__setattr__(self, "app_wraps", tuple(app_wraps or []))

if hooks and any(hooks.values()):
# Merge our dependencies first, so they can be referenced.
Expand All @@ -188,6 +197,7 @@ def __init__(
object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position)
object.__setattr__(self, "components", merged_var_data.components)
object.__setattr__(self, "app_wraps", merged_var_data.app_wraps)

def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
Expand Down Expand Up @@ -259,6 +269,16 @@ def merge(*all: VarData | None) -> VarData | None:
component for var_data in all_var_datas for component in var_data.components
)

app_wraps_seen: set[tuple[int, str]] = set()
app_wraps_list: list[tuple[int, BaseComponent]] = []
for var_data in all_var_datas:
for priority, wrapper in var_data.app_wraps:
key = (priority, wrapper.tag or type(wrapper).__name__)
if key in app_wraps_seen:
continue
Comment thread
masenf marked this conversation as resolved.
Outdated
app_wraps_seen.add(key)
app_wraps_list.append((priority, wrapper))

return VarData(
state=state,
field_name=field_name,
Expand All @@ -267,6 +287,7 @@ def merge(*all: VarData | None) -> VarData | None:
deps=deps,
position=position,
components=components,
app_wraps=tuple(app_wraps_list),
)

Comment thread
FarhanAliRaza marked this conversation as resolved.
def __bool__(self) -> bool:
Expand All @@ -283,8 +304,59 @@ def __bool__(self) -> bool:
or self.deps
or self.position
or self.components
or self.app_wraps
)

def _identity_key(self) -> tuple:
"""Return a hashable key for ``__eq__`` and ``__hash__``.

``components`` and ``app_wraps`` hold ``BaseComponent`` instances whose
``__eq__`` override drops the default hash. Use component identity for
embedded components because they can contribute hooks/imports, and use
the compiler's app-wrap registry key for wrappers so fresh provider
instances with the same role still compare equal.

Returns:
A hashable tuple uniquely identifying this VarData.
"""
return (
self.state,
self.field_name,
self.imports,
self.hooks,
self.deps,
self.position,
tuple(id(component) for component in self.components),
tuple(
(
priority,
component.tag or type(component).__name__,
)
for priority, component in self.app_wraps
),
)

def __eq__(self, other: object) -> bool:
"""Compare two VarData by render-time identity.

Args:
other: The value to compare against.

Returns:
True if ``other`` is a VarData with matching render-time fields.
"""
if not isinstance(other, VarData):
return NotImplemented
return self._identity_key() == other._identity_key()

def __hash__(self) -> int:
"""Hash consistent with ``__eq__``.

Returns:
A hash over render-time fields and hashable component metadata.
"""
return hash(self._identity_key())

@classmethod
def from_state(cls, state: type[BaseState] | str, field_name: str = "") -> VarData:
"""Set the state of the var.
Expand All @@ -296,6 +368,11 @@ def from_state(cls, state: type[BaseState] | str, field_name: str = "") -> VarDa
Returns:
The var with the set state.
"""
# Lazy imports: state_context imports VarData from this module.
from reflex_base.components.state_context import (
EventLoopContextProvider,
StateContextProvider,
)
from reflex_base.utils import format

state_name = state if isinstance(state, str) else state.get_full_name()
Expand All @@ -311,6 +388,17 @@ def from_state(cls, state: type[BaseState] | str, field_name: str = "") -> VarDa
f"$/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")],
},
app_wraps=(
# Higher priority wraps further out. ``StateProvider`` must
# enclose ``EventLoopProvider`` because the latter reads
# ``DispatchContext`` (provided by StateProvider) at its top.
# Both must enclose the chain's other wraps so the hooks
# AppWrap hosts (e.g. ``useContext(EventLoopContext)``) see
# them as React-tree ancestors. The compiler dedupes by
# ``(priority, tag)`` so fresh per-call instances are fine.
(100, StateContextProvider.create()),
(90, EventLoopContextProvider.create()),
),
)


Expand Down Expand Up @@ -362,7 +450,7 @@ def can_use_in_object_var(cls: GenericType) -> bool:
Whether the class can be used in an ObjectVar.
"""
if types.is_union(cls):
return all(can_use_in_object_var(t) for t in types.get_args(cls))
return all(can_use_in_object_var(t) for t in get_args(cls))
return (
isinstance(cls, type)
and not safe_issubclass(cls, Var)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
from reflex_base.components.component import Component
from reflex_base.vars.base import Var

from reflex_components_core.base.fragment import Fragment

class AppWrap(Component):
"""Innermost element of the app-wrap chain.

class AppWrap(Fragment):
"""Top-level component that wraps the entire app."""
Renders as ``<AppWrap>{children}</AppWrap>`` — the locally-defined JS
function in ``app_root_template`` that hosts all hooks aggregated from
the python chain and returns its children. Library is ``None`` because
the JS function is defined in the same file the component renders into.
"""

library = None
tag = "AppWrap"

@classmethod
def create(cls) -> Component:
Expand Down
Loading
Loading