Skip to content
Merged
107 changes: 83 additions & 24 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import traceback
import urllib.parse
from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence
from contextvars import Token
from datetime import datetime
from itertools import chain
from pathlib import Path
Expand All @@ -31,6 +32,7 @@
evaluate_style_namespaces,
)
from reflex_base.config import get_config
from reflex_base.context.base import BaseContext
from reflex_base.environment import ExecutorType, environment
from reflex_base.event import (
_EVENT_FIELDS,
Expand All @@ -40,6 +42,7 @@
IndividualEventType,
noop,
)
from reflex_base.event.context import EventContext
from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor
from reflex_base.registry import RegistrationContext
from reflex_base.utils import console
Expand Down Expand Up @@ -574,8 +577,65 @@ async def modified_send(message: Message):
# Ensure the event processor starts and stops with the server.
self.register_lifespan_task(self._setup_event_processor)

def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp:
"""Ensure the RegistrationContext is attached to the ASGI app.
def _set_contexts_internal(self) -> dict[type[BaseContext], Token]:
"""Set Reflex contexts if not already present, returning reset tokens.

Returns:
A dict mapping context class to the contextvars Token for each
context that was set. Empty if all contexts were already present.
"""
tokens: dict[type[BaseContext], Token] = {}

if self._registration_context is not None:
try:
RegistrationContext.get()
except LookupError:
tokens[RegistrationContext] = RegistrationContext.set(
self._registration_context
)

if (
self._event_processor is not None
and self._event_processor._root_context is not None
):
try:
EventContext.get()
except LookupError:
tokens[EventContext] = EventContext.set(
self._event_processor._root_context
)

return tokens

def set_contexts(self) -> contextlib.AbstractContextManager:
"""Set Reflex contexts needed for state and event processing.

Pushes RegistrationContext and EventContext into the current
contextvars scope, but only if they are not already set.

Can be used as a context manager::

with app.set_contexts():
async with app.modify_state(token) as state:
...

Returns:
A context manager that resets any contexts that were set on exit.
"""
tokens = self._set_contexts_internal()
if not tokens:
return contextlib.nullcontext()
stack = contextlib.ExitStack()
for ctx_cls, tok in tokens.items():
stack.callback(ctx_cls.reset, tok)
return stack

def _context_middleware(self, app: ASGIApp) -> ASGIApp:
"""Ensure Reflex contexts are attached for each ASGI request.

Many ASGI servers start each request with a fresh contextvars scope,
so this middleware re-applies the RegistrationContext and EventContext
that are needed for Reflex state and event processing.

Args:
app: The ASGI app to attach the middleware to.
Expand All @@ -584,14 +644,11 @@ def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp:
The ASGI app with the middleware attached.
"""

async def registration_context_middleware(
scope: Scope, receive: Receive, send: Send
):
if self._registration_context is not None:
RegistrationContext.set(self._registration_context)
async def context_middleware(scope: Scope, receive: Receive, send: Send):
self._set_contexts_internal()
await app(scope, receive, send)

return registration_context_middleware
return context_middleware

@contextlib.asynccontextmanager
async def _setup_event_processor(self) -> AsyncIterator[None]:
Expand Down Expand Up @@ -669,10 +726,10 @@ def __call__(self) -> ASGIApp:
asgi_app = api_transformer(asgi_app)

top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
# Make sure the RegistrationContext is attached.
# Make sure Reflex contexts are attached for each request.
top_asgi_app.mount(
"",
self._registration_context_middleware(asgi_app),
self._context_middleware(asgi_app),
)
App._add_cors(top_asgi_app)
return top_asgi_app
Expand Down Expand Up @@ -1607,20 +1664,22 @@ async def modify_state(
if isinstance(token, str):
token = BaseStateToken.from_legacy_token(token, root_state=self._state)

# Get exclusive access to the state.
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars, **context
) as state:
# No other event handler can modify the state while in this context.
yield state
delta = await state._get_resolved_delta()
state._clean()
if delta:
# When the frontend vars are modified emit the delta to the frontend.
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
token=token.ident,
)
# Ensure Reflex contexts are available (e.g. when called from an API route).
with self.set_contexts():
# Get exclusive access to the state.
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars, **context
) as state:
# No other event handler can modify the state while in this context.
yield state
delta = await state._get_resolved_delta()
state._clean()
if delta:
# When the frontend vars are modified emit the delta to the frontend.
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
token=token.ident,
)

def _validate_exception_handlers(self):
"""Validate the custom event exception handlers for front- and backend.
Expand Down
90 changes: 89 additions & 1 deletion reflex/istate/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,38 @@ def _rehydrate(self):
State.set_is_hydrated(True),
]

