reflex 0.6.8a2__py3-none-any.whl → 0.7.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflex might be problematic. Click here for more details.

Files changed (154) hide show
  1. reflex/.templates/jinja/custom_components/pyproject.toml.jinja2 +1 -1
  2. reflex/.templates/jinja/web/pages/_app.js.jinja2 +7 -7
  3. reflex/.templates/jinja/web/pages/utils.js.jinja2 +2 -2
  4. reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js +1 -4
  5. reflex/.templates/web/utils/state.js +65 -36
  6. reflex/__init__.py +4 -17
  7. reflex/__init__.pyi +1 -2
  8. reflex/app.py +244 -109
  9. reflex/app_mixins/lifespan.py +9 -9
  10. reflex/app_mixins/middleware.py +6 -6
  11. reflex/app_module_for_backend.py +3 -7
  12. reflex/base.py +7 -7
  13. reflex/compiler/compiler.py +8 -0
  14. reflex/compiler/utils.py +35 -6
  15. reflex/components/base/bare.py +1 -1
  16. reflex/components/base/error_boundary.py +2 -1
  17. reflex/components/base/error_boundary.pyi +2 -1
  18. reflex/components/base/meta.py +2 -2
  19. reflex/components/base/strict_mode.py +10 -0
  20. reflex/components/base/strict_mode.pyi +57 -0
  21. reflex/components/component.py +38 -77
  22. reflex/components/core/banner.py +83 -4
  23. reflex/components/core/banner.pyi +86 -0
  24. reflex/components/core/breakpoints.py +3 -1
  25. reflex/components/core/client_side_routing.py +1 -1
  26. reflex/components/core/client_side_routing.pyi +1 -1
  27. reflex/components/core/cond.py +9 -10
  28. reflex/components/core/debounce.py +1 -1
  29. reflex/components/core/foreach.py +23 -3
  30. reflex/components/core/html.py +1 -1
  31. reflex/components/core/match.py +5 -5
  32. reflex/components/core/sticky.py +160 -0
  33. reflex/components/core/sticky.pyi +449 -0
  34. reflex/components/core/upload.py +2 -2
  35. reflex/components/datadisplay/code.py +5 -14
  36. reflex/components/datadisplay/dataeditor.py +7 -4
  37. reflex/components/datadisplay/logo.py +13 -8
  38. reflex/components/datadisplay/shiki_code_block.py +14 -9
  39. reflex/components/dynamic.py +22 -3
  40. reflex/components/el/constants/reflex.py +1 -1
  41. reflex/components/el/element.py +1 -1
  42. reflex/components/el/elements/forms.py +4 -4
  43. reflex/components/el/elements/forms.pyi +4 -4
  44. reflex/components/lucide/icon.py +46 -8
  45. reflex/components/lucide/icon.pyi +54 -0
  46. reflex/components/markdown/markdown.py +10 -8
  47. reflex/components/moment/moment.py +2 -2
  48. reflex/components/next/image.py +16 -4
  49. reflex/components/next/image.pyi +4 -2
  50. reflex/components/next/link.py +1 -1
  51. reflex/components/plotly/plotly.py +5 -5
  52. reflex/components/props.py +3 -3
  53. reflex/components/radix/__init__.pyi +1 -1
  54. reflex/components/radix/primitives/accordion.py +9 -5
  55. reflex/components/radix/primitives/accordion.pyi +3 -1
  56. reflex/components/radix/primitives/drawer.py +5 -2
  57. reflex/components/radix/primitives/drawer.pyi +4 -4
  58. reflex/components/radix/primitives/form.pyi +6 -6
  59. reflex/components/radix/primitives/progress.py +1 -1
  60. reflex/components/radix/primitives/slider.py +1 -1
  61. reflex/components/radix/themes/color_mode.py +11 -9
  62. reflex/components/radix/themes/components/alert_dialog.py +3 -0
  63. reflex/components/radix/themes/components/card.py +1 -1
  64. reflex/components/radix/themes/components/card.pyi +1 -1
  65. reflex/components/radix/themes/components/context_menu.py +5 -0
  66. reflex/components/radix/themes/components/dialog.py +3 -0
  67. reflex/components/radix/themes/components/dropdown_menu.py +5 -0
  68. reflex/components/radix/themes/components/hover_card.py +3 -0
  69. reflex/components/radix/themes/components/icon_button.py +2 -2
  70. reflex/components/radix/themes/components/icon_button.pyi +1 -0
  71. reflex/components/radix/themes/components/popover.py +3 -0
  72. reflex/components/radix/themes/components/radio_cards.py +2 -0
  73. reflex/components/radix/themes/components/radio_group.py +1 -1
  74. reflex/components/radix/themes/components/select.py +3 -0
  75. reflex/components/radix/themes/components/tabs.py +3 -0
  76. reflex/components/radix/themes/components/text_area.py +12 -0
  77. reflex/components/radix/themes/components/text_area.pyi +2 -0
  78. reflex/components/radix/themes/components/text_field.py +1 -1
  79. reflex/components/radix/themes/components/tooltip.py +3 -1
  80. reflex/components/radix/themes/components/tooltip.pyi +1 -0
  81. reflex/components/radix/themes/layout/__init__.pyi +1 -1
  82. reflex/components/radix/themes/layout/list.py +2 -2
  83. reflex/components/radix/themes/layout/stack.py +2 -2
  84. reflex/components/radix/themes/typography/link.py +1 -1
  85. reflex/components/radix/themes/typography/text.py +2 -2
  86. reflex/components/react_player/react_player.py +1 -1
  87. reflex/components/recharts/__init__.py +2 -0
  88. reflex/components/recharts/__init__.pyi +2 -0
  89. reflex/components/recharts/charts.py +15 -15
  90. reflex/components/recharts/general.py +19 -4
  91. reflex/components/recharts/general.pyi +55 -4
  92. reflex/components/recharts/polar.py +2 -2
  93. reflex/components/recharts/recharts.py +4 -4
  94. reflex/components/sonner/toast.py +15 -13
  95. reflex/components/sonner/toast.pyi +6 -6
  96. reflex/components/suneditor/editor.py +6 -4
  97. reflex/components/suneditor/editor.pyi +2 -2
  98. reflex/components/tags/iter_tag.py +3 -3
  99. reflex/components/tags/tag.py +25 -3
  100. reflex/config.py +48 -15
  101. reflex/constants/__init__.py +1 -0
  102. reflex/constants/base.py +4 -1
  103. reflex/constants/compiler.py +5 -2
  104. reflex/constants/config.py +8 -1
  105. reflex/constants/installer.py +9 -9
  106. reflex/constants/style.py +1 -1
  107. reflex/custom_components/custom_components.py +9 -7
  108. reflex/event.py +130 -161
  109. reflex/experimental/__init__.py +19 -11
  110. reflex/experimental/client_state.py +53 -28
  111. reflex/experimental/hooks.py +5 -5
  112. reflex/experimental/layout.py +8 -5
  113. reflex/experimental/layout.pyi +1 -1
  114. reflex/experimental/misc.py +3 -3
  115. reflex/istate/wrappers.py +1 -1
  116. reflex/middleware/hydrate_middleware.py +2 -2
  117. reflex/model.py +11 -6
  118. reflex/page.py +3 -3
  119. reflex/reflex.py +90 -19
  120. reflex/route.py +1 -1
  121. reflex/state.py +358 -401
  122. reflex/style.py +27 -3
  123. reflex/testing.py +29 -23
  124. reflex/utils/build.py +6 -2
  125. reflex/utils/codespaces.py +1 -4
  126. reflex/utils/compat.py +6 -5
  127. reflex/utils/console.py +52 -16
  128. reflex/utils/exceptions.py +76 -26
  129. reflex/utils/exec.py +69 -74
  130. reflex/utils/export.py +6 -1
  131. reflex/utils/format.py +7 -39
  132. reflex/utils/imports.py +2 -2
  133. reflex/utils/lazy_loader.py +7 -1
  134. reflex/utils/path_ops.py +28 -14
  135. reflex/utils/prerequisites.py +324 -65
  136. reflex/utils/processes.py +45 -32
  137. reflex/utils/pyi_generator.py +30 -25
  138. reflex/utils/registry.py +4 -4
  139. reflex/utils/serializers.py +1 -1
  140. reflex/utils/telemetry.py +5 -4
  141. reflex/utils/types.py +42 -18
  142. reflex/vars/base.py +650 -333
  143. reflex/vars/datetime.py +6 -7
  144. reflex/vars/dep_tracking.py +344 -0
  145. reflex/vars/function.py +11 -5
  146. reflex/vars/number.py +31 -43
  147. reflex/vars/object.py +63 -62
  148. reflex/vars/sequence.py +79 -67
  149. {reflex-0.6.8a2.dist-info → reflex-0.7.0a1.dist-info}/METADATA +7 -8
  150. {reflex-0.6.8a2.dist-info → reflex-0.7.0a1.dist-info}/RECORD +153 -149
  151. {reflex-0.6.8a2.dist-info → reflex-0.7.0a1.dist-info}/WHEEL +1 -1
  152. reflex/experimental/assets.py +0 -37
  153. {reflex-0.6.8a2.dist-info → reflex-0.7.0a1.dist-info}/LICENSE +0 -0
  154. {reflex-0.6.8a2.dist-info → reflex-0.7.0a1.dist-info}/entry_points.txt +0 -0
