reflex 0.4.2a1__py3-none-any.whl → 0.4.3a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflex might be problematic. Click here for more details.

Files changed (60) hide show
  1. reflex/.templates/apps/blank/code/blank.py +1 -1
  2. reflex/.templates/apps/sidebar/README.md +3 -2
  3. reflex/.templates/apps/sidebar/assets/reflex_white.svg +8 -0
  4. reflex/.templates/apps/sidebar/code/components/sidebar.py +26 -22
  5. reflex/.templates/apps/sidebar/code/pages/dashboard.py +6 -5
  6. reflex/.templates/apps/sidebar/code/pages/settings.py +45 -6
  7. reflex/.templates/apps/sidebar/code/styles.py +15 -17
  8. reflex/.templates/apps/sidebar/code/templates/__init__.py +1 -1
  9. reflex/.templates/apps/sidebar/code/templates/template.py +54 -40
  10. reflex/.templates/jinja/custom_components/README.md.jinja2 +9 -0
  11. reflex/.templates/jinja/custom_components/__init__.py.jinja2 +1 -0
  12. reflex/.templates/jinja/custom_components/demo_app.py.jinja2 +36 -0
  13. reflex/.templates/jinja/custom_components/pyproject.toml.jinja2 +35 -0
  14. reflex/.templates/jinja/custom_components/src.py.jinja2 +57 -0
  15. reflex/.templates/jinja/web/utils/context.js.jinja2 +26 -6
  16. reflex/.templates/web/utils/state.js +206 -146
  17. reflex/app.py +21 -18
  18. reflex/compiler/compiler.py +6 -2
  19. reflex/compiler/templates.py +17 -0
  20. reflex/compiler/utils.py +2 -2
  21. reflex/components/core/__init__.py +2 -1
  22. reflex/components/core/banner.py +99 -11
  23. reflex/components/core/banner.pyi +215 -2
  24. reflex/components/el/elements/__init__.py +1 -0
  25. reflex/components/el/elements/forms.py +6 -0
  26. reflex/components/el/elements/forms.pyi +4 -0
  27. reflex/components/markdown/markdown.py +13 -25
  28. reflex/components/markdown/markdown.pyi +5 -5
  29. reflex/components/plotly/plotly.py +3 -0
  30. reflex/components/plotly/plotly.pyi +2 -0
  31. reflex/components/radix/primitives/drawer.py +3 -7
  32. reflex/components/radix/themes/components/select.py +4 -4
  33. reflex/components/radix/themes/components/text_field.pyi +4 -0
  34. reflex/constants/__init__.py +4 -0
  35. reflex/constants/colors.py +1 -0
  36. reflex/constants/compiler.py +4 -3
  37. reflex/constants/custom_components.py +30 -0
  38. reflex/custom_components/__init__.py +1 -0
  39. reflex/custom_components/custom_components.py +565 -0
  40. reflex/reflex.py +11 -2
  41. reflex/route.py +4 -0
  42. reflex/state.py +594 -124
  43. reflex/testing.py +6 -0
  44. reflex/utils/exec.py +9 -0
  45. reflex/utils/prerequisites.py +28 -2
  46. reflex/utils/telemetry.py +3 -1
  47. reflex/utils/types.py +23 -0
  48. reflex/vars.py +48 -17
  49. reflex/vars.pyi +8 -3
  50. {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/METADATA +4 -2
  51. {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/RECORD +55 -51
  52. {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/WHEEL +1 -1
  53. reflex/components/base/bare.pyi +0 -84
  54. reflex/constants/base.pyi +0 -94
  55. reflex/constants/event.pyi +0 -59
  56. reflex/constants/route.pyi +0 -50
  57. reflex/constants/style.pyi +0 -20
  58. /reflex/.templates/apps/sidebar/assets/{icon.svg → reflex_black.svg} +0 -0
  59. {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/LICENSE +0 -0
  60. {reflex-0.4.2a1.dist-info → reflex-0.4.3a2.dist-info}/entry_points.txt +0 -0
reflex/state.py CHANGED
@@ -8,7 +8,6 @@ import copy
8
8
  import functools
9
9
  import inspect
10
10
  import json
11
- import os
12
11
  import traceback
13
12
  import urllib.parse
14
13
  import uuid
@@ -45,11 +44,12 @@ from reflex.event import (
45
44
  )
46
45
  from reflex.utils import console, format, prerequisites, types
47
46
  from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
47
+ from reflex.utils.exec import is_testing_env
48
48
  from reflex.utils.serializers import SerializedType, serialize, serializer
49
- from reflex.vars import BaseVar, ComputedVar, Var
49
+ from reflex.vars import BaseVar, ComputedVar, Var, computed_var
50
50
 
51
51
  Delta = Dict[str, Any]
52
- var = ComputedVar
52
+ var = computed_var
53
53
 
54
54
 
55
55
  class HeaderData(Base):
@@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = {
151
151
  "_substate_var_dependencies",
152
152
  "_always_dirty_computed_vars",
153
153
  "_always_dirty_substates",
154
+ "_was_touched",
154
155
  }
155
156
 
156
157
 
158
+ def _substate_key(
159
+ token: str,
160
+ state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
161
+ ) -> str:
162
+ """Get the substate key.
163
+
164
+ Args:
165
+ token: The token of the state.
166
+ state_cls_or_name: The state class/instance or name or sequence of name parts.
167
+
168
+ Returns:
169
+ The substate key.
170
+ """
171
+ if isinstance(state_cls_or_name, BaseState) or (
172
+ isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
173
+ ):
174
+ state_cls_or_name = state_cls_or_name.get_full_name()
175
+ elif isinstance(state_cls_or_name, (list, tuple)):
176
+ state_cls_or_name = ".".join(state_cls_or_name)
177
+ return f"{token}_{state_cls_or_name}"
178
+
179
+
180
+ def _split_substate_key(substate_key: str) -> tuple[str, str]:
181
+ """Split the substate key into token and state name.
182
+
183
+ Args:
184
+ substate_key: The substate key.
185
+
186
+ Returns:
187
+ Tuple of token and state name.
188
+ """
189
+ token, _, state_name = substate_key.partition("_")
190
+ return token, state_name
191
+
192
+
157
193
  class BaseState(Base, ABC, extra=pydantic.Extra.allow):
158
194
  """The state of the app."""
159
195
 
@@ -214,34 +250,57 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
214
250
  # The router data for the current page
215
251
  router: RouterData = RouterData()
216
252
 
253
+ # Whether the state has ever been touched since instantiation.
254
+ _was_touched: bool = False
255
+
217
256
  def __init__(
218
257
  self,
219
258
  *args,
220
259
  parent_state: BaseState | None = None,
221
260
  init_substates: bool = True,
261
+ _reflex_internal_init: bool = False,
222
262
  **kwargs,
223
263
  ):
224
264
  """Initialize the state.
225
265
 
266
+ DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
267
+
226
268
  Args:
227
269
  *args: The args to pass to the Pydantic init method.
228
270
  parent_state: The parent state.
229
271
  init_substates: Whether to initialize the substates in this instance.
272
+ _reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
230
273
  **kwargs: The kwargs to pass to the Pydantic init method.
231
274
 
275
+ Raises:
276
+ RuntimeError: If the state is instantiated directly by end user.
232
277
  """
278
+ if not _reflex_internal_init and not is_testing_env():
279
+ raise RuntimeError(
280
+ "State classes should not be instantiated directly in a Reflex app. "
281
+ "See https://reflex.dev/docs/state/ for further information."
282
+ )
233
283
  kwargs["parent_state"] = parent_state
234
284
  super().__init__(*args, **kwargs)
235
285
 
236
286
  # Setup the substates (for memory state manager only).
237
287
  if init_substates:
238
288
  for substate in self.get_substates():
239
- self.substates[substate.get_name()] = substate(parent_state=self)
289
+ self.substates[substate.get_name()] = substate(
290
+ parent_state=self,
291
+ _reflex_internal_init=True,
292
+ )
240
293
  # Convert the event handlers to functions.
241
294
  self._init_event_handlers()
242
295
 
243
296
  # Create a fresh copy of the backend variables for this instance
244
- self._backend_vars = copy.deepcopy(self.backend_vars)
297
+ self._backend_vars = copy.deepcopy(
298
+ {
299
+ name: item
300
+ for name, item in self.backend_vars.items()
301
+ if name not in self.computed_vars
302
+ }
303
+ )
245
304
 
246
305
  def _init_event_handlers(self, state: BaseState | None = None):
247
306
  """Initialize event handlers.
@@ -277,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
277
336
  """
278
337
  return f"{self.__class__.__name__}({self.dict()})"
279
338
 
339
+ @classmethod
340
+ def _get_computed_vars(cls) -> list[ComputedVar]:
341
+ """Helper function to get all computed vars of a instance.
342
+
343
+ Returns:
344
+ A list of computed vars.
345
+ """
346
+ return [
347
+ v
348
+ for mixin in cls.__mro__
349
+ if mixin is cls or not issubclass(mixin, (BaseState, ABC))
350
+ for v in mixin.__dict__.values()
351
+ if isinstance(v, ComputedVar)
352
+ ]
353
+
280
354
  @classmethod
281
355
  def __init_subclass__(cls, **kwargs):
282
356
  """Do some magic for the subclass initialization.
@@ -287,7 +361,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
287
361
  Raises:
288
362
  ValueError: If a substate class shadows another.
289
363
  """
290
- is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
291
364
  super().__init_subclass__(**kwargs)
292
365
  # Event handlers should not shadow builtin state methods.
293
366
  cls._check_overridden_methods()
@@ -295,6 +368,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
295
368
  # Reset subclass tracking for this class.
296
369
  cls.class_subclasses = set()
297
370
 
371
+ # Reset dirty substate tracking for this class.
372
+ cls._always_dirty_substates = set()
373
+
298
374
  # Get the parent vars.
299
375
  parent_state = cls.get_parent_state()
300
376
  if parent_state is not None:
@@ -303,7 +379,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
303
379
 
304
380
  # Check if another substate class with the same name has already been defined.
305
381
  if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
306
- if is_testing_env:
382
+ if is_testing_env():
307
383
  # Clear existing subclass with same name when app is reloaded via
308
384
  # utils.prerequisites.get_app(reload=True)
309
385
  parent_state.class_subclasses = set(
@@ -321,15 +397,32 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
321
397
  # Track this new subclass in the parent state's subclasses set.
322
398
  parent_state.class_subclasses.add(cls)
323
399
 
324
- cls.new_backend_vars = {
400
+ # Get computed vars.
401
+ computed_vars = cls._get_computed_vars()
402
+
403
+ new_backend_vars = {
325
404
  name: value
326
405
  for name, value in cls.__dict__.items()
327
406
  if types.is_backend_variable(name, cls)
407
+ and name not in RESERVED_BACKEND_VAR_NAMES
328
408
  and name not in cls.inherited_backend_vars
329
409
  and not isinstance(value, FunctionType)
410
+ and not isinstance(value, ComputedVar)
330
411
  }
331
412
 
332
- cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
413
+ # Get backend computed vars
414
+ backend_computed_vars = {
415
+ v._var_name: v._var_set_state(cls)
416
+ for v in computed_vars
417
+ if types.is_backend_variable(v._var_name, cls)
418
+ and v._var_name not in cls.inherited_backend_vars
419
+ }
420
+
421
+ cls.backend_vars = {
422
+ **cls.inherited_backend_vars,
423
+ **new_backend_vars,
424
+ **backend_computed_vars,
425
+ }
333
426
 
334
427
  # Set the base and computed vars.
335
428
  cls.base_vars = {
@@ -339,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
339
432
  for f in cls.get_fields().values()
340
433
  if f.name not in cls.get_skip_vars()
341
434
  }
342
- cls.computed_vars = {
343
- v._var_name: v._var_set_state(cls)
344
- for v in cls.__dict__.values()
345
- if isinstance(v, ComputedVar)
346
- }
435
+ cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars}
347
436
  cls.vars = {
348
437
  **cls.inherited_vars,
349
438
  **cls.base_vars,
@@ -466,7 +555,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
466
555
  # track that this substate depends on its parent for this var
467
556
  state_name = cls.get_name()
468
557
  parent_state = cls.get_parent_state()
469
- while parent_state is not None and var in parent_state.vars:
558
+ while parent_state is not None and var in {
559
+ **parent_state.vars,
560
+ **parent_state.backend_vars,
561
+ }:
470
562
  parent_state._substate_var_dependencies[var].add(state_name)
471
563
  state_name, parent_state = (
472
564
  parent_state.get_name(),
@@ -481,7 +573,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
481
573
  )
482
574
 
483
575
  # Any substate containing a ComputedVar with cache=False always needs to be recomputed
484
- cls._always_dirty_substates = set()
485
576
  if cls._always_dirty_computed_vars:
486
577
  # Tell parent classes that this substate has always dirty computed vars
487
578
  state_name = cls.get_name()
@@ -920,8 +1011,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
920
1011
  **super().__getattribute__("inherited_vars"),
921
1012
  **super().__getattribute__("inherited_backend_vars"),
922
1013
  }
923
- if name in inherited_vars:
924
- return getattr(super().__getattribute__("parent_state"), name)
1014
+
1015
+ # For now, handle router_data updates as a special case.
1016
+ if name in inherited_vars or name == constants.ROUTER_DATA:
1017
+ parent_state = super().__getattribute__("parent_state")
1018
+ if parent_state is not None:
1019
+ return getattr(parent_state, name)
925
1020
 
926
1021
  backend_vars = super().__getattribute__("_backend_vars")
927
1022
  if name in backend_vars:
@@ -977,9 +1072,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
977
1072
  if name == constants.ROUTER_DATA:
978
1073
  self.dirty_vars.add(name)
979
1074
  self._mark_dirty()
980
- # propagate router_data updates down the state tree
981
- for substate in self.substates.values():
982
- setattr(substate, name, value)
983
1075
 
984
1076
  def reset(self):
985
1077
  """Reset all the base vars to their default values."""
@@ -988,7 +1080,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
988
1080
  for prop_name in self.base_vars:
989
1081
  if prop_name == constants.ROUTER:
990
1082
  continue # never reset the router data
991
- setattr(self, prop_name, copy.deepcopy(fields[prop_name].default))
1083
+ field = fields[prop_name]
1084
+ if default_factory := field.default_factory:
1085
+ default = default_factory()
1086
+ else:
1087
+ default = copy.deepcopy(field.default)
1088
+ setattr(self, prop_name, default)
992
1089
 
993
1090
  # Recursively reset the substates.
994
1091
  for substate in self.substates.values():
@@ -1033,6 +1130,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1033
1130
  raise ValueError(f"Invalid path: {path}")
1034
1131
  return self.substates[path[0]].get_substate(path[1:])
1035
1132
 
1133
+ @classmethod
1134
+ def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
1135
+ """Find the name of the nearest common ancestor shared by this and the other state.
1136
+
1137
+ Args:
1138
+ other: The other state.
1139
+
1140
+ Returns:
1141
+ Full name of the nearest common ancestor.
1142
+ """
1143
+ common_ancestor_parts = []
1144
+ for part1, part2 in zip(
1145
+ cls.get_full_name().split("."),
1146
+ other.get_full_name().split("."),
1147
+ ):
1148
+ if part1 != part2:
1149
+ break
1150
+ common_ancestor_parts.append(part1)
1151
+ return ".".join(common_ancestor_parts)
1152
+
1153
+ @classmethod
1154
+ def _determine_missing_parent_states(
1155
+ cls, target_state_cls: Type[BaseState]
1156
+ ) -> tuple[str, list[str]]:
1157
+ """Determine the missing parent states between the target_state_cls and common ancestor of this state.
1158
+
1159
+ Args:
1160
+ target_state_cls: The class of the state to find missing parent states for.
1161
+
1162
+ Returns:
1163
+ The name of the common ancestor and the list of missing parent states.
1164
+ """
1165
+ common_ancestor_name = cls._get_common_ancestor(target_state_cls)
1166
+ common_ancestor_parts = common_ancestor_name.split(".")
1167
+ target_state_parts = tuple(target_state_cls.get_full_name().split("."))
1168
+ relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
1169
+
1170
+ # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
1171
+ fetch_parent_states = [common_ancestor_name]
1172
+ for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
1173
+ fetch_parent_states.append(
1174
+ ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
1175
+ )
1176
+
1177
+ return common_ancestor_name, fetch_parent_states[1:-1]
1178
+
1179
+ def _get_parent_states(self) -> list[tuple[str, BaseState]]:
1180
+ """Get all parent state instances up to the root of the state tree.
1181
+
1182
+ Returns:
1183
+ A list of tuples containing the name and the instance of each parent state.
1184
+ """
1185
+ parent_states_with_name = []
1186
+ parent_state = self
1187
+ while parent_state.parent_state is not None:
1188
+ parent_state = parent_state.parent_state
1189
+ parent_states_with_name.append((parent_state.get_full_name(), parent_state))
1190
+ return parent_states_with_name
1191
+
1192
+ async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
1193
+ """Populate substates in the tree between the target_state_cls and common ancestor of this state.
1194
+
1195
+ Args:
1196
+ target_state_cls: The class of the state to populate parent states for.
1197
+
1198
+ Returns:
1199
+ The parent state instance of target_state_cls.
1200
+
1201
+ Raises:
1202
+ RuntimeError: If redis is not used in this backend process.
1203
+ """
1204
+ state_manager = get_state_manager()
1205
+ if not isinstance(state_manager, StateManagerRedis):
1206
+ raise RuntimeError(
1207
+ f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
1208
+ "(All states should already be available -- this is likely a bug).",
1209
+ )
1210
+
1211
+ # Find the missing parent states up to the common ancestor.
1212
+ (
1213
+ common_ancestor_name,
1214
+ missing_parent_states,
1215
+ ) = self._determine_missing_parent_states(target_state_cls)
1216
+
1217
+ # Fetch all missing parent states and link them up to the common ancestor.
1218
+ parent_states_by_name = dict(self._get_parent_states())
1219
+ parent_state = parent_states_by_name[common_ancestor_name]
1220
+ for parent_state_name in missing_parent_states:
1221
+ parent_state = await state_manager.get_state(
1222
+ token=_substate_key(
1223
+ self.router.session.client_token, parent_state_name
1224
+ ),
1225
+ top_level=False,
1226
+ get_substates=False,
1227
+ parent_state=parent_state,
1228
+ )
1229
+
1230
+ # Return the direct parent of target_state_cls for subsequent linking.
1231
+ return parent_state
1232
+
1233
+ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
1234
+ """Get a state instance from the cache.
1235
+
1236
+ Args:
1237
+ state_cls: The class of the state.
1238
+
1239
+ Returns:
1240
+ The instance of state_cls associated with this state's client_token.
1241
+ """
1242
+ if self.parent_state is None:
1243
+ root_state = self
1244
+ else:
1245
+ root_state = self._get_parent_states()[-1][1]
1246
+ return root_state.get_substate(state_cls.get_full_name().split("."))
1247
+
1248
+ async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
1249
+ """Get a state instance from redis.
1250
+
1251
+ Args:
1252
+ state_cls: The class of the state.
1253
+
1254
+ Returns:
1255
+ The instance of state_cls associated with this state's client_token.
1256
+
1257
+ Raises:
1258
+ RuntimeError: If redis is not used in this backend process.
1259
+ """
1260
+ # Fetch all missing parent states from redis.
1261
+ parent_state_of_state_cls = await self._populate_parent_states(state_cls)
1262
+
1263
+ # Then get the target state and all its substates.
1264
+ state_manager = get_state_manager()
1265
+ if not isinstance(state_manager, StateManagerRedis):
1266
+ raise RuntimeError(
1267
+ f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
1268
+ "(All states should already be available -- this is likely a bug).",
1269
+ )
1270
+ return await state_manager.get_state(
1271
+ token=_substate_key(self.router.session.client_token, state_cls),
1272
+ top_level=False,
1273
+ get_substates=True,
1274
+ parent_state=parent_state_of_state_cls,
1275
+ )
1276
+
1277
+ async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
1278
+ """Get an instance of the state associated with this token.
1279
+
1280
+ Allows for arbitrary access to sibling states from within an event handler.
1281
+
1282
+ Args:
1283
+ state_cls: The class of the state.
1284
+
1285
+ Returns:
1286
+ The instance of state_cls associated with this state's client_token.
1287
+ """
1288
+ # Fast case - if this state instance is already cached, get_substate from root state.
1289
+ try:
1290
+ return self._get_state_from_cache(state_cls)
1291
+ except ValueError:
1292
+ pass
1293
+
1294
+ # Slow case - fetch missing parent states from redis.
1295
+ return await self._get_state_from_redis(state_cls)
1296
+
1036
1297
  def _get_event_handler(
1037
1298
  self, event: Event
1038
1299
  ) -> tuple[BaseState | StateProxy, EventHandler]:
@@ -1235,6 +1496,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1235
1496
  for cvar in self._computed_var_dependencies[dirty_var]
1236
1497
  )
1237
1498
 
1499
+ @classmethod
1500
+ def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
1501
+ """Determine substates which could be affected by dirty vars in this state.
1502
+
1503
+ Returns:
1504
+ Set of State classes that may need to be fetched to recalc computed vars.
1505
+ """
1506
+ # _always_dirty_substates need to be fetched to recalc computed vars.
1507
+ fetch_substates = set(
1508
+ cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
1509
+ for substate_name in cls._always_dirty_substates
1510
+ )
1511
+ for dependent_substates in cls._substate_var_dependencies.values():
1512
+ fetch_substates.update(
1513
+ set(
1514
+ cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
1515
+ for substate_name in dependent_substates
1516
+ )
1517
+ )
1518
+ return fetch_substates
1519
+
1238
1520
  def get_delta(self) -> Delta:
1239
1521
  """Get the delta for the state.
1240
1522
 
@@ -1266,8 +1548,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1266
1548
  # Recursively find the substate deltas.
1267
1549
  substates = self.substates
1268
1550
  for substate in self.dirty_substates.union(self._always_dirty_substates):
1269
- if substate not in substates:
1270
- continue # substate not loaded at this time, no delta
1271
1551
  delta.update(substates[substate].get_delta())
1272
1552
 
1273
1553
  # Format the delta.
@@ -1289,20 +1569,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1289
1569
  # have to mark computed vars dirty to allow access to newly computed
1290
1570
  # values within the same ComputedVar function
1291
1571
  self._mark_dirty_computed_vars()
1572
+ self._mark_dirty_substates()
1292
1573
 
1293
- # Propagate dirty var / computed var status into substates
1574
+ def _mark_dirty_substates(self):
1575
+ """Propagate dirty var / computed var status into substates."""
1294
1576
  substates = self.substates
1295
1577
  for var in self.dirty_vars:
1296
1578
  for substate_name in self._substate_var_dependencies[var]:
1297
1579
  self.dirty_substates.add(substate_name)
1298
- if substate_name not in substates:
1299
- continue
1300
1580
  substate = substates[substate_name]
1301
1581
  substate.dirty_vars.add(var)
1302
1582
  substate._mark_dirty()
1303
1583
 
1584
+ def _update_was_touched(self):
1585
+ """Update the _was_touched flag based on dirty_vars."""
1586
+ if self.dirty_vars and not self._was_touched:
1587
+ for var in self.dirty_vars:
1588
+ if var in self.base_vars or var in self._backend_vars:
1589
+ self._was_touched = True
1590
+ break
1591
+
1592
+ def _get_was_touched(self) -> bool:
1593
+ """Check current dirty_vars and flag to determine if state instance was modified.
1594
+
1595
+ If any dirty vars belong to this state, mark _was_touched.
1596
+
1597
+ This flag determines whether this state instance should be persisted to redis.
1598
+
1599
+ Returns:
1600
+ Whether this state instance was ever modified.
1601
+ """
1602
+ # Ensure the flag is up to date based on the current dirty_vars
1603
+ self._update_was_touched()
1604
+ return self._was_touched
1605
+
1304
1606
  def _clean(self):
1305
1607
  """Reset the dirty vars."""
1608
+ # Update touched status before cleaning dirty_vars.
1609
+ self._update_was_touched()
1610
+
1306
1611
  # Recursively clean the substates.
1307
1612
  for substate in self.dirty_substates:
1308
1613
  if substate not in self.substates:
@@ -1328,11 +1633,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1328
1633
  return super().get_value(key.__wrapped__)
1329
1634
  return super().get_value(key)
1330
1635
 
1331
- def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
1636
+ def dict(
1637
+ self, include_computed: bool = True, initial: bool = False, **kwargs
1638
+ ) -> dict[str, Any]:
1332
1639
  """Convert the object to a dictionary.
1333
1640
 
1334
1641
  Args:
1335
1642
  include_computed: Whether to include computed vars.
1643
+ initial: Whether to get the initial value of computed vars.
1336
1644
  **kwargs: Kwargs to pass to the pydantic dict method.
1337
1645
 
1338
1646
  Returns:
@@ -1348,21 +1656,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1348
1656
  prop_name: self.get_value(getattr(self, prop_name))
1349
1657
  for prop_name in self.base_vars
1350
1658
  }
1351
- computed_vars = (
1352
- {
1659
+ if initial:
1660
+ computed_vars = {
1661
+ # Include initial computed vars.
1662
+ prop_name: cv._initial_value
1663
+ if isinstance(cv, ComputedVar)
1664
+ and not isinstance(cv._initial_value, types.Unset)
1665
+ else self.get_value(getattr(self, prop_name))
1666
+ for prop_name, cv in self.computed_vars.items()
1667
+ }
1668
+ elif include_computed:
1669
+ computed_vars = {
1353
1670
  # Include the computed vars.
1354
1671
  prop_name: self.get_value(getattr(self, prop_name))
1355
1672
  for prop_name in self.computed_vars
1356
1673
  }
1357
- if include_computed
1358
- else {}
1359
- )
1674
+ else:
1675
+ computed_vars = {}
1360
1676
  variables = {**base_vars, **computed_vars}
1361
1677
  d = {
1362
1678
  self.get_full_name(): {k: variables[k] for k in sorted(variables)},
1363
1679
  }
1364
1680
  for substate_d in [
1365
- v.dict(include_computed=include_computed, **kwargs)
1681
+ v.dict(include_computed=include_computed, initial=initial, **kwargs)
1366
1682
  for v in self.substates.values()
1367
1683
  ]:
1368
1684
  d.update(substate_d)
@@ -1408,6 +1724,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1408
1724
  state["__dict__"] = state["__dict__"].copy()
1409
1725
  state["__dict__"]["parent_state"] = None
1410
1726
  state["__dict__"]["substates"] = {}
1727
+ state["__dict__"].pop("_was_touched", None)
1411
1728
  return state
1412
1729
 
1413
1730
 
@@ -1417,6 +1734,35 @@ class State(BaseState):
1417
1734
  # The hydrated bool.
1418
1735
  is_hydrated: bool = False
1419
1736
 
1737
+
1738
+ class UpdateVarsInternalState(State):
1739
+ """Substate for handling internal state var updates."""
1740
+
1741
+ async def update_vars_internal(self, vars: dict[str, Any]) -> None:
1742
+ """Apply updates to fully qualified state vars.
1743
+
1744
+ The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
1745
+ and each value will be set on the appropriate substate instance.
1746
+
1747
+ This function is primarily used to apply cookie and local storage
1748
+ updates from the frontend to the appropriate substate.
1749
+
1750
+ Args:
1751
+ vars: The fully qualified vars and values to update.
1752
+ """
1753
+ for var, value in vars.items():
1754
+ state_name, _, var_name = var.rpartition(".")
1755
+ var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
1756
+ var_state = await self.get_state(var_state_cls)
1757
+ setattr(var_state, var_name, value)
1758
+
1759
+
1760
+ class OnLoadInternalState(State):
1761
+ """Substate for handling on_load event enumeration.
1762
+
1763
+ This is a separate substate to avoid deserializing the entire state tree for every page navigation.
1764
+ """
1765
+
1420
1766
  def on_load_internal(self) -> list[Event | EventSpec] | None:
1421
1767
  """Queue on_load handlers for the current page.
1422
1768
 
@@ -1428,6 +1774,9 @@ class State(BaseState):
1428
1774
  load_events = app.get_load_events(self.router.page.path)
1429
1775
  if not load_events and self.is_hydrated:
1430
1776
  return # Fast path for page-to-page navigation
1777
+ if not load_events:
1778
+ self.is_hydrated = True
1779
+ return # Fast path for initial hydrate with no on_load events defined.
1431
1780
  self.is_hydrated = False
1432
1781
  return [
1433
1782
  *fix_events(
@@ -1435,26 +1784,9 @@ class State(BaseState):
1435
1784
  self.router.session.client_token,
1436
1785
  router_data=self.router_data,
1437
1786
  ),
1438
- type(self).set_is_hydrated(True), # type: ignore
1787
+ State.set_is_hydrated(True), # type: ignore
1439
1788
  ]
1440
1789
 
1441
- def update_vars_internal(self, vars: dict[str, Any]) -> None:
1442
- """Apply updates to fully qualified state vars.
1443
-
1444
- The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
1445
- and each value will be set on the appropriate substate instance.
1446
-
1447
- This function is primarily used to apply cookie and local storage
1448
- updates from the frontend to the appropriate substate.
1449
-
1450
- Args:
1451
- vars: The fully qualified vars and values to update.
1452
- """
1453
- for var, value in vars.items():
1454
- state_name, _, var_name = var.rpartition(".")
1455
- var_state = self.get_substate(state_name.split("."))
1456
- setattr(var_state, var_name, value)
1457
-
1458
1790
 
1459
1791
  class StateProxy(wrapt.ObjectProxy):
1460
1792
  """Proxy of a state instance to control mutability of vars for a background task.
@@ -1508,9 +1840,10 @@ class StateProxy(wrapt.ObjectProxy):
1508
1840
  This StateProxy instance in mutable mode.
1509
1841
  """
1510
1842
  self._self_actx = self._self_app.modify_state(
1511
- self.__wrapped__.router.session.client_token
1512
- + "_"
1513
- + ".".join(self._self_substate_path)
1843
+ token=_substate_key(
1844
+ self.__wrapped__.router.session.client_token,
1845
+ self._self_substate_path,
1846
+ )
1514
1847
  )
1515
1848
  mutable_state = await self._self_actx.__aenter__()
1516
1849
  super().__setattr__(
@@ -1560,7 +1893,15 @@ class StateProxy(wrapt.ObjectProxy):
1560
1893
 
1561
1894
  Returns:
1562
1895
  The value of the attribute.
1896
+
1897
+ Raises:
1898
+ ImmutableStateError: If the state is not in mutable mode.
1563
1899
  """
1900
+ if name in ["substates", "parent_state"] and not self._self_mutable:
1901
+ raise ImmutableStateError(
1902
+ "Background task StateProxy is immutable outside of a context "
1903
+ "manager. Use `async with self` to modify state."
1904
+ )
1564
1905
  value = super().__getattr__(name)
1565
1906
  if not name.startswith("_self_") and isinstance(value, MutableProxy):
1566
1907
  # ensure mutations to these containers are blocked unless proxy is _mutable
@@ -1608,6 +1949,60 @@ class StateProxy(wrapt.ObjectProxy):
1608
1949
  "manager. Use `async with self` to modify state."
1609
1950
  )
1610
1951
 
1952
+ def get_substate(self, path: Sequence[str]) -> BaseState:
1953
+ """Only allow substate access with lock held.
1954
+
1955
+ Args:
1956
+ path: The path to the substate.
1957
+
1958
+ Returns:
1959
+ The substate.
1960
+
1961
+ Raises:
1962
+ ImmutableStateError: If the state is not in mutable mode.
1963
+ """
1964
+ if not self._self_mutable:
1965
+ raise ImmutableStateError(
1966
+ "Background task StateProxy is immutable outside of a context "
1967
+ "manager. Use `async with self` to modify state."
1968
+ )
1969
+ return self.__wrapped__.get_substate(path)
1970
+
1971
+ async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
1972
+ """Get an instance of the state associated with this token.
1973
+
1974
+ Args:
1975
+ state_cls: The class of the state.
1976
+
1977
+ Returns:
1978
+ The state.
1979
+
1980
+ Raises:
1981
+ ImmutableStateError: If the state is not in mutable mode.
1982
+ """
1983
+ if not self._self_mutable:
1984
+ raise ImmutableStateError(
1985
+ "Background task StateProxy is immutable outside of a context "
1986
+ "manager. Use `async with self` to modify state."
1987
+ )
1988
+ return await self.__wrapped__.get_state(state_cls)
1989
+
1990
+ def _as_state_update(self, *args, **kwargs) -> StateUpdate:
1991
+ """Temporarily allow mutability to access parent_state.
1992
+
1993
+ Args:
1994
+ *args: The args to pass to the underlying state instance.
1995
+ **kwargs: The kwargs to pass to the underlying state instance.
1996
+
1997
+ Returns:
1998
+ The state update.
1999
+ """
2000
+ self._self_mutable = True
2001
+ try:
2002
+ return self.__wrapped__._as_state_update(*args, **kwargs)
2003
+ finally:
2004
+ self._self_mutable = False
2005
+
1611
2006
 
1612
2007
  class StateUpdate(Base):
1613
2008
  """A state update sent to the frontend."""
@@ -1708,9 +2103,9 @@ class StateManagerMemory(StateManager):
1708
2103
  The state for the token.
1709
2104
  """
1710
2105
  # Memory state manager ignores the substate suffix and always returns the top-level state.
1711
- token = token.partition("_")[0]
2106
+ token = _split_substate_key(token)[0]
1712
2107
  if token not in self.states:
1713
- self.states[token] = self.state()
2108
+ self.states[token] = self.state(_reflex_internal_init=True)
1714
2109
  return self.states[token]
1715
2110
 
1716
2111
  async def set_state(self, token: str, state: BaseState):
@@ -1733,7 +2128,7 @@ class StateManagerMemory(StateManager):
1733
2128
  The state for the token.
1734
2129
  """
1735
2130
  # Memory state manager ignores the substate suffix and always returns the top-level state.
1736
- token = token.partition("_")[0]
2131
+ token = _split_substate_key(token)[0]
1737
2132
  if token not in self._states_locks:
1738
2133
  async with self._state_manager_lock:
1739
2134
  if token not in self._states_locks:
@@ -1773,6 +2168,81 @@ class StateManagerRedis(StateManager):
1773
2168
  b"evicted",
1774
2169
  }
1775
2170
 
2171
+ def _get_root_state(self, state: BaseState) -> BaseState:
2172
+ """Chase parent_state pointers to find an instance of the top-level state.
2173
+
2174
+ Args:
2175
+ state: The state to start from.
2176
+
2177
+ Returns:
2178
+ An instance of the top-level state (self.state).
2179
+ """
2180
+ while type(state) != self.state and state.parent_state is not None:
2181
+ state = state.parent_state
2182
+ return state
2183
+
2184
+ async def _get_parent_state(self, token: str) -> BaseState | None:
2185
+ """Get the parent state for the state requested in the token.
2186
+
2187
+ Args:
2188
+ token: The token to get the state for (_substate_key).
2189
+
2190
+ Returns:
2191
+ The parent state for the state requested by the token or None if there is no such parent.
2192
+ """
2193
+ parent_state = None
2194
+ client_token, state_path = _split_substate_key(token)
2195
+ parent_state_name = state_path.rpartition(".")[0]
2196
+ if parent_state_name:
2197
+ # Retrieve the parent state to populate event handlers onto this substate.
2198
+ parent_state = await self.get_state(
2199
+ token=_substate_key(client_token, parent_state_name),
2200
+ top_level=False,
2201
+ get_substates=False,
2202
+ )
2203
+ return parent_state
2204
+
2205
+ async def _populate_substates(
2206
+ self,
2207
+ token: str,
2208
+ state: BaseState,
2209
+ all_substates: bool = False,
2210
+ ):
2211
+ """Fetch and link substates for the given state instance.
2212
+
2213
+ There is no return value; the side-effect is that `state` will have `substates` populated,
2214
+ and each substate will have its `parent_state` set to `state`.
2215
+
2216
+ Args:
2217
+ token: The token to get the state for.
2218
+ state: The state instance to populate substates for.
2219
+ all_substates: Whether to fetch all substates or just required substates.
2220
+ """
2221
+ client_token, _ = _split_substate_key(token)
2222
+
2223
+ if all_substates:
2224
+ # All substates are requested.
2225
+ fetch_substates = state.get_substates()
2226
+ else:
2227
+ # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
2228
+ fetch_substates = state._potentially_dirty_substates()
2229
+
2230
+ tasks = {}
2231
+ # Retrieve the necessary substates from redis.
2232
+ for substate_cls in fetch_substates:
2233
+ substate_name = substate_cls.get_name()
2234
+ tasks[substate_name] = asyncio.create_task(
2235
+ self.get_state(
2236
+ token=_substate_key(client_token, substate_cls),
2237
+ top_level=False,
2238
+ get_substates=all_substates,
2239
+ parent_state=state,
2240
+ )
2241
+ )
2242
+
2243
+ for substate_name, substate_task in tasks.items():
2244
+ state.substates[substate_name] = await substate_task
2245
+
1776
2246
  async def get_state(
1777
2247
  self,
1778
2248
  token: str,
@@ -1784,8 +2254,8 @@ class StateManagerRedis(StateManager):
1784
2254
 
1785
2255
  Args:
1786
2256
  token: The token to get the state for.
1787
- top_level: If true, return an instance of the top-level state.
1788
- get_substates: If true, also retrieve substates
2257
+ top_level: If true, return an instance of the top-level state (self.state).
2258
+ get_substates: If true, also retrieve substates.
1789
2259
  parent_state: If provided, use this parent_state instead of getting it from redis.
1790
2260
 
1791
2261
  Returns:
@@ -1795,7 +2265,7 @@ class StateManagerRedis(StateManager):
1795
2265
  RuntimeError: when the state_cls is not specified in the token
1796
2266
  """
1797
2267
  # Split the actual token from the fully qualified substate name.
1798
- client_token, _, state_path = token.partition("_")
2268
+ _, state_path = _split_substate_key(token)
1799
2269
  if state_path:
1800
2270
  # Get the State class associated with the given path.
1801
2271
  state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
@@ -1811,66 +2281,49 @@ class StateManagerRedis(StateManager):
1811
2281
  # Deserialize the substate.
1812
2282
  state = cloudpickle.loads(redis_state)
1813
2283
 
1814
- # Populate parent and substates if requested.
2284
+ # Populate parent state if missing and requested.
1815
2285
  if parent_state is None:
1816
- # Retrieve the parent state from redis.
1817
- parent_state_name = state_path.rpartition(".")[0]
1818
- if parent_state_name:
1819
- parent_state_key = token.rpartition(".")[0]
1820
- parent_state = await self.get_state(
1821
- parent_state_key, top_level=False, get_substates=False
1822
- )
2286
+ parent_state = await self._get_parent_state(token)
1823
2287
  # Set up Bidirectional linkage between this state and its parent.
1824
2288
  if parent_state is not None:
1825
2289
  parent_state.substates[state.get_name()] = state
1826
2290
  state.parent_state = parent_state
1827
- if get_substates:
1828
- # Retrieve all substates from redis.
1829
- for substate_cls in state_cls.get_substates():
1830
- substate_name = substate_cls.get_name()
1831
- substate_key = token + "." + substate_name
1832
- state.substates[substate_name] = await self.get_state(
1833
- substate_key, top_level=False, parent_state=state
1834
- )
2291
+ # Populate substates if requested.
2292
+ await self._populate_substates(token, state, all_substates=get_substates)
2293
+
1835
2294
  # To retain compatibility with previous implementation, by default, we return
1836
2295
  # the top-level state by chasing `parent_state` pointers up the tree.
1837
2296
  if top_level:
1838
- while type(state) != self.state and state.parent_state is not None:
1839
- state = state.parent_state
2297
+ return self._get_root_state(state)
1840
2298
  return state
1841
2299
 
1842
- # Key didn't exist so we have to create a new entry for this token.
2300
+ # TODO: dedupe the following logic with the above block
2301
+ # Key didn't exist so we have to create a new instance for this token.
1843
2302
  if parent_state is None:
1844
- parent_state_name = state_path.rpartition(".")[0]
1845
- if parent_state_name:
1846
- # Retrieve the parent state to populate event handlers onto this substate.
1847
- parent_state_key = client_token + "_" + parent_state_name
1848
- parent_state = await self.get_state(
1849
- parent_state_key, top_level=False, get_substates=False
1850
- )
1851
- # Persist the new state class to redis.
1852
- await self.set_state(
1853
- token,
1854
- state_cls(
1855
- parent_state=parent_state,
1856
- init_substates=False,
1857
- ),
1858
- )
1859
- # After creating the state key, recursively call `get_state` to populate substates.
1860
- return await self.get_state(
1861
- token,
1862
- top_level=top_level,
1863
- get_substates=get_substates,
2303
+ parent_state = await self._get_parent_state(token)
2304
+ # Instantiate the new state class (but don't persist it yet).
2305
+ state = state_cls(
1864
2306
  parent_state=parent_state,
2307
+ init_substates=False,
2308
+ _reflex_internal_init=True,
1865
2309
  )
2310
+ # Set up Bidirectional linkage between this state and its parent.
2311
+ if parent_state is not None:
2312
+ parent_state.substates[state.get_name()] = state
2313
+ state.parent_state = parent_state
2314
+ # Populate substates for the newly created state.
2315
+ await self._populate_substates(token, state, all_substates=get_substates)
2316
+ # To retain compatibility with previous implementation, by default, we return
2317
+ # the top-level state by chasing `parent_state` pointers up the tree.
2318
+ if top_level:
2319
+ return self._get_root_state(state)
2320
+ return state
1866
2321
 
1867
2322
  async def set_state(
1868
2323
  self,
1869
2324
  token: str,
1870
2325
  state: BaseState,
1871
2326
  lock_id: bytes | None = None,
1872
- set_substates: bool = True,
1873
- set_parent_state: bool = True,
1874
2327
  ):
1875
2328
  """Set the state for a token.
1876
2329
 
@@ -1878,11 +2331,10 @@ class StateManagerRedis(StateManager):
1878
2331
  token: The token to set the state for.
1879
2332
  state: The state to set.
1880
2333
  lock_id: If provided, the lock_key must be set to this value to set the state.
1881
- set_substates: If True, write substates to redis
1882
- set_parent_state: If True, write parent state to redis
1883
2334
 
1884
2335
  Raises:
1885
2336
  LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
2337
+ RuntimeError: If the state instance doesn't match the state name in the token.
1886
2338
  """
1887
2339
  # Check that we're holding the lock.
1888
2340
  if (
@@ -1894,28 +2346,36 @@ class StateManagerRedis(StateManager):
1894
2346
  f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
1895
2347
  "or use `@rx.background` decorator for long-running tasks."
1896
2348
  )
1897
- # Find the substate associated with the token.
1898
- state_path = token.partition("_")[2]
1899
- if state_path and state.get_full_name() != state_path:
1900
- state = state.get_substate(tuple(state_path.split(".")))
1901
- # Persist the parent state separately, if requested.
1902
- if state.parent_state is not None and set_parent_state:
1903
- parent_state_key = token.rpartition(".")[0]
1904
- await self.set_state(
1905
- parent_state_key,
1906
- state.parent_state,
1907
- lock_id=lock_id,
1908
- set_substates=False,
2349
+ client_token, substate_name = _split_substate_key(token)
2350
+ # If the substate name on the token doesn't match the instance name, it cannot have a parent.
2351
+ if state.parent_state is not None and state.get_full_name() != substate_name:
2352
+ raise RuntimeError(
2353
+ f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
1909
2354
  )
1910
- # Persist the substates separately, if requested.
1911
- if set_substates:
1912
- for substate_name, substate in state.substates.items():
1913
- substate_key = token + "." + substate_name
1914
- await self.set_state(
1915
- substate_key, substate, lock_id=lock_id, set_parent_state=False
2355
+
2356
+ # Recursively set_state on all known substates.
2357
+ tasks = []
2358
+ for substate in state.substates.values():
2359
+ tasks.append(
2360
+ asyncio.create_task(
2361
+ self.set_state(
2362
+ token=_substate_key(client_token, substate),
2363
+ state=substate,
2364
+ lock_id=lock_id,
2365
+ )
1916
2366
  )
2367
+ )
1917
2368
  # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
1918
- await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
2369
+ if state._get_was_touched():
2370
+ await self.redis.set(
2371
+ _substate_key(client_token, state),
2372
+ cloudpickle.dumps(state),
2373
+ ex=self.token_expiration,
2374
+ )
2375
+
2376
+ # Wait for substates to be persisted.
2377
+ for t in tasks:
2378
+ await t
1919
2379
 
1920
2380
  @contextlib.asynccontextmanager
1921
2381
  async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@@ -1943,7 +2403,7 @@ class StateManagerRedis(StateManager):
1943
2403
  The redis lock key for the token.
1944
2404
  """
1945
2405
  # All substates share the same lock domain, so ignore any substate path suffix.
1946
- client_token = token.partition("_")[0]
2406
+ client_token = _split_substate_key(token)[0]
1947
2407
  return f"{client_token}_lock".encode()
1948
2408
 
1949
2409
  async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
@@ -2038,6 +2498,16 @@ class StateManagerRedis(StateManager):
2038
2498
  await self.redis.close(close_connection_pool=True)
2039
2499
 
2040
2500
 
2501
+ def get_state_manager() -> StateManager:
2502
+ """Get the state manager for the app that is currently running.
2503
+
2504
+ Returns:
2505
+ The state manager.
2506
+ """
2507
+ app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
2508
+ return app.state_manager
2509
+
2510
+
2041
2511
  class ClientStorageBase:
2042
2512
  """Base class for client-side storage."""
2043
2513