async def _resolve_linked_state(
self, state_cls: type["BaseState"], linked_token: str
) -> "BaseState":
"""Load and patch a linked state that was not pre-loaded in the tree.

Called by State._get_state_from_redis when a state in
_reflex_internal_links is not yet in the cache. This loads the
private copy into the tree first, then patches the linked version
on top of it via _internal_patch_linked_state.

Args:
state_cls: The shared state class to resolve.
linked_token: The shared token the state is linked to.

Returns:
The linked state instance, patched into the current tree.

Raises:
ReflexRuntimeError: If the resolved state is not a SharedState.
"""
root_state = self._get_root_state()

# Load the private copy into the tree so _internal_patch_linked_state
# has an original to swap out (needed for unlink / restore).
original_state = await BaseState._get_state_from_redis(root_state, state_cls)

if isinstance(original_state, SharedStateBaseInternal):
return await original_state._internal_patch_linked_state(linked_token)

msg = f"Failed to resolve linked state {state_cls.get_full_name()} for token {linked_token}: state does not inherit from rx.SharedState"
raise ReflexRuntimeError(msg)

async def _link_to(self, token: str) -> Self:
"""Link this shared state to a token.

Expand All @@ -194,7 +226,7 @@ async def _link_to(self, token: str) -> Self:
raise ReflexRuntimeError(msg)
if not isinstance(self, SharedState):
msg = "Can only link SharedState instances."
raise RuntimeError(msg)
raise ReflexRuntimeError(msg)
if self._linked_to == token:
return self # already linked to this token
if self._linked_to and self._linked_to != token:
Expand Down Expand Up @@ -280,6 +312,18 @@ async def _internal_patch_linked_state(
BaseStateToken(ident=token, cls=type(self))
)
)
# Set client_token on the linked root so that subsequent get_state
# calls when directly modifying a linked token will load the
# associated instance.
if linked_root_state.router.session.client_token != token:
import dataclasses as dc

linked_root_state.router = dc.replace(
linked_root_state.router,
session=dc.replace(
linked_root_state.router.session, client_token=token
),
)
self._held_locks.setdefault(token, {})
else:
linked_root_state = await get_state_manager().get_state(
Expand Down Expand Up @@ -386,6 +430,22 @@ async def _modify_linked_states(
for token in linked_state._linked_from
if token != self.router.session.client_token
)
# When modifying a shared token directly (empty _reflex_internal_links),
# the held locks will be empty. Check SharedState substates for linked
# clients that need to be notified.
if not self._reflex_internal_links:
shared_state_base_internal = await self.get_state(
SharedStateBaseInternal
)
if not isinstance(
shared_state_base_internal, SharedStateBaseInternal
):
msg = "Expected SharedStateBaseInternal in substates."
raise ReflexRuntimeError(msg)
# Collect affected tokens from all potentially linked states.
shared_state_base_internal._collect_shared_token_updates(
affected_tokens, current_dirty_vars
)
finally:
self._exit_stack = None

Expand All @@ -397,6 +457,34 @@ async def _modify_linked_states(
state_type=type(self),
)

def _collect_shared_token_updates(
self,
affected_tokens: set[str],
current_dirty_vars: dict[str, set[str]],
) -> None:
"""Recursively collect dirty vars and linked clients from SharedState substates.

When a shared state is modified directly by its shared token (rather than
through a private client token), the held locks are empty so the normal
collection loop above finds nothing. This method recursively checks
SharedState substates for linked clients that need to be notified.

Args:
affected_tokens: Set to update with client tokens that need notification.
current_dirty_vars: Dict to update with dirty var mappings per state.
"""
for substate in self.substates.values():
if not isinstance(substate, SharedState):
continue
if substate._linked_from:
if substate._previous_dirty_vars:
current_dirty_vars[substate.get_full_name()] = set(
substate._previous_dirty_vars
)
if substate._get_was_touched() or substate._previous_dirty_vars:
affected_tokens.update(substate._linked_from)
substate._collect_shared_token_updates(affected_tokens, current_dirty_vars)


class SharedState(SharedStateBaseInternal, mixin=True):
"""Mixin for defining new shared states."""
Expand Down
14 changes: 5 additions & 9 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,6 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE:
Returns:
The instance of state_cls associated with this state's client_token.
"""
state_instance = await super()._get_state_from_redis(state_cls)
if (
self._reflex_internal_links
and (
Expand All @@ -2155,15 +2154,12 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE:
)
)
is not None
and (
internal_patch_linked_state := getattr(
state_instance, "_internal_patch_linked_state", None
)
)
is not None
):
return await internal_patch_linked_state(linked_token)
return state_instance
from reflex.istate.shared import SharedStateBaseInternal

shared_base = await self.get_state(SharedStateBaseInternal)
return await shared_base._resolve_linked_state(state_cls, linked_token) # type: ignore[return-value]
return await super()._get_state_from_redis(state_cls)

@event
async def hydrate(self) -> None:
Expand Down
Loading
Loading