reflex/state.py CHANGED
@@ -15,7 +15,6 @@ import time
15
15
  import typing
16
16
  import uuid
17
17
  from abc import ABC, abstractmethod
18
- from collections import defaultdict
19
18
  from hashlib import md5
20
19
  from pathlib import Path
21
20
  from types import FunctionType, MethodType
@@ -31,6 +30,7 @@ from typing import (
31
30
  Optional,
32
31
  Sequence,
33
32
  Set,
33
+ SupportsIndex,
34
34
  Tuple,
35
35
  Type,
36
36
  TypeVar,
@@ -93,17 +93,18 @@ from reflex.event import (
93
93
  )
94
94
  from reflex.utils import console, format, path_ops, prerequisites, types
95
95
  from reflex.utils.exceptions import (
96
- ComputedVarShadowsBaseVars,
97
- ComputedVarShadowsStateVar,
98
- DynamicComponentInvalidSignature,
99
- DynamicRouteArgShadowsStateVar,
100
- EventHandlerShadowsBuiltInStateMethod,
96
+ ComputedVarShadowsBaseVarsError,
97
+ ComputedVarShadowsStateVarError,
98
+ DynamicComponentInvalidSignatureError,
99
+ DynamicRouteArgShadowsStateVarError,
100
+ EventHandlerShadowsBuiltInStateMethodError,
101
101
  ImmutableStateError,
102
102
  InvalidLockWarningThresholdError,
103
- InvalidStateManagerMode,
103
+ InvalidStateManagerModeError,
104
104
  LockExpiredError,
105
105
  ReflexRuntimeError,
106
106
  SetUndefinedStateVarError,
107
+ StateMismatchError,
107
108
  StateSchemaMismatchError,
108
109
  StateSerializationError,
109
110
  StateTooLargeError,
@@ -327,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
327
328
  )
328
329
 
329
330
 
331
+ async def _resolve_delta(delta: Delta) -> Delta:
332
+ """Await all coroutines in the delta.
333
+
334
+ Args:
335
+ delta: The delta to process.
336
+
337
+ Returns:
338
+ The same delta dict with all coroutines resolved to their return value.
339
+ """
340
+ tasks = {}
341
+ for state_name, state_delta in delta.items():
342
+ for var_name, value in state_delta.items():
343
+ if asyncio.iscoroutine(value):
344
+ tasks[state_name, var_name] = asyncio.create_task(value)
345
+ for (state_name, var_name), task in tasks.items():
346
+ delta[state_name][var_name] = await task
347
+ return delta
348
+
349
+
330
350
  class BaseState(Base, ABC, extra=pydantic.Extra.allow):
331
351
  """The state of the app."""
332
352
 
@@ -354,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
354
374
  # A set of subclassses of this class.
355
375
  class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
356
376
 
357
- # Mapping of var name to set of computed variables that depend on it
358
- _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
359
-
360
- # Mapping of var name to set of substates that depend on it
361
- _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
377
+ # Mapping of var name to set of (state_full_name, var_name) that depend on it.
378
+ _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
362
379
 
363
380
  # Set of vars which always need to be recomputed
364
381
  _always_dirty_computed_vars: ClassVar[Set[str]] = set()
@@ -366,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
366
383
  # Set of substates which always need to be recomputed
367
384
  _always_dirty_substates: ClassVar[Set[str]] = set()
368
385
 
386
+ # Set of states which might need to be recomputed if vars in this state change.
387
+ _potentially_dirty_states: ClassVar[Set[str]] = set()
388
+
369
389
  # The parent state.
370
390
  parent_state: Optional[BaseState] = None
371
391
 
@@ -517,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
517
537
 
518
538
  # Reset dirty substate tracking for this class.
519
539
  cls._always_dirty_substates = set()
540
+ cls._potentially_dirty_states = set()
520
541
 
521
542
  # Get the parent vars.
522
543
  parent_state = cls.get_parent_state()
@@ -586,7 +607,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
586
607
  if cls._item_is_event_handler(name, fn)
587
608
  }
588
609
 
589
- for mixin in cls._mixins():
610
+ for mixin in cls._mixins(): # pyright: ignore [reportAssignmentType]
590
611
  for name, value in mixin.__dict__.items():
591
612
  if name in cls.inherited_vars:
592
613
  continue
@@ -598,7 +619,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
598
619
  cls.computed_vars[newcv._js_expr] = newcv
599
620
  cls.vars[newcv._js_expr] = newcv
600
621
  continue
601
- if types.is_backend_base_variable(name, mixin):
622
+ if types.is_backend_base_variable(name, mixin): # pyright: ignore [reportArgumentType]
602
623
  cls.backend_vars[name] = copy.deepcopy(value)
603
624
  continue
604
625
  if events.get(name) is not None:
@@ -620,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
620
641
  setattr(cls, name, handler)
621
642
 
622
643
  # Initialize per-class var dependency tracking.
623
- cls._computed_var_dependencies = defaultdict(set)
624
- cls._substate_var_dependencies = defaultdict(set)
644
+ cls._var_dependencies = {}
625
645
  cls._init_var_dependency_dicts()
626
646
 
627
647
  @staticmethod
@@ -766,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
766
786
  Additional updates tracking dicts for vars and substates that always
767
787
  need to be recomputed.
768
788
  """
769
- inherited_vars = set(cls.inherited_vars).union(
770
- set(cls.inherited_backend_vars),
771
- )
772
789
  for cvar_name, cvar in cls.computed_vars.items():
773
- # Add the dependencies.
774
- for var in cvar._deps(objclass=cls):
775
- cls._computed_var_dependencies[var].add(cvar_name)
776
- if var in inherited_vars:
777
- # track that this substate depends on its parent for this var
778
- state_name = cls.get_name()
779
- parent_state = cls.get_parent_state()
780
- while parent_state is not None and var in {
781
- **parent_state.vars,
782
- **parent_state.backend_vars,
790
+ if not cvar._cache:
791
+ # Do not perform dep calculation when cache=False (these are always dirty).
792
+ continue
793
+ for state_name, dvar_set in cvar._deps(objclass=cls).items():
794
+ state_cls = cls.get_root_state().get_class_substate(state_name)
795
+ for dvar in dvar_set:
796
+ defining_state_cls = state_cls
797
+ while dvar in {
798
+ *defining_state_cls.inherited_vars,
799
+ *defining_state_cls.inherited_backend_vars,
783
800
  }:
784
- parent_state._substate_var_dependencies[var].add(state_name)
785
- state_name, parent_state = (
786
- parent_state.get_name(),
787
- parent_state.get_parent_state(),
788
- )
801
+ parent_state = defining_state_cls.get_parent_state()
802
+ if parent_state is not None:
803
+ defining_state_cls = parent_state
804
+ defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
805
+ (cls.get_full_name(), cvar_name)
806
+ )
807
+ defining_state_cls._potentially_dirty_states.add(
808
+ cls.get_full_name()
809
+ )
789
810
 
790
811
  # ComputedVar with cache=False always need to be recomputed
791
812
  cls._always_dirty_computed_vars = {
@@ -814,7 +835,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
814
835
  """Check for shadow methods and raise error if any.
815
836
 
816
837
  Raises:
817
- EventHandlerShadowsBuiltInStateMethod: When an event handler shadows an inbuilt state method.
838
+ EventHandlerShadowsBuiltInStateMethodError: When an event handler shadows an inbuilt state method.
818
839
  """
819
840
  overridden_methods = set()
820
841
  state_base_functions = cls._get_base_functions()
@@ -828,7 +849,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
828
849
  overridden_methods.add(method.__name__)
829
850
 
830
851
  for method_name in overridden_methods:
831
- raise EventHandlerShadowsBuiltInStateMethod(
852
+ raise EventHandlerShadowsBuiltInStateMethodError(
832
853
  f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead"
833
854
  )
834
855
 
@@ -837,11 +858,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
837
858
  """Check for shadow base vars and raise error if any.
838
859
 
839
860
  Raises:
840
- ComputedVarShadowsBaseVars: When a computed var shadows a base var.
861
+ ComputedVarShadowsBaseVarsError: When a computed var shadows a base var.
841
862
  """
842
863
  for computed_var_ in cls._get_computed_vars():
843
864
  if computed_var_._js_expr in cls.__annotations__:
844
- raise ComputedVarShadowsBaseVars(
865
+ raise ComputedVarShadowsBaseVarsError(
845
866
  f"The computed var name `{computed_var_._js_expr}` shadows a base var in {cls.__module__}.{cls.__name__}; use a different name instead"
846
867
  )
847
868
 
@@ -850,14 +871,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
850
871
  """Check for shadow computed vars and raise error if any.
851
872
 
852
873
  Raises:
853
- ComputedVarShadowsStateVar: When a computed var shadows another.
874
+ ComputedVarShadowsStateVarError: When a computed var shadows another.
854
875
  """
855
876
  for name, cv in cls.__dict__.items():
856
877
  if not is_computed_var(cv):
857
878
  continue
858
879
  name = cv._js_expr
859
880
  if name in cls.inherited_vars or name in cls.inherited_backend_vars:
860
- raise ComputedVarShadowsStateVar(
881
+ raise ComputedVarShadowsStateVarError(
861
882
  f"The computed var name `{cv._js_expr}` shadows a var in {cls.__module__}.{cls.__name__}; use a different name instead"
862
883
  )
863
884
 
@@ -898,7 +919,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
898
919
  ]
899
920
  if len(parent_states) >= 2:
900
921
  raise ValueError(f"Only one parent state is allowed {parent_states}.")
901
- return parent_states[0] if len(parent_states) == 1 else None # type: ignore
922
+ return parent_states[0] if len(parent_states) == 1 else None
923
+
924
+ @classmethod
925
+ @functools.lru_cache()
926
+ def get_root_state(cls) -> Type[BaseState]:
927
+ """Get the root state.
928
+
929
+ Returns:
930
+ The root state.
931
+ """
932
+ parent_state = cls.get_parent_state()
933
+ return cls if parent_state is None else parent_state.get_root_state()
902
934
 
903
935
  @classmethod
904
936
  def get_substates(cls) -> set[Type[BaseState]]:
@@ -1057,7 +1089,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1057
1089
  setattr(cls, prop._var_field_name, prop)
1058
1090
 
1059
1091
  @classmethod
1060
- def _create_event_handler(cls, fn):
1092
+ def _create_event_handler(cls, fn: Any):
1061
1093
  """Create an event handler for the given function.
1062
1094
 
1063
1095
  Args:
@@ -1175,14 +1207,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1175
1207
 
1176
1208
  cls._check_overwritten_dynamic_args(list(args.keys()))
1177
1209
 
1178
- def argsingle_factory(param):
1179
- def inner_func(self) -> str:
1210
+ def argsingle_factory(param: str):
1211
+ def inner_func(self: BaseState) -> str:
1180
1212
  return self.router.page.params.get(param, "")
1181
1213
 
1182
1214
  return inner_func
1183
1215
 
1184
- def arglist_factory(param):
1185
- def inner_func(self) -> List[str]:
1216
+ def arglist_factory(param: str):
1217
+ def inner_func(self: BaseState) -> List[str]:
1186
1218
  return self.router.page.params.get(param, [])
1187
1219
 
1188
1220
  return inner_func
@@ -1199,7 +1231,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1199
1231
  fget=func,
1200
1232
  auto_deps=False,
1201
1233
  deps=["router"],
1202
- cache=True,
1203
1234
  _js_expr=param,
1204
1235
  _var_data=VarData.from_state(cls),
1205
1236
  )
@@ -1218,14 +1249,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1218
1249
  args: a dict of args
1219
1250
 
1220
1251
  Raises:
1221
- DynamicRouteArgShadowsStateVar: If a dynamic arg is shadowing an existing var.
1252
+ DynamicRouteArgShadowsStateVarError: If a dynamic arg is shadowing an existing var.
1222
1253
  """
1223
1254
  for arg in args:
1224
1255
  if (
1225
1256
  arg in cls.computed_vars
1226
1257
  and not isinstance(cls.computed_vars[arg], DynamicRouteVar)
1227
1258
  ) or arg in cls.base_vars:
1228
- raise DynamicRouteArgShadowsStateVar(
1259
+ raise DynamicRouteArgShadowsStateVarError(
1229
1260
  f"Dynamic route arg '{arg}' is shadowing an existing var in {cls.__module__}.{cls.__name__}"
1230
1261
  )
1231
1262
  for substate in cls.get_substates():
@@ -1268,8 +1299,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1268
1299
  fn = _no_chain_background_task(type(self), name, handler.fn)
1269
1300
  else:
1270
1301
  fn = functools.partial(handler.fn, self)
1271
- fn.__module__ = handler.fn.__module__ # type: ignore
1272
- fn.__qualname__ = handler.fn.__qualname__ # type: ignore
1302
+ fn.__module__ = handler.fn.__module__
1303
+ fn.__qualname__ = handler.fn.__qualname__
1273
1304
  return fn
1274
1305
 
1275
1306
  backend_vars = super().__getattribute__("_backend_vars")
@@ -1341,19 +1372,16 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1341
1372
  if field.allow_none and not is_optional(field_type):
1342
1373
  field_type = Union[field_type, None]
1343
1374
  if not _isinstance(value, field_type):
1344
- console.deprecate(
1345
- "mismatched-type-assignment",
1346
- f"Tried to assign value {value} of type {type(value)} to field {type(self).__name__}.{name} of type {field_type}."
1347
- " This might lead to unexpected behavior.",
1348
- "0.6.5",
1349
- "0.7.0",
1375
+ console.error(
1376
+ f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
1377
+ f" but got '{value}' of type '{type(value)}'."
1350
1378
  )
1351
1379
 
1352
1380
  # Set the attribute.
1353
1381
  super().__setattr__(name, value)
1354
1382
 
1355
1383
  # Add the var to the dirty list.
1356
- if name in self.vars or name in self._computed_var_dependencies:
1384
+ if name in self.base_vars:
1357
1385
  self.dirty_vars.add(name)
1358
1386
  self._mark_dirty()
1359
1387
 
@@ -1424,63 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1424
1452
  return self.substates[path[0]].get_substate(path[1:])
1425
1453
 
1426
1454
  @classmethod
1427
- def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
1428
- """Find the name of the nearest common ancestor shared by this and the other state.
1429
-
1430
- Args:
1431
- other: The other state.
1432
-
1433
- Returns:
1434
- Full name of the nearest common ancestor.
1435
- """
1436
- common_ancestor_parts = []
1437
- for part1, part2 in zip(
1438
- cls.get_full_name().split("."),
1439
- other.get_full_name().split("."),
1440
- ):
1441
- if part1 != part2:
1442
- break
1443
- common_ancestor_parts.append(part1)
1444
- return ".".join(common_ancestor_parts)
1445
-
1446
- @classmethod
1447
- def _determine_missing_parent_states(
1448
- cls, target_state_cls: Type[BaseState]
1449
- ) -> tuple[str, list[str]]:
1450
- """Determine the missing parent states between the target_state_cls and common ancestor of this state.
1451
-
1452
- Args:
1453
- target_state_cls: The class of the state to find missing parent states for.
1454
-
1455
- Returns:
1456
- The name of the common ancestor and the list of missing parent states.
1457
- """
1458
- common_ancestor_name = cls._get_common_ancestor(target_state_cls)
1459
- common_ancestor_parts = common_ancestor_name.split(".")
1460
- target_state_parts = tuple(target_state_cls.get_full_name().split("."))
1461
- relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
1462
-
1463
- # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
1464
- fetch_parent_states = [common_ancestor_name]
1465
- for relative_parent_state_name in relative_target_state_parts:
1466
- fetch_parent_states.append(
1467
- ".".join((fetch_parent_states[-1], relative_parent_state_name))
1468
- )
1469
-
1470
- return common_ancestor_name, fetch_parent_states[1:-1]
1471
-
1472
- def _get_parent_states(self) -> list[tuple[str, BaseState]]:
1473
- """Get all parent state instances up to the root of the state tree.
1455
+ def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
1456
+ """Get substates which may have dirty vars due to dependencies.
1474
1457
 
1475
1458
  Returns:
1476
- A list of tuples containing the name and the instance of each parent state.
1459
+ The set of potentially dirty substate classes.
1477
1460
  """
1478
- parent_states_with_name = []
1479
- parent_state = self
1480
- while parent_state.parent_state is not None:
1481
- parent_state = parent_state.parent_state
1482
- parent_states_with_name.append((parent_state.get_full_name(), parent_state))
1483
- return parent_states_with_name
1461
+ return {
1462
+ cls.get_class_substate(substate_name)
1463
+ for substate_name in cls._always_dirty_substates
1464
+ }.union(
1465
+ {
1466
+ cls.get_root_state().get_class_substate(substate_name)
1467
+ for substate_name in cls._potentially_dirty_states
1468
+ }
1469
+ )
1484
1470
 
