reflex 0.4.2a1__py3-none-any.whl → 0.4.3a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflex might be problematic. Click here for more details.
- reflex/.templates/apps/blank/code/blank.py +1 -1
- reflex/.templates/apps/sidebar/README.md +3 -2
- reflex/.templates/apps/sidebar/assets/reflex_white.svg +8 -0
- reflex/.templates/apps/sidebar/code/components/sidebar.py +26 -22
- reflex/.templates/apps/sidebar/code/pages/dashboard.py +6 -5
- reflex/.templates/apps/sidebar/code/pages/settings.py +45 -6
- reflex/.templates/apps/sidebar/code/styles.py +15 -17
- reflex/.templates/apps/sidebar/code/templates/__init__.py +1 -1
- reflex/.templates/apps/sidebar/code/templates/template.py +54 -40
- reflex/.templates/jinja/custom_components/README.md.jinja2 +9 -0
- reflex/.templates/jinja/custom_components/__init__.py.jinja2 +1 -0
- reflex/.templates/jinja/custom_components/demo_app.py.jinja2 +36 -0
- reflex/.templates/jinja/custom_components/pyproject.toml.jinja2 +35 -0
- reflex/.templates/jinja/custom_components/src.py.jinja2 +57 -0
- reflex/.templates/jinja/web/utils/context.js.jinja2 +26 -6
- reflex/.templates/web/utils/state.js +206 -146
- reflex/app.py +21 -18
- reflex/compiler/compiler.py +6 -2
- reflex/compiler/templates.py +17 -0
- reflex/compiler/utils.py +2 -2
- reflex/components/core/__init__.py +2 -1
- reflex/components/core/banner.py +99 -11
- reflex/components/core/banner.pyi +215 -2
- reflex/components/el/elements/__init__.py +1 -0
- reflex/components/el/elements/forms.py +6 -0
- reflex/components/el/elements/forms.pyi +4 -0
- reflex/components/markdown/markdown.py +13 -25
- reflex/components/markdown/markdown.pyi +5 -5
- reflex/components/plotly/plotly.py +3 -0
- reflex/components/plotly/plotly.pyi +2 -0
- reflex/components/radix/primitives/drawer.py +3 -7
- reflex/components/radix/themes/components/select.py +4 -4
- reflex/components/radix/themes/components/text_field.pyi +4 -0
- reflex/constants/__init__.py +4 -0
- reflex/constants/colors.py +1 -0
- reflex/constants/compiler.py +4 -3
- reflex/constants/custom_components.py +30 -0
- reflex/custom_components/__init__.py +1 -0
- reflex/custom_components/custom_components.py +565 -0
- reflex/reflex.py +11 -2
- reflex/route.py +4 -0
- reflex/state.py +594 -124
- reflex/testing.py +6 -0
- reflex/utils/exec.py +9 -0
- reflex/utils/prerequisites.py +28 -2
- reflex/utils/telemetry.py +3 -1
- reflex/utils/types.py +23 -0
- reflex/vars.py +48 -17
- reflex/vars.pyi +8 -3
- {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/METADATA +4 -2
- {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/RECORD +55 -51
- {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/WHEEL +1 -1
- reflex/components/base/bare.pyi +0 -84
- reflex/constants/base.pyi +0 -94
- reflex/constants/event.pyi +0 -59
- reflex/constants/route.pyi +0 -50
- reflex/constants/style.pyi +0 -20
- /reflex/.templates/apps/sidebar/assets/{icon.svg → reflex_black.svg} +0 -0
- {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/LICENSE +0 -0
- {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/entry_points.txt +0 -0
reflex/state.py
CHANGED
|
@@ -8,7 +8,6 @@ import copy
|
|
|
8
8
|
import functools
|
|
9
9
|
import inspect
|
|
10
10
|
import json
|
|
11
|
-
import os
|
|
12
11
|
import traceback
|
|
13
12
|
import urllib.parse
|
|
14
13
|
import uuid
|
|
@@ -45,11 +44,12 @@ from reflex.event import (
|
|
|
45
44
|
)
|
|
46
45
|
from reflex.utils import console, format, prerequisites, types
|
|
47
46
|
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
|
47
|
+
from reflex.utils.exec import is_testing_env
|
|
48
48
|
from reflex.utils.serializers import SerializedType, serialize, serializer
|
|
49
|
-
from reflex.vars import BaseVar, ComputedVar, Var
|
|
49
|
+
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
|
50
50
|
|
|
51
51
|
Delta = Dict[str, Any]
|
|
52
|
-
var =
|
|
52
|
+
var = computed_var
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class HeaderData(Base):
|
|
@@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = {
|
|
|
151
151
|
"_substate_var_dependencies",
|
|
152
152
|
"_always_dirty_computed_vars",
|
|
153
153
|
"_always_dirty_substates",
|
|
154
|
+
"_was_touched",
|
|
154
155
|
}
|
|
155
156
|
|
|
156
157
|
|
|
158
|
+
def _substate_key(
|
|
159
|
+
token: str,
|
|
160
|
+
state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
|
|
161
|
+
) -> str:
|
|
162
|
+
"""Get the substate key.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
token: The token of the state.
|
|
166
|
+
state_cls_or_name: The state class/instance or name or sequence of name parts.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
The substate key.
|
|
170
|
+
"""
|
|
171
|
+
if isinstance(state_cls_or_name, BaseState) or (
|
|
172
|
+
isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
|
|
173
|
+
):
|
|
174
|
+
state_cls_or_name = state_cls_or_name.get_full_name()
|
|
175
|
+
elif isinstance(state_cls_or_name, (list, tuple)):
|
|
176
|
+
state_cls_or_name = ".".join(state_cls_or_name)
|
|
177
|
+
return f"{token}_{state_cls_or_name}"
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _split_substate_key(substate_key: str) -> tuple[str, str]:
|
|
181
|
+
"""Split the substate key into token and state name.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
substate_key: The substate key.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Tuple of token and state name.
|
|
188
|
+
"""
|
|
189
|
+
token, _, state_name = substate_key.partition("_")
|
|
190
|
+
return token, state_name
|
|
191
|
+
|
|
192
|
+
|
|
157
193
|
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
158
194
|
"""The state of the app."""
|
|
159
195
|
|
|
@@ -214,34 +250,57 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
214
250
|
# The router data for the current page
|
|
215
251
|
router: RouterData = RouterData()
|
|
216
252
|
|
|
253
|
+
# Whether the state has ever been touched since instantiation.
|
|
254
|
+
_was_touched: bool = False
|
|
255
|
+
|
|
217
256
|
def __init__(
|
|
218
257
|
self,
|
|
219
258
|
*args,
|
|
220
259
|
parent_state: BaseState | None = None,
|
|
221
260
|
init_substates: bool = True,
|
|
261
|
+
_reflex_internal_init: bool = False,
|
|
222
262
|
**kwargs,
|
|
223
263
|
):
|
|
224
264
|
"""Initialize the state.
|
|
225
265
|
|
|
266
|
+
DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
|
|
267
|
+
|
|
226
268
|
Args:
|
|
227
269
|
*args: The args to pass to the Pydantic init method.
|
|
228
270
|
parent_state: The parent state.
|
|
229
271
|
init_substates: Whether to initialize the substates in this instance.
|
|
272
|
+
_reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
|
|
230
273
|
**kwargs: The kwargs to pass to the Pydantic init method.
|
|
231
274
|
|
|
275
|
+
Raises:
|
|
276
|
+
RuntimeError: If the state is instantiated directly by end user.
|
|
232
277
|
"""
|
|
278
|
+
if not _reflex_internal_init and not is_testing_env():
|
|
279
|
+
raise RuntimeError(
|
|
280
|
+
"State classes should not be instantiated directly in a Reflex app. "
|
|
281
|
+
"See https://reflex.dev/docs/state/ for further information."
|
|
282
|
+
)
|
|
233
283
|
kwargs["parent_state"] = parent_state
|
|
234
284
|
super().__init__(*args, **kwargs)
|
|
235
285
|
|
|
236
286
|
# Setup the substates (for memory state manager only).
|
|
237
287
|
if init_substates:
|
|
238
288
|
for substate in self.get_substates():
|
|
239
|
-
self.substates[substate.get_name()] = substate(
|
|
289
|
+
self.substates[substate.get_name()] = substate(
|
|
290
|
+
parent_state=self,
|
|
291
|
+
_reflex_internal_init=True,
|
|
292
|
+
)
|
|
240
293
|
# Convert the event handlers to functions.
|
|
241
294
|
self._init_event_handlers()
|
|
242
295
|
|
|
243
296
|
# Create a fresh copy of the backend variables for this instance
|
|
244
|
-
self._backend_vars = copy.deepcopy(
|
|
297
|
+
self._backend_vars = copy.deepcopy(
|
|
298
|
+
{
|
|
299
|
+
name: item
|
|
300
|
+
for name, item in self.backend_vars.items()
|
|
301
|
+
if name not in self.computed_vars
|
|
302
|
+
}
|
|
303
|
+
)
|
|
245
304
|
|
|
246
305
|
def _init_event_handlers(self, state: BaseState | None = None):
|
|
247
306
|
"""Initialize event handlers.
|
|
@@ -277,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
277
336
|
"""
|
|
278
337
|
return f"{self.__class__.__name__}({self.dict()})"
|
|
279
338
|
|
|
339
|
+
@classmethod
|
|
340
|
+
def _get_computed_vars(cls) -> list[ComputedVar]:
|
|
341
|
+
"""Helper function to get all computed vars of a instance.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
A list of computed vars.
|
|
345
|
+
"""
|
|
346
|
+
return [
|
|
347
|
+
v
|
|
348
|
+
for mixin in cls.__mro__
|
|
349
|
+
if mixin is cls or not issubclass(mixin, (BaseState, ABC))
|
|
350
|
+
for v in mixin.__dict__.values()
|
|
351
|
+
if isinstance(v, ComputedVar)
|
|
352
|
+
]
|
|
353
|
+
|
|
280
354
|
@classmethod
|
|
281
355
|
def __init_subclass__(cls, **kwargs):
|
|
282
356
|
"""Do some magic for the subclass initialization.
|
|
@@ -287,7 +361,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
287
361
|
Raises:
|
|
288
362
|
ValueError: If a substate class shadows another.
|
|
289
363
|
"""
|
|
290
|
-
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
|
|
291
364
|
super().__init_subclass__(**kwargs)
|
|
292
365
|
# Event handlers should not shadow builtin state methods.
|
|
293
366
|
cls._check_overridden_methods()
|
|
@@ -295,6 +368,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
295
368
|
# Reset subclass tracking for this class.
|
|
296
369
|
cls.class_subclasses = set()
|
|
297
370
|
|
|
371
|
+
# Reset dirty substate tracking for this class.
|
|
372
|
+
cls._always_dirty_substates = set()
|
|
373
|
+
|
|
298
374
|
# Get the parent vars.
|
|
299
375
|
parent_state = cls.get_parent_state()
|
|
300
376
|
if parent_state is not None:
|
|
@@ -303,7 +379,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
303
379
|
|
|
304
380
|
# Check if another substate class with the same name has already been defined.
|
|
305
381
|
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
|
|
306
|
-
if is_testing_env:
|
|
382
|
+
if is_testing_env():
|
|
307
383
|
# Clear existing subclass with same name when app is reloaded via
|
|
308
384
|
# utils.prerequisites.get_app(reload=True)
|
|
309
385
|
parent_state.class_subclasses = set(
|
|
@@ -321,15 +397,32 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
321
397
|
# Track this new subclass in the parent state's subclasses set.
|
|
322
398
|
parent_state.class_subclasses.add(cls)
|
|
323
399
|
|
|
324
|
-
|
|
400
|
+
# Get computed vars.
|
|
401
|
+
computed_vars = cls._get_computed_vars()
|
|
402
|
+
|
|
403
|
+
new_backend_vars = {
|
|
325
404
|
name: value
|
|
326
405
|
for name, value in cls.__dict__.items()
|
|
327
406
|
if types.is_backend_variable(name, cls)
|
|
407
|
+
and name not in RESERVED_BACKEND_VAR_NAMES
|
|
328
408
|
and name not in cls.inherited_backend_vars
|
|
329
409
|
and not isinstance(value, FunctionType)
|
|
410
|
+
and not isinstance(value, ComputedVar)
|
|
330
411
|
}
|
|
331
412
|
|
|
332
|
-
|
|
413
|
+
# Get backend computed vars
|
|
414
|
+
backend_computed_vars = {
|
|
415
|
+
v._var_name: v._var_set_state(cls)
|
|
416
|
+
for v in computed_vars
|
|
417
|
+
if types.is_backend_variable(v._var_name, cls)
|
|
418
|
+
and v._var_name not in cls.inherited_backend_vars
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
cls.backend_vars = {
|
|
422
|
+
**cls.inherited_backend_vars,
|
|
423
|
+
**new_backend_vars,
|
|
424
|
+
**backend_computed_vars,
|
|
425
|
+
}
|
|
333
426
|
|
|
334
427
|
# Set the base and computed vars.
|
|
335
428
|
cls.base_vars = {
|
|
@@ -339,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
339
432
|
for f in cls.get_fields().values()
|
|
340
433
|
if f.name not in cls.get_skip_vars()
|
|
341
434
|
}
|
|
342
|
-
cls.computed_vars = {
|
|
343
|
-
v._var_name: v._var_set_state(cls)
|
|
344
|
-
for v in cls.__dict__.values()
|
|
345
|
-
if isinstance(v, ComputedVar)
|
|
346
|
-
}
|
|
435
|
+
cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars}
|
|
347
436
|
cls.vars = {
|
|
348
437
|
**cls.inherited_vars,
|
|
349
438
|
**cls.base_vars,
|
|
@@ -466,7 +555,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
466
555
|
# track that this substate depends on its parent for this var
|
|
467
556
|
state_name = cls.get_name()
|
|
468
557
|
parent_state = cls.get_parent_state()
|
|
469
|
-
while parent_state is not None and var in
|
|
558
|
+
while parent_state is not None and var in {
|
|
559
|
+
**parent_state.vars,
|
|
560
|
+
**parent_state.backend_vars,
|
|
561
|
+
}:
|
|
470
562
|
parent_state._substate_var_dependencies[var].add(state_name)
|
|
471
563
|
state_name, parent_state = (
|
|
472
564
|
parent_state.get_name(),
|
|
@@ -481,7 +573,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
481
573
|
)
|
|
482
574
|
|
|
483
575
|
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
|
484
|
-
cls._always_dirty_substates = set()
|
|
485
576
|
if cls._always_dirty_computed_vars:
|
|
486
577
|
# Tell parent classes that this substate has always dirty computed vars
|
|
487
578
|
state_name = cls.get_name()
|
|
@@ -920,8 +1011,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
920
1011
|
**super().__getattribute__("inherited_vars"),
|
|
921
1012
|
**super().__getattribute__("inherited_backend_vars"),
|
|
922
1013
|
}
|
|
923
|
-
|
|
924
|
-
|
|
1014
|
+
|
|
1015
|
+
# For now, handle router_data updates as a special case.
|
|
1016
|
+
if name in inherited_vars or name == constants.ROUTER_DATA:
|
|
1017
|
+
parent_state = super().__getattribute__("parent_state")
|
|
1018
|
+
if parent_state is not None:
|
|
1019
|
+
return getattr(parent_state, name)
|
|
925
1020
|
|
|
926
1021
|
backend_vars = super().__getattribute__("_backend_vars")
|
|
927
1022
|
if name in backend_vars:
|
|
@@ -977,9 +1072,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
977
1072
|
if name == constants.ROUTER_DATA:
|
|
978
1073
|
self.dirty_vars.add(name)
|
|
979
1074
|
self._mark_dirty()
|
|
980
|
-
# propagate router_data updates down the state tree
|
|
981
|
-
for substate in self.substates.values():
|
|
982
|
-
setattr(substate, name, value)
|
|
983
1075
|
|
|
984
1076
|
def reset(self):
|
|
985
1077
|
"""Reset all the base vars to their default values."""
|
|
@@ -988,7 +1080,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
988
1080
|
for prop_name in self.base_vars:
|
|
989
1081
|
if prop_name == constants.ROUTER:
|
|
990
1082
|
continue # never reset the router data
|
|
991
|
-
|
|
1083
|
+
field = fields[prop_name]
|
|
1084
|
+
if default_factory := field.default_factory:
|
|
1085
|
+
default = default_factory()
|
|
1086
|
+
else:
|
|
1087
|
+
default = copy.deepcopy(field.default)
|
|
1088
|
+
setattr(self, prop_name, default)
|
|
992
1089
|
|
|
993
1090
|
# Recursively reset the substates.
|
|
994
1091
|
for substate in self.substates.values():
|
|
@@ -1033,6 +1130,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1033
1130
|
raise ValueError(f"Invalid path: {path}")
|
|
1034
1131
|
return self.substates[path[0]].get_substate(path[1:])
|
|
1035
1132
|
|
|
1133
|
+
@classmethod
|
|
1134
|
+
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
|
1135
|
+
"""Find the name of the nearest common ancestor shared by this and the other state.
|
|
1136
|
+
|
|
1137
|
+
Args:
|
|
1138
|
+
other: The other state.
|
|
1139
|
+
|
|
1140
|
+
Returns:
|
|
1141
|
+
Full name of the nearest common ancestor.
|
|
1142
|
+
"""
|
|
1143
|
+
common_ancestor_parts = []
|
|
1144
|
+
for part1, part2 in zip(
|
|
1145
|
+
cls.get_full_name().split("."),
|
|
1146
|
+
other.get_full_name().split("."),
|
|
1147
|
+
):
|
|
1148
|
+
if part1 != part2:
|
|
1149
|
+
break
|
|
1150
|
+
common_ancestor_parts.append(part1)
|
|
1151
|
+
return ".".join(common_ancestor_parts)
|
|
1152
|
+
|
|
1153
|
+
@classmethod
|
|
1154
|
+
def _determine_missing_parent_states(
|
|
1155
|
+
cls, target_state_cls: Type[BaseState]
|
|
1156
|
+
) -> tuple[str, list[str]]:
|
|
1157
|
+
"""Determine the missing parent states between the target_state_cls and common ancestor of this state.
|
|
1158
|
+
|
|
1159
|
+
Args:
|
|
1160
|
+
target_state_cls: The class of the state to find missing parent states for.
|
|
1161
|
+
|
|
1162
|
+
Returns:
|
|
1163
|
+
The name of the common ancestor and the list of missing parent states.
|
|
1164
|
+
"""
|
|
1165
|
+
common_ancestor_name = cls._get_common_ancestor(target_state_cls)
|
|
1166
|
+
common_ancestor_parts = common_ancestor_name.split(".")
|
|
1167
|
+
target_state_parts = tuple(target_state_cls.get_full_name().split("."))
|
|
1168
|
+
relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
|
|
1169
|
+
|
|
1170
|
+
# Determine which parent states to fetch from the common ancestor down to the target_state_cls.
|
|
1171
|
+
fetch_parent_states = [common_ancestor_name]
|
|
1172
|
+
for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
|
|
1173
|
+
fetch_parent_states.append(
|
|
1174
|
+
".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
return common_ancestor_name, fetch_parent_states[1:-1]
|
|
1178
|
+
|
|
1179
|
+
def _get_parent_states(self) -> list[tuple[str, BaseState]]:
|
|
1180
|
+
"""Get all parent state instances up to the root of the state tree.
|
|
1181
|
+
|
|
1182
|
+
Returns:
|
|
1183
|
+
A list of tuples containing the name and the instance of each parent state.
|
|
1184
|
+
"""
|
|
1185
|
+
parent_states_with_name = []
|
|
1186
|
+
parent_state = self
|
|
1187
|
+
while parent_state.parent_state is not None:
|
|
1188
|
+
parent_state = parent_state.parent_state
|
|
1189
|
+
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
|
1190
|
+
return parent_states_with_name
|
|
1191
|
+
|
|
1192
|
+
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
|
1193
|
+
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
|
1194
|
+
|
|
1195
|
+
Args:
|
|
1196
|
+
target_state_cls: The class of the state to populate parent states for.
|
|
1197
|
+
|
|
1198
|
+
Returns:
|
|
1199
|
+
The parent state instance of target_state_cls.
|
|
1200
|
+
|
|
1201
|
+
Raises:
|
|
1202
|
+
RuntimeError: If redis is not used in this backend process.
|
|
1203
|
+
"""
|
|
1204
|
+
state_manager = get_state_manager()
|
|
1205
|
+
if not isinstance(state_manager, StateManagerRedis):
|
|
1206
|
+
raise RuntimeError(
|
|
1207
|
+
f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
|
|
1208
|
+
"(All states should already be available -- this is likely a bug).",
|
|
1209
|
+
)
|
|
1210
|
+
|
|
1211
|
+
# Find the missing parent states up to the common ancestor.
|
|
1212
|
+
(
|
|
1213
|
+
common_ancestor_name,
|
|
1214
|
+
missing_parent_states,
|
|
1215
|
+
) = self._determine_missing_parent_states(target_state_cls)
|
|
1216
|
+
|
|
1217
|
+
# Fetch all missing parent states and link them up to the common ancestor.
|
|
1218
|
+
parent_states_by_name = dict(self._get_parent_states())
|
|
1219
|
+
parent_state = parent_states_by_name[common_ancestor_name]
|
|
1220
|
+
for parent_state_name in missing_parent_states:
|
|
1221
|
+
parent_state = await state_manager.get_state(
|
|
1222
|
+
token=_substate_key(
|
|
1223
|
+
self.router.session.client_token, parent_state_name
|
|
1224
|
+
),
|
|
1225
|
+
top_level=False,
|
|
1226
|
+
get_substates=False,
|
|
1227
|
+
parent_state=parent_state,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
# Return the direct parent of target_state_cls for subsequent linking.
|
|
1231
|
+
return parent_state
|
|
1232
|
+
|
|
1233
|
+
def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
|
|
1234
|
+
"""Get a state instance from the cache.
|
|
1235
|
+
|
|
1236
|
+
Args:
|
|
1237
|
+
state_cls: The class of the state.
|
|
1238
|
+
|
|
1239
|
+
Returns:
|
|
1240
|
+
The instance of state_cls associated with this state's client_token.
|
|
1241
|
+
"""
|
|
1242
|
+
if self.parent_state is None:
|
|
1243
|
+
root_state = self
|
|
1244
|
+
else:
|
|
1245
|
+
root_state = self._get_parent_states()[-1][1]
|
|
1246
|
+
return root_state.get_substate(state_cls.get_full_name().split("."))
|
|
1247
|
+
|
|
1248
|
+
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
|
|
1249
|
+
"""Get a state instance from redis.
|
|
1250
|
+
|
|
1251
|
+
Args:
|
|
1252
|
+
state_cls: The class of the state.
|
|
1253
|
+
|
|
1254
|
+
Returns:
|
|
1255
|
+
The instance of state_cls associated with this state's client_token.
|
|
1256
|
+
|
|
1257
|
+
Raises:
|
|
1258
|
+
RuntimeError: If redis is not used in this backend process.
|
|
1259
|
+
"""
|
|
1260
|
+
# Fetch all missing parent states from redis.
|
|
1261
|
+
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
1262
|
+
|
|
1263
|
+
# Then get the target state and all its substates.
|
|
1264
|
+
state_manager = get_state_manager()
|
|
1265
|
+
if not isinstance(state_manager, StateManagerRedis):
|
|
1266
|
+
raise RuntimeError(
|
|
1267
|
+
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
1268
|
+
"(All states should already be available -- this is likely a bug).",
|
|
1269
|
+
)
|
|
1270
|
+
return await state_manager.get_state(
|
|
1271
|
+
token=_substate_key(self.router.session.client_token, state_cls),
|
|
1272
|
+
top_level=False,
|
|
1273
|
+
get_substates=True,
|
|
1274
|
+
parent_state=parent_state_of_state_cls,
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
|
|
1278
|
+
"""Get an instance of the state associated with this token.
|
|
1279
|
+
|
|
1280
|
+
Allows for arbitrary access to sibling states from within an event handler.
|
|
1281
|
+
|
|
1282
|
+
Args:
|
|
1283
|
+
state_cls: The class of the state.
|
|
1284
|
+
|
|
1285
|
+
Returns:
|
|
1286
|
+
The instance of state_cls associated with this state's client_token.
|
|
1287
|
+
"""
|
|
1288
|
+
# Fast case - if this state instance is already cached, get_substate from root state.
|
|
1289
|
+
try:
|
|
1290
|
+
return self._get_state_from_cache(state_cls)
|
|
1291
|
+
except ValueError:
|
|
1292
|
+
pass
|
|
1293
|
+
|
|
1294
|
+
# Slow case - fetch missing parent states from redis.
|
|
1295
|
+
return await self._get_state_from_redis(state_cls)
|
|
1296
|
+
|
|
1036
1297
|
def _get_event_handler(
|
|
1037
1298
|
self, event: Event
|
|
1038
1299
|
) -> tuple[BaseState | StateProxy, EventHandler]:
|
|
@@ -1235,6 +1496,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1235
1496
|
for cvar in self._computed_var_dependencies[dirty_var]
|
|
1236
1497
|
)
|
|
1237
1498
|
|
|
1499
|
+
@classmethod
|
|
1500
|
+
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
|
1501
|
+
"""Determine substates which could be affected by dirty vars in this state.
|
|
1502
|
+
|
|
1503
|
+
Returns:
|
|
1504
|
+
Set of State classes that may need to be fetched to recalc computed vars.
|
|
1505
|
+
"""
|
|
1506
|
+
# _always_dirty_substates need to be fetched to recalc computed vars.
|
|
1507
|
+
fetch_substates = set(
|
|
1508
|
+
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
1509
|
+
for substate_name in cls._always_dirty_substates
|
|
1510
|
+
)
|
|
1511
|
+
for dependent_substates in cls._substate_var_dependencies.values():
|
|
1512
|
+
fetch_substates.update(
|
|
1513
|
+
set(
|
|
1514
|
+
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
1515
|
+
for substate_name in dependent_substates
|
|
1516
|
+
)
|
|
1517
|
+
)
|
|
1518
|
+
return fetch_substates
|
|
1519
|
+
|
|
1238
1520
|
def get_delta(self) -> Delta:
|
|
1239
1521
|
"""Get the delta for the state.
|
|
1240
1522
|
|
|
@@ -1266,8 +1548,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1266
1548
|
# Recursively find the substate deltas.
|
|
1267
1549
|
substates = self.substates
|
|
1268
1550
|
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
|
1269
|
-
if substate not in substates:
|
|
1270
|
-
continue # substate not loaded at this time, no delta
|
|
1271
1551
|
delta.update(substates[substate].get_delta())
|
|
1272
1552
|
|
|
1273
1553
|
# Format the delta.
|
|
@@ -1289,20 +1569,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1289
1569
|
# have to mark computed vars dirty to allow access to newly computed
|
|
1290
1570
|
# values within the same ComputedVar function
|
|
1291
1571
|
self._mark_dirty_computed_vars()
|
|
1572
|
+
self._mark_dirty_substates()
|
|
1292
1573
|
|
|
1293
|
-
|
|
1574
|
+
def _mark_dirty_substates(self):
|
|
1575
|
+
"""Propagate dirty var / computed var status into substates."""
|
|
1294
1576
|
substates = self.substates
|
|
1295
1577
|
for var in self.dirty_vars:
|
|
1296
1578
|
for substate_name in self._substate_var_dependencies[var]:
|
|
1297
1579
|
self.dirty_substates.add(substate_name)
|
|
1298
|
-
if substate_name not in substates:
|
|
1299
|
-
continue
|
|
1300
1580
|
substate = substates[substate_name]
|
|
1301
1581
|
substate.dirty_vars.add(var)
|
|
1302
1582
|
substate._mark_dirty()
|
|
1303
1583
|
|
|
1584
|
+
def _update_was_touched(self):
|
|
1585
|
+
"""Update the _was_touched flag based on dirty_vars."""
|
|
1586
|
+
if self.dirty_vars and not self._was_touched:
|
|
1587
|
+
for var in self.dirty_vars:
|
|
1588
|
+
if var in self.base_vars or var in self._backend_vars:
|
|
1589
|
+
self._was_touched = True
|
|
1590
|
+
break
|
|
1591
|
+
|
|
1592
|
+
def _get_was_touched(self) -> bool:
|
|
1593
|
+
"""Check current dirty_vars and flag to determine if state instance was modified.
|
|
1594
|
+
|
|
1595
|
+
If any dirty vars belong to this state, mark _was_touched.
|
|
1596
|
+
|
|
1597
|
+
This flag determines whether this state instance should be persisted to redis.
|
|
1598
|
+
|
|
1599
|
+
Returns:
|
|
1600
|
+
Whether this state instance was ever modified.
|
|
1601
|
+
"""
|
|
1602
|
+
# Ensure the flag is up to date based on the current dirty_vars
|
|
1603
|
+
self._update_was_touched()
|
|
1604
|
+
return self._was_touched
|
|
1605
|
+
|
|
1304
1606
|
def _clean(self):
|
|
1305
1607
|
"""Reset the dirty vars."""
|
|
1608
|
+
# Update touched status before cleaning dirty_vars.
|
|
1609
|
+
self._update_was_touched()
|
|
1610
|
+
|
|
1306
1611
|
# Recursively clean the substates.
|
|
1307
1612
|
for substate in self.dirty_substates:
|
|
1308
1613
|
if substate not in self.substates:
|
|
@@ -1328,11 +1633,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1328
1633
|
return super().get_value(key.__wrapped__)
|
|
1329
1634
|
return super().get_value(key)
|
|
1330
1635
|
|
|
1331
|
-
def dict(
|
|
1636
|
+
def dict(
|
|
1637
|
+
self, include_computed: bool = True, initial: bool = False, **kwargs
|
|
1638
|
+
) -> dict[str, Any]:
|
|
1332
1639
|
"""Convert the object to a dictionary.
|
|
1333
1640
|
|
|
1334
1641
|
Args:
|
|
1335
1642
|
include_computed: Whether to include computed vars.
|
|
1643
|
+
initial: Whether to get the initial value of computed vars.
|
|
1336
1644
|
**kwargs: Kwargs to pass to the pydantic dict method.
|
|
1337
1645
|
|
|
1338
1646
|
Returns:
|
|
@@ -1348,21 +1656,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1348
1656
|
prop_name: self.get_value(getattr(self, prop_name))
|
|
1349
1657
|
for prop_name in self.base_vars
|
|
1350
1658
|
}
|
|
1351
|
-
|
|
1352
|
-
{
|
|
1659
|
+
if initial:
|
|
1660
|
+
computed_vars = {
|
|
1661
|
+
# Include initial computed vars.
|
|
1662
|
+
prop_name: cv._initial_value
|
|
1663
|
+
if isinstance(cv, ComputedVar)
|
|
1664
|
+
and not isinstance(cv._initial_value, types.Unset)
|
|
1665
|
+
else self.get_value(getattr(self, prop_name))
|
|
1666
|
+
for prop_name, cv in self.computed_vars.items()
|
|
1667
|
+
}
|
|
1668
|
+
elif include_computed:
|
|
1669
|
+
computed_vars = {
|
|
1353
1670
|
# Include the computed vars.
|
|
1354
1671
|
prop_name: self.get_value(getattr(self, prop_name))
|
|
1355
1672
|
for prop_name in self.computed_vars
|
|
1356
1673
|
}
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
)
|
|
1674
|
+
else:
|
|
1675
|
+
computed_vars = {}
|
|
1360
1676
|
variables = {**base_vars, **computed_vars}
|
|
1361
1677
|
d = {
|
|
1362
1678
|
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
|
1363
1679
|
}
|
|
1364
1680
|
for substate_d in [
|
|
1365
|
-
v.dict(include_computed=include_computed, **kwargs)
|
|
1681
|
+
v.dict(include_computed=include_computed, initial=initial, **kwargs)
|
|
1366
1682
|
for v in self.substates.values()
|
|
1367
1683
|
]:
|
|
1368
1684
|
d.update(substate_d)
|
|
@@ -1408,6 +1724,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
1408
1724
|
state["__dict__"] = state["__dict__"].copy()
|
|
1409
1725
|
state["__dict__"]["parent_state"] = None
|
|
1410
1726
|
state["__dict__"]["substates"] = {}
|
|
1727
|
+
state["__dict__"].pop("_was_touched", None)
|
|
1411
1728
|
return state
|
|
1412
1729
|
|
|
1413
1730
|
|
|
@@ -1417,6 +1734,35 @@ class State(BaseState):
|
|
|
1417
1734
|
# The hydrated bool.
|
|
1418
1735
|
is_hydrated: bool = False
|
|
1419
1736
|
|
|
1737
|
+
|
|
1738
|
+
class UpdateVarsInternalState(State):
|
|
1739
|
+
"""Substate for handling internal state var updates."""
|
|
1740
|
+
|
|
1741
|
+
async def update_vars_internal(self, vars: dict[str, Any]) -> None:
|
|
1742
|
+
"""Apply updates to fully qualified state vars.
|
|
1743
|
+
|
|
1744
|
+
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
|
|
1745
|
+
and each value will be set on the appropriate substate instance.
|
|
1746
|
+
|
|
1747
|
+
This function is primarily used to apply cookie and local storage
|
|
1748
|
+
updates from the frontend to the appropriate substate.
|
|
1749
|
+
|
|
1750
|
+
Args:
|
|
1751
|
+
vars: The fully qualified vars and values to update.
|
|
1752
|
+
"""
|
|
1753
|
+
for var, value in vars.items():
|
|
1754
|
+
state_name, _, var_name = var.rpartition(".")
|
|
1755
|
+
var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
|
|
1756
|
+
var_state = await self.get_state(var_state_cls)
|
|
1757
|
+
setattr(var_state, var_name, value)
|
|
1758
|
+
|
|
1759
|
+
|
|
1760
|
+
class OnLoadInternalState(State):
|
|
1761
|
+
"""Substate for handling on_load event enumeration.
|
|
1762
|
+
|
|
1763
|
+
This is a separate substate to avoid deserializing the entire state tree for every page navigation.
|
|
1764
|
+
"""
|
|
1765
|
+
|
|
1420
1766
|
def on_load_internal(self) -> list[Event | EventSpec] | None:
|
|
1421
1767
|
"""Queue on_load handlers for the current page.
|
|
1422
1768
|
|
|
@@ -1428,6 +1774,9 @@ class State(BaseState):
|
|
|
1428
1774
|
load_events = app.get_load_events(self.router.page.path)
|
|
1429
1775
|
if not load_events and self.is_hydrated:
|
|
1430
1776
|
return # Fast path for page-to-page navigation
|
|
1777
|
+
if not load_events:
|
|
1778
|
+
self.is_hydrated = True
|
|
1779
|
+
return # Fast path for initial hydrate with no on_load events defined.
|
|
1431
1780
|
self.is_hydrated = False
|
|
1432
1781
|
return [
|
|
1433
1782
|
*fix_events(
|
|
@@ -1435,26 +1784,9 @@ class State(BaseState):
|
|
|
1435
1784
|
self.router.session.client_token,
|
|
1436
1785
|
router_data=self.router_data,
|
|
1437
1786
|
),
|
|
1438
|
-
|
|
1787
|
+
State.set_is_hydrated(True), # type: ignore
|
|
1439
1788
|
]
|
|
1440
1789
|
|
|
1441
|
-
def update_vars_internal(self, vars: dict[str, Any]) -> None:
|
|
1442
|
-
"""Apply updates to fully qualified state vars.
|
|
1443
|
-
|
|
1444
|
-
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
|
|
1445
|
-
and each value will be set on the appropriate substate instance.
|
|
1446
|
-
|
|
1447
|
-
This function is primarily used to apply cookie and local storage
|
|
1448
|
-
updates from the frontend to the appropriate substate.
|
|
1449
|
-
|
|
1450
|
-
Args:
|
|
1451
|
-
vars: The fully qualified vars and values to update.
|
|
1452
|
-
"""
|
|
1453
|
-
for var, value in vars.items():
|
|
1454
|
-
state_name, _, var_name = var.rpartition(".")
|
|
1455
|
-
var_state = self.get_substate(state_name.split("."))
|
|
1456
|
-
setattr(var_state, var_name, value)
|
|
1457
|
-
|
|
1458
1790
|
|
|
1459
1791
|
class StateProxy(wrapt.ObjectProxy):
|
|
1460
1792
|
"""Proxy of a state instance to control mutability of vars for a background task.
|
|
@@ -1508,9 +1840,10 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
1508
1840
|
This StateProxy instance in mutable mode.
|
|
1509
1841
|
"""
|
|
1510
1842
|
self._self_actx = self._self_app.modify_state(
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1843
|
+
token=_substate_key(
|
|
1844
|
+
self.__wrapped__.router.session.client_token,
|
|
1845
|
+
self._self_substate_path,
|
|
1846
|
+
)
|
|
1514
1847
|
)
|
|
1515
1848
|
mutable_state = await self._self_actx.__aenter__()
|
|
1516
1849
|
super().__setattr__(
|
|
@@ -1560,7 +1893,15 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
1560
1893
|
|
|
1561
1894
|
Returns:
|
|
1562
1895
|
The value of the attribute.
|
|
1896
|
+
|
|
1897
|
+
Raises:
|
|
1898
|
+
ImmutableStateError: If the state is not in mutable mode.
|
|
1563
1899
|
"""
|
|
1900
|
+
if name in ["substates", "parent_state"] and not self._self_mutable:
|
|
1901
|
+
raise ImmutableStateError(
|
|
1902
|
+
"Background task StateProxy is immutable outside of a context "
|
|
1903
|
+
"manager. Use `async with self` to modify state."
|
|
1904
|
+
)
|
|
1564
1905
|
value = super().__getattr__(name)
|
|
1565
1906
|
if not name.startswith("_self_") and isinstance(value, MutableProxy):
|
|
1566
1907
|
# ensure mutations to these containers are blocked unless proxy is _mutable
|
|
@@ -1608,6 +1949,60 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
1608
1949
|
"manager. Use `async with self` to modify state."
|
|
1609
1950
|
)
|
|
1610
1951
|
|
|
1952
|
+
def get_substate(self, path: Sequence[str]) -> BaseState:
|
|
1953
|
+
"""Only allow substate access with lock held.
|
|
1954
|
+
|
|
1955
|
+
Args:
|
|
1956
|
+
path: The path to the substate.
|
|
1957
|
+
|
|
1958
|
+
Returns:
|
|
1959
|
+
The substate.
|
|
1960
|
+
|
|
1961
|
+
Raises:
|
|
1962
|
+
ImmutableStateError: If the state is not in mutable mode.
|
|
1963
|
+
"""
|
|
1964
|
+
if not self._self_mutable:
|
|
1965
|
+
raise ImmutableStateError(
|
|
1966
|
+
"Background task StateProxy is immutable outside of a context "
|
|
1967
|
+
"manager. Use `async with self` to modify state."
|
|
1968
|
+
)
|
|
1969
|
+
return self.__wrapped__.get_substate(path)
|
|
1970
|
+
|
|
1971
|
+
async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
|
|
1972
|
+
"""Get an instance of the state associated with this token.
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
state_cls: The class of the state.
|
|
1976
|
+
|
|
1977
|
+
Returns:
|
|
1978
|
+
The state.
|
|
1979
|
+
|
|
1980
|
+
Raises:
|
|
1981
|
+
ImmutableStateError: If the state is not in mutable mode.
|
|
1982
|
+
"""
|
|
1983
|
+
if not self._self_mutable:
|
|
1984
|
+
raise ImmutableStateError(
|
|
1985
|
+
"Background task StateProxy is immutable outside of a context "
|
|
1986
|
+
"manager. Use `async with self` to modify state."
|
|
1987
|
+
)
|
|
1988
|
+
return await self.__wrapped__.get_state(state_cls)
|
|
1989
|
+
|
|
1990
|
+
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
|
1991
|
+
"""Temporarily allow mutability to access parent_state.
|
|
1992
|
+
|
|
1993
|
+
Args:
|
|
1994
|
+
*args: The args to pass to the underlying state instance.
|
|
1995
|
+
**kwargs: The kwargs to pass to the underlying state instance.
|
|
1996
|
+
|
|
1997
|
+
Returns:
|
|
1998
|
+
The state update.
|
|
1999
|
+
"""
|
|
2000
|
+
self._self_mutable = True
|
|
2001
|
+
try:
|
|
2002
|
+
return self.__wrapped__._as_state_update(*args, **kwargs)
|
|
2003
|
+
finally:
|
|
2004
|
+
self._self_mutable = False
|
|
2005
|
+
|
|
1611
2006
|
|
|
1612
2007
|
class StateUpdate(Base):
|
|
1613
2008
|
"""A state update sent to the frontend."""
|
|
@@ -1708,9 +2103,9 @@ class StateManagerMemory(StateManager):
|
|
|
1708
2103
|
The state for the token.
|
|
1709
2104
|
"""
|
|
1710
2105
|
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
1711
|
-
token = token
|
|
2106
|
+
token = _split_substate_key(token)[0]
|
|
1712
2107
|
if token not in self.states:
|
|
1713
|
-
self.states[token] = self.state()
|
|
2108
|
+
self.states[token] = self.state(_reflex_internal_init=True)
|
|
1714
2109
|
return self.states[token]
|
|
1715
2110
|
|
|
1716
2111
|
async def set_state(self, token: str, state: BaseState):
|
|
@@ -1733,7 +2128,7 @@ class StateManagerMemory(StateManager):
|
|
|
1733
2128
|
The state for the token.
|
|
1734
2129
|
"""
|
|
1735
2130
|
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
1736
|
-
token = token
|
|
2131
|
+
token = _split_substate_key(token)[0]
|
|
1737
2132
|
if token not in self._states_locks:
|
|
1738
2133
|
async with self._state_manager_lock:
|
|
1739
2134
|
if token not in self._states_locks:
|
|
@@ -1773,6 +2168,81 @@ class StateManagerRedis(StateManager):
|
|
|
1773
2168
|
b"evicted",
|
|
1774
2169
|
}
|
|
1775
2170
|
|
|
2171
|
+
def _get_root_state(self, state: BaseState) -> BaseState:
|
|
2172
|
+
"""Chase parent_state pointers to find an instance of the top-level state.
|
|
2173
|
+
|
|
2174
|
+
Args:
|
|
2175
|
+
state: The state to start from.
|
|
2176
|
+
|
|
2177
|
+
Returns:
|
|
2178
|
+
An instance of the top-level state (self.state).
|
|
2179
|
+
"""
|
|
2180
|
+
while type(state) != self.state and state.parent_state is not None:
|
|
2181
|
+
state = state.parent_state
|
|
2182
|
+
return state
|
|
2183
|
+
|
|
2184
|
+
async def _get_parent_state(self, token: str) -> BaseState | None:
|
|
2185
|
+
"""Get the parent state for the state requested in the token.
|
|
2186
|
+
|
|
2187
|
+
Args:
|
|
2188
|
+
token: The token to get the state for (_substate_key).
|
|
2189
|
+
|
|
2190
|
+
Returns:
|
|
2191
|
+
The parent state for the state requested by the token or None if there is no such parent.
|
|
2192
|
+
"""
|
|
2193
|
+
parent_state = None
|
|
2194
|
+
client_token, state_path = _split_substate_key(token)
|
|
2195
|
+
parent_state_name = state_path.rpartition(".")[0]
|
|
2196
|
+
if parent_state_name:
|
|
2197
|
+
# Retrieve the parent state to populate event handlers onto this substate.
|
|
2198
|
+
parent_state = await self.get_state(
|
|
2199
|
+
token=_substate_key(client_token, parent_state_name),
|
|
2200
|
+
top_level=False,
|
|
2201
|
+
get_substates=False,
|
|
2202
|
+
)
|
|
2203
|
+
return parent_state
|
|
2204
|
+
|
|
2205
|
+
async def _populate_substates(
|
|
2206
|
+
self,
|
|
2207
|
+
token: str,
|
|
2208
|
+
state: BaseState,
|
|
2209
|
+
all_substates: bool = False,
|
|
2210
|
+
):
|
|
2211
|
+
"""Fetch and link substates for the given state instance.
|
|
2212
|
+
|
|
2213
|
+
There is no return value; the side-effect is that `state` will have `substates` populated,
|
|
2214
|
+
and each substate will have its `parent_state` set to `state`.
|
|
2215
|
+
|
|
2216
|
+
Args:
|
|
2217
|
+
token: The token to get the state for.
|
|
2218
|
+
state: The state instance to populate substates for.
|
|
2219
|
+
all_substates: Whether to fetch all substates or just required substates.
|
|
2220
|
+
"""
|
|
2221
|
+
client_token, _ = _split_substate_key(token)
|
|
2222
|
+
|
|
2223
|
+
if all_substates:
|
|
2224
|
+
# All substates are requested.
|
|
2225
|
+
fetch_substates = state.get_substates()
|
|
2226
|
+
else:
|
|
2227
|
+
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
|
2228
|
+
fetch_substates = state._potentially_dirty_substates()
|
|
2229
|
+
|
|
2230
|
+
tasks = {}
|
|
2231
|
+
# Retrieve the necessary substates from redis.
|
|
2232
|
+
for substate_cls in fetch_substates:
|
|
2233
|
+
substate_name = substate_cls.get_name()
|
|
2234
|
+
tasks[substate_name] = asyncio.create_task(
|
|
2235
|
+
self.get_state(
|
|
2236
|
+
token=_substate_key(client_token, substate_cls),
|
|
2237
|
+
top_level=False,
|
|
2238
|
+
get_substates=all_substates,
|
|
2239
|
+
parent_state=state,
|
|
2240
|
+
)
|
|
2241
|
+
)
|
|
2242
|
+
|
|
2243
|
+
for substate_name, substate_task in tasks.items():
|
|
2244
|
+
state.substates[substate_name] = await substate_task
|
|
2245
|
+
|
|
1776
2246
|
async def get_state(
|
|
1777
2247
|
self,
|
|
1778
2248
|
token: str,
|
|
@@ -1784,8 +2254,8 @@ class StateManagerRedis(StateManager):
|
|
|
1784
2254
|
|
|
1785
2255
|
Args:
|
|
1786
2256
|
token: The token to get the state for.
|
|
1787
|
-
top_level: If true, return an instance of the top-level state.
|
|
1788
|
-
get_substates: If true, also retrieve substates
|
|
2257
|
+
top_level: If true, return an instance of the top-level state (self.state).
|
|
2258
|
+
get_substates: If true, also retrieve substates.
|
|
1789
2259
|
parent_state: If provided, use this parent_state instead of getting it from redis.
|
|
1790
2260
|
|
|
1791
2261
|
Returns:
|
|
@@ -1795,7 +2265,7 @@ class StateManagerRedis(StateManager):
|
|
|
1795
2265
|
RuntimeError: when the state_cls is not specified in the token
|
|
1796
2266
|
"""
|
|
1797
2267
|
# Split the actual token from the fully qualified substate name.
|
|
1798
|
-
|
|
2268
|
+
_, state_path = _split_substate_key(token)
|
|
1799
2269
|
if state_path:
|
|
1800
2270
|
# Get the State class associated with the given path.
|
|
1801
2271
|
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
|
|
@@ -1811,66 +2281,49 @@ class StateManagerRedis(StateManager):
|
|
|
1811
2281
|
# Deserialize the substate.
|
|
1812
2282
|
state = cloudpickle.loads(redis_state)
|
|
1813
2283
|
|
|
1814
|
-
# Populate parent
|
|
2284
|
+
# Populate parent state if missing and requested.
|
|
1815
2285
|
if parent_state is None:
|
|
1816
|
-
|
|
1817
|
-
parent_state_name = state_path.rpartition(".")[0]
|
|
1818
|
-
if parent_state_name:
|
|
1819
|
-
parent_state_key = token.rpartition(".")[0]
|
|
1820
|
-
parent_state = await self.get_state(
|
|
1821
|
-
parent_state_key, top_level=False, get_substates=False
|
|
1822
|
-
)
|
|
2286
|
+
parent_state = await self._get_parent_state(token)
|
|
1823
2287
|
# Set up Bidirectional linkage between this state and its parent.
|
|
1824
2288
|
if parent_state is not None:
|
|
1825
2289
|
parent_state.substates[state.get_name()] = state
|
|
1826
2290
|
state.parent_state = parent_state
|
|
1827
|
-
if
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
substate_name = substate_cls.get_name()
|
|
1831
|
-
substate_key = token + "." + substate_name
|
|
1832
|
-
state.substates[substate_name] = await self.get_state(
|
|
1833
|
-
substate_key, top_level=False, parent_state=state
|
|
1834
|
-
)
|
|
2291
|
+
# Populate substates if requested.
|
|
2292
|
+
await self._populate_substates(token, state, all_substates=get_substates)
|
|
2293
|
+
|
|
1835
2294
|
# To retain compatibility with previous implementation, by default, we return
|
|
1836
2295
|
# the top-level state by chasing `parent_state` pointers up the tree.
|
|
1837
2296
|
if top_level:
|
|
1838
|
-
|
|
1839
|
-
state = state.parent_state
|
|
2297
|
+
return self._get_root_state(state)
|
|
1840
2298
|
return state
|
|
1841
2299
|
|
|
1842
|
-
#
|
|
2300
|
+
# TODO: dedupe the following logic with the above block
|
|
2301
|
+
# Key didn't exist so we have to create a new instance for this token.
|
|
1843
2302
|
if parent_state is None:
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
parent_state_key = client_token + "_" + parent_state_name
|
|
1848
|
-
parent_state = await self.get_state(
|
|
1849
|
-
parent_state_key, top_level=False, get_substates=False
|
|
1850
|
-
)
|
|
1851
|
-
# Persist the new state class to redis.
|
|
1852
|
-
await self.set_state(
|
|
1853
|
-
token,
|
|
1854
|
-
state_cls(
|
|
1855
|
-
parent_state=parent_state,
|
|
1856
|
-
init_substates=False,
|
|
1857
|
-
),
|
|
1858
|
-
)
|
|
1859
|
-
# After creating the state key, recursively call `get_state` to populate substates.
|
|
1860
|
-
return await self.get_state(
|
|
1861
|
-
token,
|
|
1862
|
-
top_level=top_level,
|
|
1863
|
-
get_substates=get_substates,
|
|
2303
|
+
parent_state = await self._get_parent_state(token)
|
|
2304
|
+
# Instantiate the new state class (but don't persist it yet).
|
|
2305
|
+
state = state_cls(
|
|
1864
2306
|
parent_state=parent_state,
|
|
2307
|
+
init_substates=False,
|
|
2308
|
+
_reflex_internal_init=True,
|
|
1865
2309
|
)
|
|
2310
|
+
# Set up Bidirectional linkage between this state and its parent.
|
|
2311
|
+
if parent_state is not None:
|
|
2312
|
+
parent_state.substates[state.get_name()] = state
|
|
2313
|
+
state.parent_state = parent_state
|
|
2314
|
+
# Populate substates for the newly created state.
|
|
2315
|
+
await self._populate_substates(token, state, all_substates=get_substates)
|
|
2316
|
+
# To retain compatibility with previous implementation, by default, we return
|
|
2317
|
+
# the top-level state by chasing `parent_state` pointers up the tree.
|
|
2318
|
+
if top_level:
|
|
2319
|
+
return self._get_root_state(state)
|
|
2320
|
+
return state
|
|
1866
2321
|
|
|
1867
2322
|
async def set_state(
|
|
1868
2323
|
self,
|
|
1869
2324
|
token: str,
|
|
1870
2325
|
state: BaseState,
|
|
1871
2326
|
lock_id: bytes | None = None,
|
|
1872
|
-
set_substates: bool = True,
|
|
1873
|
-
set_parent_state: bool = True,
|
|
1874
2327
|
):
|
|
1875
2328
|
"""Set the state for a token.
|
|
1876
2329
|
|
|
@@ -1878,11 +2331,10 @@ class StateManagerRedis(StateManager):
|
|
|
1878
2331
|
token: The token to set the state for.
|
|
1879
2332
|
state: The state to set.
|
|
1880
2333
|
lock_id: If provided, the lock_key must be set to this value to set the state.
|
|
1881
|
-
set_substates: If True, write substates to redis
|
|
1882
|
-
set_parent_state: If True, write parent state to redis
|
|
1883
2334
|
|
|
1884
2335
|
Raises:
|
|
1885
2336
|
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
|
2337
|
+
RuntimeError: If the state instance doesn't match the state name in the token.
|
|
1886
2338
|
"""
|
|
1887
2339
|
# Check that we're holding the lock.
|
|
1888
2340
|
if (
|
|
@@ -1894,28 +2346,36 @@ class StateManagerRedis(StateManager):
|
|
|
1894
2346
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
1895
2347
|
"or use `@rx.background` decorator for long-running tasks."
|
|
1896
2348
|
)
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
if
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
if state.parent_state is not None and set_parent_state:
|
|
1903
|
-
parent_state_key = token.rpartition(".")[0]
|
|
1904
|
-
await self.set_state(
|
|
1905
|
-
parent_state_key,
|
|
1906
|
-
state.parent_state,
|
|
1907
|
-
lock_id=lock_id,
|
|
1908
|
-
set_substates=False,
|
|
2349
|
+
client_token, substate_name = _split_substate_key(token)
|
|
2350
|
+
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
|
2351
|
+
if state.parent_state is not None and state.get_full_name() != substate_name:
|
|
2352
|
+
raise RuntimeError(
|
|
2353
|
+
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
|
1909
2354
|
)
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
2355
|
+
|
|
2356
|
+
# Recursively set_state on all known substates.
|
|
2357
|
+
tasks = []
|
|
2358
|
+
for substate in state.substates.values():
|
|
2359
|
+
tasks.append(
|
|
2360
|
+
asyncio.create_task(
|
|
2361
|
+
self.set_state(
|
|
2362
|
+
token=_substate_key(client_token, substate),
|
|
2363
|
+
state=substate,
|
|
2364
|
+
lock_id=lock_id,
|
|
2365
|
+
)
|
|
1916
2366
|
)
|
|
2367
|
+
)
|
|
1917
2368
|
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
1918
|
-
|
|
2369
|
+
if state._get_was_touched():
|
|
2370
|
+
await self.redis.set(
|
|
2371
|
+
_substate_key(client_token, state),
|
|
2372
|
+
cloudpickle.dumps(state),
|
|
2373
|
+
ex=self.token_expiration,
|
|
2374
|
+
)
|
|
2375
|
+
|
|
2376
|
+
# Wait for substates to be persisted.
|
|
2377
|
+
for t in tasks:
|
|
2378
|
+
await t
|
|
1919
2379
|
|
|
1920
2380
|
@contextlib.asynccontextmanager
|
|
1921
2381
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
@@ -1943,7 +2403,7 @@ class StateManagerRedis(StateManager):
|
|
|
1943
2403
|
The redis lock key for the token.
|
|
1944
2404
|
"""
|
|
1945
2405
|
# All substates share the same lock domain, so ignore any substate path suffix.
|
|
1946
|
-
client_token = token
|
|
2406
|
+
client_token = _split_substate_key(token)[0]
|
|
1947
2407
|
return f"{client_token}_lock".encode()
|
|
1948
2408
|
|
|
1949
2409
|
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
|
@@ -2038,6 +2498,16 @@ class StateManagerRedis(StateManager):
|
|
|
2038
2498
|
await self.redis.close(close_connection_pool=True)
|
|
2039
2499
|
|
|
2040
2500
|
|
|
2501
|
+
def get_state_manager() -> StateManager:
|
|
2502
|
+
"""Get the state manager for the app that is currently running.
|
|
2503
|
+
|
|
2504
|
+
Returns:
|
|
2505
|
+
The state manager.
|
|
2506
|
+
"""
|
|
2507
|
+
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
|
|
2508
|
+
return app.state_manager
|
|
2509
|
+
|
|
2510
|
+
|
|
2041
2511
|
class ClientStorageBase:
|
|
2042
2512
|
"""Base class for client-side storage."""
|
|
2043
2513
|
|