1485
1471
  def _get_root_state(self) -> BaseState:
1486
1472
  """Get the root state of the state tree.
@@ -1493,71 +1479,42 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1493
1479
  parent_state = parent_state.parent_state
1494
1480
  return parent_state
1495
1481
 
1496
- async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
1497
- """Populate substates in the tree between the target_state_cls and common ancestor of this state.
1482
+ async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
1483
+ """Get a state instance from redis.
1498
1484
 
1499
1485
  Args:
1500
- target_state_cls: The class of the state to populate parent states for.
1486
+ state_cls: The class of the state.
1501
1487
 
1502
1488
  Returns:
1503
- The parent state instance of target_state_cls.
1489
+ The instance of state_cls associated with this state's client_token.
1504
1490
 
1505
1491
  Raises:
1506
1492
  RuntimeError: If redis is not used in this backend process.
1493
+ StateMismatchError: If the state instance is not of the expected type.
1507
1494
  """
1495
+ # Then get the target state and all its substates.
1508
1496
  state_manager = get_state_manager()
1509
1497
  if not isinstance(state_manager, StateManagerRedis):
1510
1498
  raise RuntimeError(
1511
- f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
1499
+ f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
1512
1500
  "(All states should already be available -- this is likely a bug).",
1513
1501
  )
1502
+ state_in_redis = await state_manager.get_state(
1503
+ token=_substate_key(self.router.session.client_token, state_cls),
1504
+ top_level=False,
1505
+ for_state_instance=self,
1506
+ )
1514
1507
 
1515
- # Find the missing parent states up to the common ancestor.
1516
- (
1517
- common_ancestor_name,
1518
- missing_parent_states,
1519
- ) = self._determine_missing_parent_states(target_state_cls)
1520
-
1521
- # Fetch all missing parent states and link them up to the common ancestor.
1522
- parent_states_tuple = self._get_parent_states()
1523
- root_state = parent_states_tuple[-1][1]
1524
- parent_states_by_name = dict(parent_states_tuple)
1525
- parent_state = parent_states_by_name[common_ancestor_name]
1526
- for parent_state_name in missing_parent_states:
1527
- try:
1528
- parent_state = root_state.get_substate(parent_state_name.split("."))
1529
- # The requested state is already cached, do NOT fetch it again.
1530
- continue
1531
- except ValueError:
1532
- # The requested state is missing, fetch from redis.
1533
- pass
1534
- parent_state = await state_manager.get_state(
1535
- token=_substate_key(
1536
- self.router.session.client_token, parent_state_name
1537
- ),
1538
- top_level=False,
1539
- get_substates=False,
1540
- parent_state=parent_state,
1508
+ if not isinstance(state_in_redis, state_cls):
1509
+ raise StateMismatchError(
1510
+ f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
1541
1511
  )
1542
1512
 
1543
- # Return the direct parent of target_state_cls for subsequent linking.
1544
- return parent_state
1513
+ return state_in_redis
1545
1514
 
1546
- def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
1515
+ def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
1547
1516
  """Get a state instance from the cache.
1548
1517
 
1549
- Args:
1550
- state_cls: The class of the state.
1551
-
1552
- Returns:
1553
- The instance of state_cls associated with this state's client_token.
1554
- """
1555
- root_state = self._get_root_state()
1556
- return root_state.get_substate(state_cls.get_full_name().split("."))
1557
-
1558
- async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
1559
- """Get a state instance from redis.
1560
-
1561
1518
  Args:
1562
1519
  state_cls: The class of the state.
1563
1520
 
@@ -1565,26 +1522,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1565
1522
  The instance of state_cls associated with this state's client_token.
1566
1523
 
1567
1524
  Raises:
1568
- RuntimeError: If redis is not used in this backend process.
1525
+ StateMismatchError: If the state instance is not of the expected type.
1569
1526
  """
1570
- # Fetch all missing parent states from redis.
1571
- parent_state_of_state_cls = await self._populate_parent_states(state_cls)
1572
-
1573
- # Then get the target state and all its substates.
1574
- state_manager = get_state_manager()
1575
- if not isinstance(state_manager, StateManagerRedis):
1576
- raise RuntimeError(
1577
- f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
1578
- "(All states should already be available -- this is likely a bug).",
1527
+ root_state = self._get_root_state()
1528
+ substate = root_state.get_substate(state_cls.get_full_name().split("."))
1529
+ if not isinstance(substate, state_cls):
1530
+ raise StateMismatchError(
1531
+ f"Searched for state {state_cls.get_full_name()} but found {substate}."
1579
1532
  )
1580
- return await state_manager.get_state(
1581
- token=_substate_key(self.router.session.client_token, state_cls),
1582
- top_level=False,
1583
- get_substates=True,
1584
- parent_state=parent_state_of_state_cls,
1585
- )
1533
+ return substate
1586
1534
 
1587
- async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
1535
+ async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
1588
1536
  """Get an instance of the state associated with this token.
1589
1537
 
1590
1538
  Allows for arbitrary access to sibling states from within an event handler.
@@ -1619,11 +1567,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1619
1567
  """
1620
1568
  # Oopsie case: you didn't give me a Var... so get what you give.
1621
1569
  if not isinstance(var, Var):
1622
- return var # type: ignore
1570
+ return var
1571
+
1572
+ unset = object()
1623
1573
 
1624
1574
  # Fast case: this is a literal var and the value is known.
1625
- if hasattr(var, "_var_value"):
1626
- return var._var_value
1575
+ if (var_value := getattr(var, "_var_value", unset)) is not unset:
1576
+ return var_value # pyright: ignore [reportReturnType]
1627
1577
 
1628
1578
  var_data = var._get_all_var_data()
1629
1579
  if var_data is None or not var_data.state:
@@ -1720,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1720
1670
  f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
1721
1671
  )
1722
1672
 
1723
- def _as_state_update(
1673
+ async def _as_state_update(
1724
1674
  self,
1725
1675
  handler: EventHandler,
1726
1676
  events: EventSpec | list[EventSpec] | None,
@@ -1748,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1748
1698
 
1749
1699
  try:
1750
1700
  # Get the delta after processing the event.
1751
- delta = state.get_delta()
1701
+ delta = await _resolve_delta(state.get_delta())
1752
1702
  state._clean()
1753
1703
 
1754
1704
  return StateUpdate(
@@ -1759,9 +1709,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1759
1709
  except Exception as ex:
1760
1710
  state._clean()
1761
1711
 
1762
- app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
1763
-
1764
- event_specs = app_instance.backend_exception_handler(ex)
1712
+ event_specs = (
1713
+ prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
1714
+ )
1765
1715
 
1766
1716
  if event_specs is None:
1767
1717
  return StateUpdate()
@@ -1814,7 +1764,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1814
1764
  if (
1815
1765
  isinstance(value, dict)
1816
1766
  and inspect.isclass(hinted_args)
1817
- and not types.is_generic_alias(hinted_args) # py3.9-py3.10
1767
+ and not types.is_generic_alias(hinted_args) # py3.10
1818
1768
  ):
1819
1769
  if issubclass(hinted_args, Model):
1820
1770
  # Remove non-fields from the payload
@@ -1848,34 +1798,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1848
1798
  # Handle async generators.
1849
1799
  if inspect.isasyncgen(events):
1850
1800
  async for event in events:
1851
- yield state._as_state_update(handler, event, final=False)
1852
- yield state._as_state_update(handler, events=None, final=True)
1801
+ yield await state._as_state_update(handler, event, final=False)
1802
+ yield await state._as_state_update(handler, events=None, final=True)
1853
1803
 
1854
1804
  # Handle regular generators.
1855
1805
  elif inspect.isgenerator(events):
1856
1806
  try:
1857
1807
  while True:
1858
- yield state._as_state_update(handler, next(events), final=False)
1808
+ yield await state._as_state_update(
1809
+ handler, next(events), final=False
1810
+ )
1859
1811
  except StopIteration as si:
1860
1812
  # the "return" value of the generator is not available
1861
1813
  # in the loop, we must catch StopIteration to access it
1862
1814
  if si.value is not None:
1863
- yield state._as_state_update(handler, si.value, final=False)
1864
- yield state._as_state_update(handler, events=None, final=True)
1815
+ yield await state._as_state_update(
1816
+ handler, si.value, final=False
1817
+ )
1818
+ yield await state._as_state_update(handler, events=None, final=True)
1865
1819
 
1866
1820
  # Handle regular event chains.
1867
1821
  else:
1868
- yield state._as_state_update(handler, events, final=True)
1822
+ yield await state._as_state_update(handler, events, final=True)
1869
1823
 
1870
1824
  # If an error occurs, throw a window alert.
1871
1825
  except Exception as ex:
1872
1826
  telemetry.send_error(ex, context="backend")
1873
1827
 
1874
- app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
1875
-
1876
- event_specs = app_instance.backend_exception_handler(ex)
1828
+ event_specs = (
1829
+ prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
1830
+ )
1877
1831
 
1878
- yield state._as_state_update(
1832
+ yield await state._as_state_update(
1879
1833
  handler,
1880
1834
  event_specs,
1881
1835
  final=True,
@@ -1883,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1883
1837
 
1884
1838
  def _mark_dirty_computed_vars(self) -> None:
1885
1839
  """Mark ComputedVars that need to be recalculated based on dirty_vars."""
1840
+ # Append expired computed vars to dirty_vars to trigger recalculation
1841
+ self.dirty_vars.update(self._expired_computed_vars())
1842
+ # Append always dirty computed vars to dirty_vars to trigger recalculation
1843
+ self.dirty_vars.update(self._always_dirty_computed_vars)
1844
+
1886
1845
  dirty_vars = self.dirty_vars
1887
1846
  while dirty_vars:
1888
1847
  calc_vars, dirty_vars = dirty_vars, set()
1889
- for cvar in self._dirty_computed_vars(from_vars=calc_vars):
1890
- self.dirty_vars.add(cvar)
1848
+ for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
1849
+ if state_name == self.get_full_name():
1850
+ defining_state = self
1851
+ else:
1852
+ defining_state = self._get_root_state().get_substate(
1853
+ tuple(state_name.split("."))
1854
+ )
1855
+ defining_state.dirty_vars.add(cvar)
1891
1856
  dirty_vars.add(cvar)
1892
- actual_var = self.computed_vars.get(cvar)
1857
+ actual_var = defining_state.computed_vars.get(cvar)
1893
1858
  if actual_var is not None:
1894
- actual_var.mark_dirty(instance=self)
1859
+ actual_var.mark_dirty(instance=defining_state)
1860
+ if defining_state is not self:
1861
+ defining_state._mark_dirty()
1895
1862
 
1896
1863
  def _expired_computed_vars(self) -> set[str]:
1897
1864
  """Determine ComputedVars that need to be recalculated based on the expiration time.
@@ -1907,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1907
1874
 
1908
1875
  def _dirty_computed_vars(
1909
1876
  self, from_vars: set[str] | None = None, include_backend: bool = True
1910
- ) -> set[str]:
1877
+ ) -> set[tuple[str, str]]:
1911
1878
  """Determine ComputedVars that need to be recalculated based on the given vars.
1912
1879
 
1913
1880
  Args:
@@ -1918,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1918
1885
  Set of computed vars to include in the delta.
1919
1886
  """
1920
1887
  return {
1921
- cvar
1888
+ (state_name, cvar)
1922
1889
  for dirty_var in from_vars or self.dirty_vars
1923
- for cvar in self._computed_var_dependencies[dirty_var]
1890
+ for state_name, cvar in self._var_dependencies.get(dirty_var, set())
1924
1891
  if include_backend or not self.computed_vars[cvar]._backend
1925
1892
  }
1926
1893
 
1927
- @classmethod
1928
- def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
1929
- """Determine substates which could be affected by dirty vars in this state.
1930
-
1931
- Returns:
1932
- Set of State classes that may need to be fetched to recalc computed vars.
1933
- """
1934
- # _always_dirty_substates need to be fetched to recalc computed vars.
1935
- fetch_substates = {
1936
- cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
1937
- for substate_name in cls._always_dirty_substates
1938
- }
1939
- for dependent_substates in cls._substate_var_dependencies.values():
1940
- fetch_substates.update(
1941
- {
1942
- cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
1943
- for substate_name in dependent_substates
1944
- }
1945
- )
1946
- return fetch_substates
1947
-
1948
1894
  def get_delta(self) -> Delta:
1949
1895
  """Get the delta for the state.
1950
1896
 
@@ -1953,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1953
1899
  """
1954
1900
  delta = {}
1955
1901
 
1956
- # Apply dirty variables down into substates
1957
- self.dirty_vars.update(self._always_dirty_computed_vars)
1958
- self._mark_dirty()
1959
-
1902
+ self._mark_dirty_computed_vars()
1960
1903
  frontend_computed_vars: set[str] = {
1961
1904
  name for name, cv in self.computed_vars.items() if not cv._backend
1962
1905
  }
1963
1906
 
1964
1907
  # Return the dirty vars for this instance, any cached/dependent computed vars,
1965
1908
  # and always dirty computed vars (cache=False)
1966
- delta_vars = (
1967
- self.dirty_vars.intersection(self.base_vars)
1968
- .union(self.dirty_vars.intersection(frontend_computed_vars))
1969
- .union(self._dirty_computed_vars(include_backend=False))
1970
- .union(self._always_dirty_computed_vars)
1909
+ delta_vars = self.dirty_vars.intersection(self.base_vars).union(
1910
+ self.dirty_vars.intersection(frontend_computed_vars)
1971
1911
  )
1972
1912
 
1973
1913
  subdelta: Dict[str, Any] = {
@@ -1997,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
1997
1937
  self.parent_state.dirty_substates.add(self.get_name())
1998
1938
  self.parent_state._mark_dirty()
1999
1939
 
2000
- # Append expired computed vars to dirty_vars to trigger recalculation
2001
- self.dirty_vars.update(self._expired_computed_vars())
2002
-
2003
1940
  # have to mark computed vars dirty to allow access to newly computed
2004
1941
  # values within the same ComputedVar function
2005
1942
  self._mark_dirty_computed_vars()
2006
- self._mark_dirty_substates()
2007
-
2008
- def _mark_dirty_substates(self):
2009
- """Propagate dirty var / computed var status into substates."""
2010
- substates = self.substates
2011
- for var in self.dirty_vars:
2012
- for substate_name in self._substate_var_dependencies[var]:
2013
- self.dirty_substates.add(substate_name)
2014
- substate = substates[substate_name]
2015
- substate.dirty_vars.add(var)
2016
- substate._mark_dirty()
2017
1943
 
2018
1944
  def _update_was_touched(self):
2019
1945
  """Update the _was_touched flag based on dirty_vars."""
@@ -2085,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
2085
2011
  The object as a dictionary.
2086
2012
  """
2087
2013
  if include_computed:
2088
- # Apply dirty variables down into substates to allow never-cached ComputedVar to
2089
- # trigger recalculation of dependent vars
2090
- self.dirty_vars.update(self._always_dirty_computed_vars)
2091
- self._mark_dirty()
2092
-
2014
+ self._mark_dirty_computed_vars()
2093
2015
  base_vars = {
2094
2016
  prop_name: self.get_value(prop_name) for prop_name in self.base_vars
2095
2017
  }
@@ -2316,6 +2238,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
2316
2238
  return state
2317
2239
 
2318
2240
 
2241
+ T_STATE = TypeVar("T_STATE", bound=BaseState)
2242
+
2243
+
2319
2244
  class State(BaseState):
2320
2245
  """The app Base State."""
2321
2246
 
@@ -2336,8 +2261,7 @@ def dynamic(func: Callable[[T], Component]):
2336
2261
  The dynamically generated component.
2337
2262
 
2338
2263
  Raises:
2339
- DynamicComponentInvalidSignature: If the function does not have exactly one parameter.
2340
- DynamicComponentInvalidSignature: If the function does not have a type hint for the state class.
2264
+ DynamicComponentInvalidSignatureError: If the function does not have exactly one parameter or a type hint for the state class.
2341
2265
  """
2342
2266
  number_of_parameters = len(inspect.signature(func).parameters)
2343
2267
 
@@ -2349,12 +2273,12 @@ def dynamic(func: Callable[[T], Component]):
2349
2273
  values = list(func_signature.values())
2350
2274
 
2351
2275
  if number_of_parameters != 1:
2352
- raise DynamicComponentInvalidSignature(
2276
+ raise DynamicComponentInvalidSignatureError(
2353
2277
  "The function must have exactly one parameter, which is the state class."
2354
2278
  )
2355
2279
 
2356
2280
  if len(values) != 1:
2357
- raise DynamicComponentInvalidSignature(
2281
+ raise DynamicComponentInvalidSignatureError(
2358
2282
  "You must provide a type hint for the state class in the function."
2359
2283
  )
2360
2284
 
@@ -2383,8 +2307,9 @@ class FrontendEventExceptionState(State):
2383
2307
  component_stack: The stack trace of the component where the exception occurred.
2384
2308
 
2385
2309
  """
2386
- app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
2387
- app_instance.frontend_exception_handler(Exception(stack))
2310
+ prerequisites.get_and_validate_app().app.frontend_exception_handler(
2311
+ Exception(stack)
2312
+ )
2388
2313
 
2389
2314
 
2390
2315
  class UpdateVarsInternalState(State):
@@ -2422,19 +2347,20 @@ class OnLoadInternalState(State):
2422
2347
  The list of events to queue for on load handling.
2423
2348
  """
2424
2349
  # Do not app._compile()! It should be already compiled by now.
2425
- app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
2426
- load_events = app.get_load_events(self.router.page.path)
2350
+ load_events = prerequisites.get_and_validate_app().app.get_load_events(
2351
+ self.router.page.path
2352
+ )
2427
2353
  if not load_events:
2428
2354
  self.is_hydrated = True
2429
2355
  return # Fast path for navigation with no on_load events defined.
2430
2356
  self.is_hydrated = False
2431
2357
  return [
2432
2358
  *fix_events(
2433
- load_events,
2359
+ cast(list[Union[EventSpec, EventHandler]], load_events),
2434
2360
  self.router.session.client_token,
2435
2361
  router_data=self.router_data,
2436
2362
  ),
2437
- State.set_is_hydrated(True), # type: ignore
2363
+ State.set_is_hydrated(True), # pyright: ignore [reportAttributeAccessIssue]
2438
2364
  ]
2439
2365
 
2440
2366
 
@@ -2575,7 +2501,9 @@ class StateProxy(wrapt.ObjectProxy):
2575
2501
  """
2576
2502
 
2577
2503
  def __init__(
2578
- self, state_instance, parent_state_proxy: Optional["StateProxy"] = None
2504
+ self,
2505
+ state_instance: BaseState,
2506
+ parent_state_proxy: Optional["StateProxy"] = None,
2579
2507
  ):
2580
2508
  """Create a proxy for a state instance.
2581
2509
 
@@ -2589,7 +2517,7 @@ class StateProxy(wrapt.ObjectProxy):
2589
2517
  """
2590
2518
  super().__init__(state_instance)
2591
2519
  # compile is not relevant to backend logic
2592
- self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
2520
+ self._self_app = prerequisites.get_and_validate_app().app
2593
2521
  self._self_substate_path = tuple(state_instance.get_full_name().split("."))
2594
2522
  self._self_actx = None
2595
2523
  self._self_mutable = False
@@ -2718,7 +2646,7 @@ class StateProxy(wrapt.ObjectProxy):
2718
2646
  # ensure mutations to these containers are blocked unless proxy is _mutable
2719
2647
  return ImmutableMutableProxy(
2720
2648
  wrapped=value.__wrapped__,
2721
- state=self, # type: ignore
2649
+ state=self,
2722
2650
  field_name=value._self_field_name,
2723
2651
  )
2724
2652
  if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
@@ -2731,7 +2659,7 @@ class StateProxy(wrapt.ObjectProxy):
2731
2659
  )
2732
2660
  if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
2733
2661
  # Rebind methods to the proxy instance
2734
- value = type(value)(value.__func__, self) # type: ignore
2662
+ value = type(value)(value.__func__, self)
2735
2663
  return value
2736
2664
 
2737
2665
  def __setattr__(self, name: str, value: Any) -> None:
@@ -2800,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy):
2800
2728
  await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
2801
2729
  )
2802
2730
 
2803
- def _as_state_update(self, *args, **kwargs) -> StateUpdate:
2731
+ async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
2804
2732
  """Temporarily allow mutability to access parent_state.
2805
2733
 
2806
2734
  Args:
@@ -2813,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy):
2813
2741
  original_mutable = self._self_mutable
2814
2742
  self._self_mutable = True
2815
2743
  try:
2816
- return self.__wrapped__._as_state_update(*args, **kwargs)
2744
+ return await self.__wrapped__._as_state_update(*args, **kwargs)
2817
2745
  finally:
2818
2746
  self._self_mutable = original_mutable
2819
2747
 
@@ -2856,7 +2784,7 @@ class StateManager(Base, ABC):
2856
2784
  state: The state class to use.
2857
2785
 
2858
2786
  Raises:
2859
- InvalidStateManagerMode: If the state manager mode is invalid.
2787
+ InvalidStateManagerModeError: If the state manager mode is invalid.
2860
2788
 
2861
2789
  Returns:
2862
2790
  The state manager (either disk, memory or redis).
@@ -2879,7 +2807,7 @@ class StateManager(Base, ABC):
2879
2807
  lock_expiration=config.redis_lock_expiration,
2880
2808
  lock_warning_threshold=config.redis_lock_warning_threshold,
2881
2809
  )
2882
- raise InvalidStateManagerMode(
2810
+ raise InvalidStateManagerModeError(
2883
2811
  f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
2884
2812
  )
2885
2813
 
@@ -2931,7 +2859,7 @@ class StateManagerMemory(StateManager):
2931
2859
  # The dict of mutexes for each client
2932
2860
  _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
2933
2861
 
2934
- class Config:
2862
+ class Config: # pyright: ignore [reportIncompatibleVariableOverride]
2935
2863
  """The Pydantic config."""
2936
2864
 
2937
2865
  fields = {
@@ -3028,7 +2956,7 @@ def is_serializable(value: Any) -> bool:
3028
2956
 
3029
2957
  def reset_disk_state_manager():
3030
2958
  """Reset the disk state manager."""
3031
- states_directory = prerequisites.get_web_dir() / constants.Dirs.STATES
2959
+ states_directory = prerequisites.get_states_dir()
3032
2960
  if states_directory.exists():
3033
2961
  for path in states_directory.iterdir():
3034
2962
  path.unlink()
@@ -3049,7 +2977,7 @@ class StateManagerDisk(StateManager):
3049
2977
  # The token expiration time (s).
3050
2978
  token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
3051
2979
 
3052
- class Config:
2980
+ class Config: # pyright: ignore [reportIncompatibleVariableOverride]
3053
2981
  """The Pydantic config."""
3054
2982
 
3055
2983
  fields = {
@@ -3076,7 +3004,7 @@ class StateManagerDisk(StateManager):
3076
3004
  Returns:
3077
3005
  The states directory.
3078
3006
  """
3079
- return prerequisites.get_web_dir() / constants.Dirs.STATES
3007
+ return prerequisites.get_states_dir()
3080
3008
 
3081
3009
  def _purge_expired_states(self):
3082
3010
  """Purge expired states from the disk."""
@@ -3289,103 +3217,106 @@ class StateManagerRedis(StateManager):
3289
3217
  b"evicted",
3290
3218
  }
3291
3219
 
3292
- async def _get_parent_state(
3293
- self, token: str, state: BaseState | None = None
3294
- ) -> BaseState | None:
3295
- """Get the parent state for the state requested in the token.
3220
+ def _get_required_state_classes(
3221
+ self,
3222
+ target_state_cls: Type[BaseState],
3223
+ subclasses: bool = False,
3224
+ required_state_classes: set[Type[BaseState]] | None = None,
3225
+ ) -> set[Type[BaseState]]:
3226
+ """Recursively determine which states are required to fetch the target state.
3227
+
3228
+ This will always include potentially dirty substates that depend on vars
3229
+ in the target_state_cls.
3296
3230
 
3297
3231
  Args:
3298
- token: The token to get the state for (_substate_key).
3299
- state: The state instance to get parent state for.
3232
+ target_state_cls: The target state class being fetched.
3233
+ subclasses: Whether to include subclasses of the target state.
3234
+ required_state_classes: Recursive argument tracking state classes that have already been seen.
3300
3235
 
3301
3236
  Returns:
3302
- The parent state for the state requested by the token or None if there is no such parent.
3303
- """
3304
- parent_state = None
3305
- client_token, state_path = _split_substate_key(token)
3306
- parent_state_name = state_path.rpartition(".")[0]
3307
- if parent_state_name:
3308
- cached_substates = None
3309
- if state is not None:
3310
- cached_substates = [state]
3311
- # Retrieve the parent state to populate event handlers onto this substate.
3312
- parent_state = await self.get_state(
3313
- token=_substate_key(client_token, parent_state_name),
3314
- top_level=False,
3315
- get_substates=False,
3316
- cached_substates=cached_substates,
3237
+ The set of state classes required to fetch the target state.
3238
+ """
3239
+ if required_state_classes is None:
3240
+ required_state_classes = set()
3241
+ # Get the substates if requested.
3242
+ if subclasses:
3243
+ for substate in target_state_cls.get_substates():
3244
+ self._get_required_state_classes(
3245
+ substate,
3246
+ subclasses=True,
3247
+ required_state_classes=required_state_classes,
3248
+ )
3249
+ if target_state_cls in required_state_classes:
3250
+ return required_state_classes
3251
+ required_state_classes.add(target_state_cls)
3252
+
3253
+ # Get dependent substates.
3254
+ for pd_substates in target_state_cls._get_potentially_dirty_states():
3255
+ self._get_required_state_classes(
3256
+ pd_substates,
3257
+ subclasses=False,
3258
+ required_state_classes=required_state_classes,
3317
3259
  )
3318
- return parent_state
3319
3260
 
3320
- async def _populate_substates(
3321
- self,
3322
- token: str,
3323
- state: BaseState,
3324
- all_substates: bool = False,
3325
- ):
3326
- """Fetch and link substates for the given state instance.
3261
+ # Get the parent state if it exists.
3262
+ if parent_state := target_state_cls.get_parent_state():
3263
+ self._get_required_state_classes(
3264
+ parent_state,
3265
+ subclasses=False,
3266
+ required_state_classes=required_state_classes,
3267
+ )
3268
+ return required_state_classes
3327
3269
 
3328
- There is no return value; the side-effect is that `state` will have `substates` populated,
3329
- and each substate will have its `parent_state` set to `state`.
3270
+ def _get_populated_states(
3271
+ self,
3272
+ target_state: BaseState,
3273
+ populated_states: dict[str, BaseState] | None = None,
3274
+ ) -> dict[str, BaseState]:
3275
+ """Recursively determine which states from target_state are already fetched.
3330
3276
 
3331
3277
  Args:
3332
- token: The token to get the state for.
3333
- state: The state instance to populate substates for.
3334
- all_substates: Whether to fetch all substates or just required substates.
3335
- """
3336
- client_token, _ = _split_substate_key(token)
3337
-
3338
- if all_substates:
3339
- # All substates are requested.
3340
- fetch_substates = state.get_substates()
3341
- else:
3342
- # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
3343
- fetch_substates = state._potentially_dirty_substates()
3278
+ target_state: The state to check for populated states.
3279
+ populated_states: Recursive argument tracking states seen in previous calls.
3344
3280
 
3345
- tasks = {}
3346
- # Retrieve the necessary substates from redis.
3347
- for substate_cls in fetch_substates:
3348
- if substate_cls.get_name() in state.substates:
3349
- continue
3350
- substate_name = substate_cls.get_name()
3351
- tasks[substate_name] = asyncio.create_task(
3352
- self.get_state(
3353
- token=_substate_key(client_token, substate_cls),
3354
- top_level=False,
3355
- get_substates=all_substates,
3356
- parent_state=state,
3357
- )
3281
+ Returns:
3282
+ A dictionary of state full name to state instance.
3283
+ """
3284
+ if populated_states is None:
3285
+ populated_states = {}
3286
+ if target_state.get_full_name() in populated_states:
3287
+ return populated_states
3288
+ populated_states[target_state.get_full_name()] = target_state
3289
+ for substate in target_state.substates.values():
3290
+ self._get_populated_states(substate, populated_states=populated_states)
3291
+ if target_state.parent_state is not None:
3292
+ self._get_populated_states(
3293
+ target_state.parent_state, populated_states=populated_states
3358
3294
  )
3359
-
3360
- for substate_name, substate_task in tasks.items():
3361
- state.substates[substate_name] = await substate_task
3295
+ return populated_states
3362
3296
 
3363
3297
  @override
3364
3298
  async def get_state(
3365
3299
  self,
3366
3300
  token: str,
3367
3301
  top_level: bool = True,
3368
- get_substates: bool = True,
3369
- parent_state: BaseState | None = None,
3370
- cached_substates: list[BaseState] | None = None,
3302
+ for_state_instance: BaseState | None = None,
3371
3303
  ) -> BaseState:
3372
3304
  """Get the state for a token.
3373
3305
 
3374
3306
  Args:
3375
3307
  token: The token to get the state for.
3376
3308
  top_level: If true, return an instance of the top-level state (self.state).
3377
- get_substates: If true, also retrieve substates.
3378
- parent_state: If provided, use this parent_state instead of getting it from redis.
3379
- cached_substates: If provided, attach these substates to the state.
3309
+ for_state_instance: If provided, attach the requested states to this existing state tree.
3380
3310
 
3381
3311
  Returns:
3382
3312
  The state for the token.
3383
3313
 
3384
3314
  Raises:
3385
- RuntimeError: when the state_cls is not specified in the token
3315
+ RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
3316
+ requested state was not fetched.
3386
3317
  """
3387
3318
  # Split the actual token from the fully qualified substate name.
3388
- _, state_path = _split_substate_key(token)
3319
+ token, state_path = _split_substate_key(token)
3389
3320
  if state_path:
3390
3321
  # Get the State class associated with the given path.
3391
3322
  state_cls = self.state.get_class_substate(state_path)
@@ -3394,43 +3325,59 @@ class StateManagerRedis(StateManager):
3394
3325
  f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
3395
3326
  )
3396
3327
 
3397
- # The deserialized or newly created (sub)state instance.
3398
- state = None
3399
-
3400
- # Fetch the serialized substate from redis.
3401
- redis_state = await self.redis.get(token)
3402
-
3403
- if redis_state is not None:
3404
- # Deserialize the substate.
3405
- with contextlib.suppress(StateSchemaMismatchError):
3406
- state = BaseState._deserialize(data=redis_state)
3407
- if state is None:
3408
- # Key didn't exist or schema mismatch so create a new instance for this token.
3409
- state = state_cls(
3410
- init_substates=False,
3411
- _reflex_internal_init=True,
3412
- )
3413
- # Populate parent state if missing and requested.
3414
- if parent_state is None:
3415
- parent_state = await self._get_parent_state(token, state)
3416
- # Set up Bidirectional linkage between this state and its parent.
3417
- if parent_state is not None:
3418
- parent_state.substates[state.get_name()] = state
3419
- state.parent_state = parent_state
3420
- # Avoid fetching substates multiple times.
3421
- if cached_substates:
3422
- for substate in cached_substates:
3423
- state.substates[substate.get_name()] = substate
3424
- if substate.parent_state is None:
3425
- substate.parent_state = state
3426
- # Populate substates if requested.
3427
- await self._populate_substates(token, state, all_substates=get_substates)
3328
+ # Determine which states we already have.
3329
+ flat_state_tree: dict[str, BaseState] = (
3330
+ self._get_populated_states(for_state_instance) if for_state_instance else {}
3331
+ )
3332
+
3333
+ # Determine which states from the tree need to be fetched.
3334
+ required_state_classes = sorted(
3335
+ self._get_required_state_classes(state_cls, subclasses=True)
3336
+ - {type(s) for s in flat_state_tree.values()},
3337
+ key=lambda x: x.get_full_name(),
3338
+ )
3339
+
3340
+ redis_pipeline = self.redis.pipeline()
3341
+ for state_cls in required_state_classes:
3342
+ redis_pipeline.get(_substate_key(token, state_cls))
3343
+
3344
+ for state_cls, redis_state in zip(
3345
+ required_state_classes,
3346
+ await redis_pipeline.execute(),
3347
+ strict=False,
3348
+ ):
3349
+ state = None
3350
+
3351
+ if redis_state is not None:
3352
+ # Deserialize the substate.
3353
+ with contextlib.suppress(StateSchemaMismatchError):
3354
+ state = BaseState._deserialize(data=redis_state)
3355
+ if state is None:
3356
+ # Key didn't exist or schema mismatch so create a new instance for this token.
3357
+ state = state_cls(
3358
+ init_substates=False,
3359
+ _reflex_internal_init=True,
3360
+ )
3361
+ flat_state_tree[state.get_full_name()] = state
3362
+ if state.get_parent_state() is not None:
3363
+ parent_state_name, _dot, state_name = state.get_full_name().rpartition(
3364
+ "."
3365
+ )
3366
+ parent_state = flat_state_tree.get(parent_state_name)
3367
+ if parent_state is None:
3368
+ raise RuntimeError(
3369
+ f"Parent state for {state.get_full_name()} was not found "
3370
+ "in the state tree, but should have already been fetched. "
3371
+ "This is a bug",
3372
+ )
3373
+ parent_state.substates[state_name] = state
3374
+ state.parent_state = parent_state
3428
3375
 
3429
3376
  # To retain compatibility with previous implementation, by default, we return
3430
- # the top-level state by chasing `parent_state` pointers up the tree.
3377
+ # the top-level state which should always be fetched or already cached.
3431
3378
  if top_level:
3432
- return state._get_root_state()
3433
- return state
3379
+ return flat_state_tree[self.state.get_full_name()]
3380
+ return flat_state_tree[state_cls.get_full_name()]
3434
3381
 
3435
3382
  @override
3436
3383
  async def set_state(
@@ -3521,7 +3468,9 @@ class StateManagerRedis(StateManager):
3521
3468
 
3522
3469
  @validator("lock_warning_threshold")
3523
3470
  @classmethod
3524
- def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
3471
+ def validate_lock_warning_threshold(
3472
+ cls, lock_warning_threshold: int, values: dict[str, int]
3473
+ ):
3525
3474
  """Validate the lock warning threshold.
3526
3475
 
3527
3476
  Args:
@@ -3682,8 +3631,7 @@ def get_state_manager() -> StateManager:
3682
3631
  Returns:
3683
3632
  The state manager.
3684
3633
  """
3685
- app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
3686
- return app.state_manager
3634
+ return prerequisites.get_and_validate_app().app.state_manager
3687
3635
 
3688
3636
 
3689
3637
  class MutableProxy(wrapt.ObjectProxy):
@@ -3758,9 +3706,9 @@ class MutableProxy(wrapt.ObjectProxy):
3758
3706
  wrapper_cls_name,
3759
3707
  (cls,),
3760
3708
  {
3761
- dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
3709
+ dataclasses._FIELDS: getattr( # pyright: ignore [reportAttributeAccessIssue]
3762
3710
  wrapped_cls,
3763
- dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
3711
+ dataclasses._FIELDS, # pyright: ignore [reportAttributeAccessIssue]
3764
3712
  ),
3765
3713
  },
3766
3714
  )
@@ -3790,10 +3738,10 @@ class MutableProxy(wrapt.ObjectProxy):
3790
3738
 
3791
3739
  def _mark_dirty(
3792
3740
  self,
3793
- wrapped=None,
3794
- instance=None,
3795
- args=(),
3796
- kwargs=None,
3741
+ wrapped: Callable | None = None,
3742
+ instance: BaseState | None = None,
3743
+ args: tuple = (),
3744
+ kwargs: dict | None = None,
3797
3745
  ) -> Any:
3798
3746
  """Mark the state as dirty, then call a wrapped function.
3799
3747
 
@@ -3867,7 +3815,9 @@ class MutableProxy(wrapt.ObjectProxy):
3867
3815
  )
3868
3816
  return value
3869
3817
 
3870
- def _wrap_recursive_decorator(self, wrapped, instance, args, kwargs) -> Any:
3818
+ def _wrap_recursive_decorator(
3819
+ self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
3820
+ ) -> Any:
3871
3821
  """Wrap a function that returns a possibly mutable value.
3872
3822
 
3873
3823
  Intended for use with `FunctionWrapper` from the `wrapt` library.
@@ -3913,7 +3863,7 @@ class MutableProxy(wrapt.ObjectProxy):
3913
3863
  ):
3914
3864
  # Wrap methods called on Base subclasses, which might do _anything_
3915
3865
  return wrapt.FunctionWrapper(
3916
- functools.partial(value.__func__, self),
3866
+ functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess]
3917
3867
  self._wrap_recursive_decorator,
3918
3868
  )
3919
3869
 
@@ -3926,7 +3876,7 @@ class MutableProxy(wrapt.ObjectProxy):
3926
3876
 
3927
3877
  return value
3928
3878
 
3929
- def __getitem__(self, key) -> Any:
3879
+ def __getitem__(self, key: Any) -> Any:
3930
3880
  """Get the item on the proxied object and return a proxy if mutable.
3931
3881
 
3932
3882
  Args:
@@ -3949,7 +3899,7 @@ class MutableProxy(wrapt.ObjectProxy):
3949
3899
  # Recursively wrap mutable items retrieved through this proxy.
3950
3900
  yield self._wrap_recursive(value)
3951
3901
 
3952
- def __delattr__(self, name):
3902
+ def __delattr__(self, name: str):
3953
3903
  """Delete the attribute on the proxied object and mark state dirty.
3954
3904
 
3955
3905
  Args:
@@ -3957,7 +3907,7 @@ class MutableProxy(wrapt.ObjectProxy):
3957
3907
  """
3958
3908
  self._mark_dirty(super().__delattr__, args=(name,))
3959
3909
 
3960
- def __delitem__(self, key):
3910
+ def __delitem__(self, key: str):
3961
3911
  """Delete the item on the proxied object and mark state dirty.
3962
3912
 
3963
3913
  Args:
@@ -3965,7 +3915,7 @@ class MutableProxy(wrapt.ObjectProxy):
3965
3915
  """
3966
3916
  self._mark_dirty(super().__delitem__, args=(key,))
3967
3917
 
3968
- def __setitem__(self, key, value):
3918
+ def __setitem__(self, key: str, value: Any):
3969
3919
  """Set the item on the proxied object and mark state dirty.
3970
3920
 
3971
3921
  Args:
@@ -3974,7 +3924,7 @@ class MutableProxy(wrapt.ObjectProxy):
3974
3924
  """
3975
3925
  self._mark_dirty(super().__setitem__, args=(key, value))
3976
3926
 
3977
- def __setattr__(self, name, value):
3927
+ def __setattr__(self, name: str, value: Any):
3978
3928
  """Set the attribute on the proxied object and mark state dirty.
3979
3929
 
3980
3930
  If the attribute starts with "_self_", then the state is NOT marked
@@ -3998,7 +3948,7 @@ class MutableProxy(wrapt.ObjectProxy):
3998
3948
  """
3999
3949
  return copy.copy(self.__wrapped__)
4000
3950
 
4001
- def __deepcopy__(self, memo=None) -> Any:
3951
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
4002
3952
  """Return a deepcopy of the proxy.
4003
3953
 
4004
3954
  Args:
@@ -4009,7 +3959,7 @@ class MutableProxy(wrapt.ObjectProxy):
4009
3959
  """
4010
3960
  return copy.deepcopy(self.__wrapped__, memo=memo)
4011
3961
 
4012
- def __reduce_ex__(self, protocol_version):
3962
+ def __reduce_ex__(self, protocol_version: SupportsIndex):
4013
3963
  """Get the state for redis serialization.
4014
3964
 
4015
3965
  This method is called by cloudpickle to serialize the object.
@@ -4038,10 +3988,10 @@ def serialize_mutable_proxy(mp: MutableProxy):
4038
3988
  return mp.__wrapped__
4039
3989
 
4040
3990
 
4041
- _orig_json_JSONEncoder_default = json.JSONEncoder.default
3991
+ _orig_json_encoder_default = json.JSONEncoder.default
4042
3992
 
4043
3993
 
4044
- def _json_JSONEncoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
3994
+ def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
4045
3995
  """Wrap JSONEncoder.default to handle MutableProxy objects.
4046
3996
 
4047
3997
  Args:
@@ -4055,10 +4005,10 @@ def _json_JSONEncoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
4055
4005
  return o.__wrapped__
4056
4006
  except AttributeError:
4057
4007
  pass
4058
- return _orig_json_JSONEncoder_default(self, o)
4008
+ return _orig_json_encoder_default(self, o)
4059
4009
 
4060
4010
 
4061
- json.JSONEncoder.default = _json_JSONEncoder_default_wrapper
4011
+ json.JSONEncoder.default = _json_encoder_default_wrapper
4062
4012
 
4063
4013
 
4064
4014
  class ImmutableMutableProxy(MutableProxy):
@@ -4073,10 +4023,10 @@ class ImmutableMutableProxy(MutableProxy):
4073
4023
 
4074
4024
  def _mark_dirty(
4075
4025
  self,
4076
- wrapped=None,
4077
- instance=None,
4078
- args=(),
4079
- kwargs=None,
4026
+ wrapped: Callable | None = None,
4027
+ instance: BaseState | None = None,
4028
+ args: tuple = (),
4029
+ kwargs: dict | None = None,
4080
4030
  ) -> Any:
4081
4031
  """Raise an exception when an attempt is made to modify the object.
4082
4032
 
@@ -4127,12 +4077,19 @@ def reload_state_module(
4127
4077
  state: Recursive argument for the state class to reload.
4128
4078
 
4129
4079
  """
4080
+ # Clean out all potentially dirty states of reloaded modules.
4081
+ for pd_state in tuple(state._potentially_dirty_states):
4082
+ with contextlib.suppress(ValueError):
4083
+ if (
4084
+ state.get_root_state().get_class_substate(pd_state).__module__ == module
4085
+ and module is not None
4086
+ ):
4087
+ state._potentially_dirty_states.remove(pd_state)
4130
4088
  for subclass in tuple(state.class_subclasses):
4131
4089
  reload_state_module(module=module, state=subclass)
4132
4090
  if subclass.__module__ == module and module is not None:
4133
4091
  state.class_subclasses.remove(subclass)
4134
4092
  state._always_dirty_substates.discard(subclass.get_name())
4135
- state._computed_var_dependencies = defaultdict(set)
4136
- state._substate_var_dependencies = defaultdict(set)
4093
+ state._var_dependencies = {}
4137
4094
  state._init_var_dependency_dicts()
4138
4095
  state.get_class_substate.cache_clear()