reflex 0.7.9a1__py3-none-any.whl → 0.7.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflex might be problematic. Click here for more details.

reflex/state.py CHANGED
@@ -9,24 +9,19 @@ import copy
9
9
  import dataclasses
10
10
  import functools
11
11
  import inspect
12
- import json
13
12
  import pickle
14
13
  import sys
15
- import time
16
14
  import typing
17
- import uuid
18
15
  import warnings
19
- from abc import ABC, abstractmethod
16
+ from abc import ABC
20
17
  from collections.abc import AsyncIterator, Callable, Sequence
21
18
  from hashlib import md5
22
- from pathlib import Path
23
- from types import FunctionType, MethodType
19
+ from types import FunctionType
24
20
  from typing import (
25
21
  TYPE_CHECKING,
26
22
  Any,
27
23
  BinaryIO,
28
24
  ClassVar,
29
- SupportsIndex,
30
25
  TypeVar,
31
26
  cast,
32
27
  get_args,
@@ -34,22 +29,16 @@ from typing import (
34
29
  )
35
30
 
36
31
  import pydantic.v1 as pydantic
37
- import wrapt
38
32
  from pydantic import BaseModel as BaseModelV2
39
33
  from pydantic.v1 import BaseModel as BaseModelV1
40
- from pydantic.v1 import validator
41
34
  from pydantic.v1.fields import ModelField
42
- from redis.asyncio import Redis
43
- from redis.asyncio.client import PubSub
44
- from redis.exceptions import ResponseError
45
35
  from rich.markup import escape
46
- from sqlalchemy.orm import DeclarativeBase
47
36
  from typing_extensions import Self
48
37
 
49
38
  import reflex.istate.dynamic
50
39
  from reflex import constants, event
51
40
  from reflex.base import Base
52
- from reflex.config import PerformanceMode, environment, get_config
41
+ from reflex.config import PerformanceMode, environment
53
42
  from reflex.event import (
54
43
  BACKGROUND_TASK_MARKER,
55
44
  Event,
@@ -58,19 +47,17 @@ from reflex.event import (
58
47
  fix_events,
59
48
  )
60
49
  from reflex.istate.data import RouterData
50
+ from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy
51
+ from reflex.istate.proxy import MutableProxy, StateProxy
61
52
  from reflex.istate.storage import ClientStorageBase
62
53
  from reflex.model import Model
63
- from reflex.utils import console, format, path_ops, prerequisites, types
54
+ from reflex.utils import console, format, prerequisites, types
64
55
  from reflex.utils.exceptions import (
65
56
  ComputedVarShadowsBaseVarsError,
66
57
  ComputedVarShadowsStateVarError,
67
58
  DynamicComponentInvalidSignatureError,
68
59
  DynamicRouteArgShadowsStateVarError,
69
60
  EventHandlerShadowsBuiltInStateMethodError,
70
- ImmutableStateError,
71
- InvalidLockWarningThresholdError,
72
- InvalidStateManagerModeError,
73
- LockExpiredError,
74
61
  ReflexRuntimeError,
75
62
  SetUndefinedStateVarError,
76
63
  StateMismatchError,
@@ -79,13 +66,12 @@ from reflex.utils.exceptions import (
79
66
  StateTooLargeError,
80
67
  UnretrievableVarValueError,
81
68
  )
69
+ from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError
82
70
  from reflex.utils.exec import is_testing_env
83
- from reflex.utils.serializers import serializer
84
71
  from reflex.utils.types import (
85
72
  _isinstance,
86
73
  get_origin,
87
74
  is_union,
88
- override,
89
75
  true_type_for_pydantic_field,
90
76
  value_inside_optional,
91
77
  )
@@ -627,6 +613,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
627
613
 
628
614
  all_base_state_classes[cls.get_full_name()] = None
629
615
 
616
+ @classmethod
617
+ def _add_event_handler(
618
+ cls,
619
+ name: str,
620
+ fn: Callable,
621
+ ):
622
+ """Add an event handler dynamically to the state.
623
+
624
+ Args:
625
+ name: The name of the event handler.
626
+ fn: The function to call when the event is triggered.
627
+ """
628
+ handler = cls._create_event_handler(fn)
629
+ cls.event_handlers[name] = handler
630
+ setattr(cls, name, handler)
631
+
630
632
  @staticmethod
631
633
  def _copy_fn(fn: Callable) -> Callable:
632
634
  """Copy a function. Used to copy ComputedVars and EventHandlers from mixins.
@@ -2268,6 +2270,35 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
2268
2270
  return state
2269
2271
 
2270
2272
 
2273
+ def _serialize_type(type_: Any) -> str:
2274
+ """Serialize a type.
2275
+
2276
+ Args:
2277
+ type_: The type to serialize.
2278
+
2279
+ Returns:
2280
+ The serialized type.
2281
+ """
2282
+ if not inspect.isclass(type_):
2283
+ return f"{type_}"
2284
+ return f"{type_.__module__}.{type_.__qualname__}"
2285
+
2286
+
2287
+ def is_serializable(value: Any) -> bool:
2288
+ """Check if a value is serializable.
2289
+
2290
+ Args:
2291
+ value: The value to check.
2292
+
2293
+ Returns:
2294
+ Whether the value is serializable.
2295
+ """
2296
+ try:
2297
+ return bool(pickle.dumps(value))
2298
+ except Exception:
2299
+ return False
2300
+
2301
+
2271
2302
  T_STATE = TypeVar("T_STATE", bound=BaseState)
2272
2303
 
2273
2304
 
@@ -2507,278 +2538,6 @@ class ComponentState(State, mixin=True):
2507
2538
  return component
2508
2539
 
2509
2540
 
2510
- class StateProxy(wrapt.ObjectProxy):
2511
- """Proxy of a state instance to control mutability of vars for a background task.
2512
-
2513
- Since a background task runs against a state instance without holding the
2514
- state_manager lock for the token, the reference may become stale if the same
2515
- state is modified by another event handler.
2516
-
2517
- The proxy object ensures that writes to the state are blocked unless
2518
- explicitly entering a context which refreshes the state from state_manager
2519
- and holds the lock for the token until exiting the context. After exiting
2520
- the context, a StateUpdate may be emitted to the frontend to notify the
2521
- client of the state change.
2522
-
2523
- A background task will be passed the `StateProxy` as `self`, so mutability
2524
- can be safely performed inside an `async with self` block.
2525
-
2526
- class State(rx.State):
2527
- counter: int = 0
2528
-
2529
- @rx.event(background=True)
2530
- async def bg_increment(self):
2531
- await asyncio.sleep(1)
2532
- async with self:
2533
- self.counter += 1
2534
- """
2535
-
2536
- def __init__(
2537
- self,
2538
- state_instance: BaseState,
2539
- parent_state_proxy: StateProxy | None = None,
2540
- ):
2541
- """Create a proxy for a state instance.
2542
-
2543
- If `get_state` is used on a StateProxy, the resulting state will be
2544
- linked to the given state via parent_state_proxy. The first state in the
2545
- chain is the state that initiated the background task.
2546
-
2547
- Args:
2548
- state_instance: The state instance to proxy.
2549
- parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
2550
- """
2551
- super().__init__(state_instance)
2552
- # compile is not relevant to backend logic
2553
- self._self_app = prerequisites.get_and_validate_app().app
2554
- self._self_substate_path = tuple(state_instance.get_full_name().split("."))
2555
- self._self_actx = None
2556
- self._self_mutable = False
2557
- self._self_actx_lock = asyncio.Lock()
2558
- self._self_actx_lock_holder = None
2559
- self._self_parent_state_proxy = parent_state_proxy
2560
-
2561
- def _is_mutable(self) -> bool:
2562
- """Check if the state is mutable.
2563
-
2564
- Returns:
2565
- Whether the state is mutable.
2566
- """
2567
- if self._self_parent_state_proxy is not None:
2568
- return self._self_parent_state_proxy._is_mutable() or self._self_mutable
2569
- return self._self_mutable
2570
-
2571
- async def __aenter__(self) -> StateProxy:
2572
- """Enter the async context manager protocol.
2573
-
2574
- Sets mutability to True and enters the `App.modify_state` async context,
2575
- which refreshes the state from state_manager and holds the lock for the
2576
- given state token until exiting the context.
2577
-
2578
- Background tasks should avoid blocking calls while inside the context.
2579
-
2580
- Returns:
2581
- This StateProxy instance in mutable mode.
2582
-
2583
- Raises:
2584
- ImmutableStateError: If the state is already mutable.
2585
- """
2586
- if self._self_parent_state_proxy is not None:
2587
- parent_state = (
2588
- await self._self_parent_state_proxy.__aenter__()
2589
- ).__wrapped__
2590
- super().__setattr__(
2591
- "__wrapped__",
2592
- await parent_state.get_state(
2593
- State.get_class_substate(self._self_substate_path)
2594
- ),
2595
- )
2596
- return self
2597
- current_task = asyncio.current_task()
2598
- if (
2599
- self._self_actx_lock.locked()
2600
- and current_task == self._self_actx_lock_holder
2601
- ):
2602
- raise ImmutableStateError(
2603
- "The state is already mutable. Do not nest `async with self` blocks."
2604
- )
2605
- await self._self_actx_lock.acquire()
2606
- self._self_actx_lock_holder = current_task
2607
- self._self_actx = self._self_app.modify_state(
2608
- token=_substate_key(
2609
- self.__wrapped__.router.session.client_token,
2610
- self._self_substate_path,
2611
- )
2612
- )
2613
- mutable_state = await self._self_actx.__aenter__()
2614
- super().__setattr__(
2615
- "__wrapped__", mutable_state.get_substate(self._self_substate_path)
2616
- )
2617
- self._self_mutable = True
2618
- return self
2619
-
2620
- async def __aexit__(self, *exc_info: Any) -> None:
2621
- """Exit the async context manager protocol.
2622
-
2623
- Sets proxy mutability to False and persists any state changes.
2624
-
2625
- Args:
2626
- exc_info: The exception info tuple.
2627
- """
2628
- if self._self_parent_state_proxy is not None:
2629
- await self._self_parent_state_proxy.__aexit__(*exc_info)
2630
- return
2631
- if self._self_actx is None:
2632
- return
2633
- self._self_mutable = False
2634
- try:
2635
- await self._self_actx.__aexit__(*exc_info)
2636
- finally:
2637
- self._self_actx_lock_holder = None
2638
- self._self_actx_lock.release()
2639
- self._self_actx = None
2640
-
2641
- def __enter__(self):
2642
- """Enter the regular context manager protocol.
2643
-
2644
- This is not supported for background tasks, and exists only to raise a more useful exception
2645
- when the StateProxy is used incorrectly.
2646
-
2647
- Raises:
2648
- TypeError: always, because only async contextmanager protocol is supported.
2649
- """
2650
- raise TypeError("Background task must use `async with self` to modify state.")
2651
-
2652
- def __exit__(self, *exc_info: Any) -> None:
2653
- """Exit the regular context manager protocol.
2654
-
2655
- Args:
2656
- exc_info: The exception info tuple.
2657
- """
2658
- pass
2659
-
2660
- def __getattr__(self, name: str) -> Any:
2661
- """Get the attribute from the underlying state instance.
2662
-
2663
- Args:
2664
- name: The name of the attribute.
2665
-
2666
- Returns:
2667
- The value of the attribute.
2668
-
2669
- Raises:
2670
- ImmutableStateError: If the state is not in mutable mode.
2671
- """
2672
- if name in ["substates", "parent_state"] and not self._is_mutable():
2673
- raise ImmutableStateError(
2674
- "Background task StateProxy is immutable outside of a context "
2675
- "manager. Use `async with self` to modify state."
2676
- )
2677
- value = super().__getattr__(name)
2678
- if not name.startswith("_self_") and isinstance(value, MutableProxy):
2679
- # ensure mutations to these containers are blocked unless proxy is _mutable
2680
- return ImmutableMutableProxy(
2681
- wrapped=value.__wrapped__,
2682
- state=self,
2683
- field_name=value._self_field_name,
2684
- )
2685
- if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
2686
- # Rebind event handler to the proxy instance
2687
- value = functools.partial(
2688
- value.func,
2689
- self,
2690
- *value.args[1:],
2691
- **value.keywords,
2692
- )
2693
- if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
2694
- # Rebind methods to the proxy instance
2695
- value = type(value)(value.__func__, self)
2696
- return value
2697
-
2698
- def __setattr__(self, name: str, value: Any) -> None:
2699
- """Set the attribute on the underlying state instance.
2700
-
2701
- If the attribute is internal, set it on the proxy instance instead.
2702
-
2703
- Args:
2704
- name: The name of the attribute.
2705
- value: The value of the attribute.
2706
-
2707
- Raises:
2708
- ImmutableStateError: If the state is not in mutable mode.
2709
- """
2710
- if (
2711
- name.startswith("_self_") # wrapper attribute
2712
- or self._is_mutable() # lock held
2713
- # non-persisted state attribute
2714
- or name in self.__wrapped__.get_skip_vars()
2715
- ):
2716
- super().__setattr__(name, value)
2717
- return
2718
-
2719
- raise ImmutableStateError(
2720
- "Background task StateProxy is immutable outside of a context "
2721
- "manager. Use `async with self` to modify state."
2722
- )
2723
-
2724
- def get_substate(self, path: Sequence[str]) -> BaseState:
2725
- """Only allow substate access with lock held.
2726
-
2727
- Args:
2728
- path: The path to the substate.
2729
-
2730
- Returns:
2731
- The substate.
2732
-
2733
- Raises:
2734
- ImmutableStateError: If the state is not in mutable mode.
2735
- """
2736
- if not self._is_mutable():
2737
- raise ImmutableStateError(
2738
- "Background task StateProxy is immutable outside of a context "
2739
- "manager. Use `async with self` to modify state."
2740
- )
2741
- return self.__wrapped__.get_substate(path)
2742
-
2743
- async def get_state(self, state_cls: type[BaseState]) -> BaseState:
2744
- """Get an instance of the state associated with this token.
2745
-
2746
- Args:
2747
- state_cls: The class of the state.
2748
-
2749
- Returns:
2750
- The state.
2751
-
2752
- Raises:
2753
- ImmutableStateError: If the state is not in mutable mode.
2754
- """
2755
- if not self._is_mutable():
2756
- raise ImmutableStateError(
2757
- "Background task StateProxy is immutable outside of a context "
2758
- "manager. Use `async with self` to modify state."
2759
- )
2760
- return type(self)(
2761
- await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
2762
- )
2763
-
2764
- async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
2765
- """Temporarily allow mutability to access parent_state.
2766
-
2767
- Args:
2768
- *args: The args to pass to the underlying state instance.
2769
- **kwargs: The kwargs to pass to the underlying state instance.
2770
-
2771
- Returns:
2772
- The state update.
2773
- """
2774
- original_mutable = self._self_mutable
2775
- self._self_mutable = True
2776
- try:
2777
- return await self.__wrapped__._as_state_update(*args, **kwargs)
2778
- finally:
2779
- self._self_mutable = original_mutable
2780
-
2781
-
2782
2541
  @dataclasses.dataclass(
2783
2542
  frozen=True,
2784
2543
  )
@@ -2803,1344 +2562,54 @@ class StateUpdate:
2803
2562
  return format.json_dumps(self)
2804
2563
 
2805
2564
 
2806
- class StateManager(Base, ABC):
2807
- """A class to manage many client states."""
2808
-
2809
- # The state class to use.
2810
- state: type[BaseState]
2811
-
2812
- @classmethod
2813
- def create(cls, state: type[BaseState]):
2814
- """Create a new state manager.
2815
-
2816
- Args:
2817
- state: The state class to use.
2818
-
2819
- Raises:
2820
- InvalidStateManagerModeError: If the state manager mode is invalid.
2821
-
2822
- Returns:
2823
- The state manager (either disk, memory or redis).
2824
- """
2825
- config = get_config()
2826
- if prerequisites.parse_redis_url() is not None:
2827
- config.state_manager_mode = constants.StateManagerMode.REDIS
2828
- if config.state_manager_mode == constants.StateManagerMode.MEMORY:
2829
- return StateManagerMemory(state=state)
2830
- if config.state_manager_mode == constants.StateManagerMode.DISK:
2831
- return StateManagerDisk(state=state)
2832
- if config.state_manager_mode == constants.StateManagerMode.REDIS:
2833
- redis = prerequisites.get_redis()
2834
- if redis is not None:
2835
- # make sure expiration values are obtained only from the config object on creation
2836
- return StateManagerRedis(
2837
- state=state,
2838
- redis=redis,
2839
- token_expiration=config.redis_token_expiration,
2840
- lock_expiration=config.redis_lock_expiration,
2841
- lock_warning_threshold=config.redis_lock_warning_threshold,
2842
- )
2843
- raise InvalidStateManagerModeError(
2844
- f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
2845
- )
2846
-
2847
- @abstractmethod
2848
- async def get_state(self, token: str) -> BaseState:
2849
- """Get the state for a token.
2850
-
2851
- Args:
2852
- token: The token to get the state for.
2853
-
2854
- Returns:
2855
- The state for the token.
2856
- """
2857
- pass
2858
-
2859
- @abstractmethod
2860
- async def set_state(self, token: str, state: BaseState):
2861
- """Set the state for a token.
2862
-
2863
- Args:
2864
- token: The token to set the state for.
2865
- state: The state to set.
2866
- """
2867
- pass
2868
-
2869
- @abstractmethod
2870
- @contextlib.asynccontextmanager
2871
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
2872
- """Modify the state for a token while holding exclusive lock.
2873
-
2874
- Args:
2875
- token: The token to modify the state for.
2876
-
2877
- Yields:
2878
- The state for the token.
2879
- """
2880
- yield self.state()
2881
-
2882
-
2883
- class StateManagerMemory(StateManager):
2884
- """A state manager that stores states in memory."""
2885
-
2886
- # The mapping of client ids to states.
2887
- states: dict[str, BaseState] = {}
2888
-
2889
- # The mutex ensures the dict of mutexes is updated exclusively
2890
- _state_manager_lock = asyncio.Lock()
2891
-
2892
- # The dict of mutexes for each client
2893
- _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
2894
-
2895
- class Config: # pyright: ignore [reportIncompatibleVariableOverride]
2896
- """The Pydantic config."""
2897
-
2898
- fields = {
2899
- "_states_locks": {"exclude": True},
2900
- }
2901
-
2902
- @override
2903
- async def get_state(self, token: str) -> BaseState:
2904
- """Get the state for a token.
2905
-
2906
- Args:
2907
- token: The token to get the state for.
2908
-
2909
- Returns:
2910
- The state for the token.
2911
- """
2912
- # Memory state manager ignores the substate suffix and always returns the top-level state.
2913
- token = _split_substate_key(token)[0]
2914
- if token not in self.states:
2915
- self.states[token] = self.state(_reflex_internal_init=True)
2916
- return self.states[token]
2917
-
2918
- @override
2919
- async def set_state(self, token: str, state: BaseState):
2920
- """Set the state for a token.
2921
-
2922
- Args:
2923
- token: The token to set the state for.
2924
- state: The state to set.
2925
- """
2926
- pass
2927
-
2928
- @override
2929
- @contextlib.asynccontextmanager
2930
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
2931
- """Modify the state for a token while holding exclusive lock.
2932
-
2933
- Args:
2934
- token: The token to modify the state for.
2935
-
2936
- Yields:
2937
- The state for the token.
2938
- """
2939
- # Memory state manager ignores the substate suffix and always returns the top-level state.
2940
- token = _split_substate_key(token)[0]
2941
- if token not in self._states_locks:
2942
- async with self._state_manager_lock:
2943
- if token not in self._states_locks:
2944
- self._states_locks[token] = asyncio.Lock()
2945
-
2946
- async with self._states_locks[token]:
2947
- state = await self.get_state(token)
2948
- yield state
2949
- await self.set_state(token, state)
2950
-
2565
+ def code_uses_state_contexts(javascript_code: str) -> bool:
2566
+ """Check if the rendered Javascript uses state contexts.
2951
2567
 
2952
- def _default_token_expiration() -> int:
2953
- """Get the default token expiration time.
2568
+ Args:
2569
+ javascript_code: The Javascript code to check.
2954
2570
 
2955
2571
  Returns:
2956
- The default token expiration time.
2572
+ True if the code attempts to access a member of StateContexts.
2957
2573
  """
2958
- return get_config().redis_token_expiration
2574
+ return bool("useContext(StateContexts" in javascript_code)
2959
2575
 
2960
2576
 
2961
- def _serialize_type(type_: Any) -> str:
2962
- """Serialize a type.
2577
+ def reload_state_module(
2578
+ module: str,
2579
+ state: type[BaseState] = State,
2580
+ ) -> None:
2581
+ """Reset rx.State subclasses to avoid conflict when reloading.
2963
2582
 
2964
2583
  Args:
2965
- type_: The type to serialize.
2584
+ module: The module to reload.
2585
+ state: Recursive argument for the state class to reload.
2966
2586
 
2967
- Returns:
2968
- The serialized type.
2969
2587
  """
2970
- if not inspect.isclass(type_):
2971
- return f"{type_}"
2972
- return f"{type_.__module__}.{type_.__qualname__}"
2973
-
2974
-
2975
- def is_serializable(value: Any) -> bool:
2976
- """Check if a value is serializable.
2588
+ # Clean out all potentially dirty states of reloaded modules.
2589
+ for pd_state in tuple(state._potentially_dirty_states):
2590
+ with contextlib.suppress(ValueError):
2591
+ if (
2592
+ state.get_root_state().get_class_substate(pd_state).__module__ == module
2593
+ and module is not None
2594
+ ):
2595
+ state._potentially_dirty_states.remove(pd_state)
2596
+ for subclass in tuple(state.class_subclasses):
2597
+ reload_state_module(module=module, state=subclass)
2598
+ if subclass.__module__ == module and module is not None:
2599
+ all_base_state_classes.pop(subclass.get_full_name(), None)
2600
+ state.class_subclasses.remove(subclass)
2601
+ state._always_dirty_substates.discard(subclass.get_name())
2602
+ state._var_dependencies = {}
2603
+ state._init_var_dependency_dicts()
2604
+ state.get_class_substate.cache_clear()
2977
2605
 
2978
- Args:
2979
- value: The value to check.
2980
2606
 
2981
- Returns:
2982
- Whether the value is serializable.
2983
- """
2984
- try:
2985
- return bool(pickle.dumps(value))
2986
- except Exception:
2987
- return False
2988
-
2989
-
2990
- def reset_disk_state_manager():
2991
- """Reset the disk state manager."""
2992
- states_directory = prerequisites.get_states_dir()
2993
- if states_directory.exists():
2994
- for path in states_directory.iterdir():
2995
- path.unlink()
2996
-
2997
-
2998
- class StateManagerDisk(StateManager):
2999
- """A state manager that stores states in memory."""
3000
-
3001
- # The mapping of client ids to states.
3002
- states: dict[str, BaseState] = {}
3003
-
3004
- # The mutex ensures the dict of mutexes is updated exclusively
3005
- _state_manager_lock = asyncio.Lock()
3006
-
3007
- # The dict of mutexes for each client
3008
- _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
3009
-
3010
- # The token expiration time (s).
3011
- token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
3012
-
3013
- class Config: # pyright: ignore [reportIncompatibleVariableOverride]
3014
- """The Pydantic config."""
3015
-
3016
- fields = {
3017
- "_states_locks": {"exclude": True},
3018
- }
3019
- keep_untouched = (functools.cached_property,)
3020
-
3021
- def __init__(self, state: type[BaseState]):
3022
- """Create a new state manager.
3023
-
3024
- Args:
3025
- state: The state class to use.
3026
- """
3027
- super().__init__(state=state)
3028
-
3029
- path_ops.mkdir(self.states_directory)
3030
-
3031
- self._purge_expired_states()
3032
-
3033
- @functools.cached_property
3034
- def states_directory(self) -> Path:
3035
- """Get the states directory.
3036
-
3037
- Returns:
3038
- The states directory.
3039
- """
3040
- return prerequisites.get_states_dir()
3041
-
3042
- def _purge_expired_states(self):
3043
- """Purge expired states from the disk."""
3044
- import time
3045
-
3046
- for path in path_ops.ls(self.states_directory):
3047
- # check path is a pickle file
3048
- if path.suffix != ".pkl":
3049
- continue
3050
-
3051
- # load last edited field from file
3052
- last_edited = path.stat().st_mtime
3053
-
3054
- # check if the file is older than the token expiration time
3055
- if time.time() - last_edited > self.token_expiration:
3056
- # remove the file
3057
- path.unlink()
3058
-
3059
- def token_path(self, token: str) -> Path:
3060
- """Get the path for a token.
3061
-
3062
- Args:
3063
- token: The token to get the path for.
3064
-
3065
- Returns:
3066
- The path for the token.
3067
- """
3068
- return (
3069
- self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
3070
- ).absolute()
3071
-
3072
- async def load_state(self, token: str) -> BaseState | None:
3073
- """Load a state object based on the provided token.
3074
-
3075
- Args:
3076
- token: The token used to identify the state object.
3077
-
3078
- Returns:
3079
- The loaded state object or None.
3080
- """
3081
- token_path = self.token_path(token)
3082
-
3083
- if token_path.exists():
3084
- try:
3085
- with token_path.open(mode="rb") as file:
3086
- return BaseState._deserialize(fp=file)
3087
- except Exception:
3088
- pass
3089
-
3090
- async def populate_substates(
3091
- self, client_token: str, state: BaseState, root_state: BaseState
3092
- ):
3093
- """Populate the substates of a state object.
3094
-
3095
- Args:
3096
- client_token: The client token.
3097
- state: The state object to populate.
3098
- root_state: The root state object.
3099
- """
3100
- for substate in state.get_substates():
3101
- substate_token = _substate_key(client_token, substate)
3102
-
3103
- fresh_instance = await root_state.get_state(substate)
3104
- instance = await self.load_state(substate_token)
3105
- if instance is not None:
3106
- # Ensure all substates exist, even if they weren't serialized previously.
3107
- instance.substates = fresh_instance.substates
3108
- else:
3109
- instance = fresh_instance
3110
- state.substates[substate.get_name()] = instance
3111
- instance.parent_state = state
3112
-
3113
- await self.populate_substates(client_token, instance, root_state)
3114
-
3115
- @override
3116
- async def get_state(
3117
- self,
3118
- token: str,
3119
- ) -> BaseState:
3120
- """Get the state for a token.
3121
-
3122
- Args:
3123
- token: The token to get the state for.
3124
-
3125
- Returns:
3126
- The state for the token.
3127
- """
3128
- client_token = _split_substate_key(token)[0]
3129
- root_state = self.states.get(client_token)
3130
- if root_state is not None:
3131
- # Retrieved state from memory.
3132
- return root_state
3133
-
3134
- # Deserialize root state from disk.
3135
- root_state = await self.load_state(_substate_key(client_token, self.state))
3136
- # Create a new root state tree with all substates instantiated.
3137
- fresh_root_state = self.state(_reflex_internal_init=True)
3138
- if root_state is None:
3139
- root_state = fresh_root_state
3140
- else:
3141
- # Ensure all substates exist, even if they were not serialized previously.
3142
- root_state.substates = fresh_root_state.substates
3143
- self.states[client_token] = root_state
3144
- await self.populate_substates(client_token, root_state, root_state)
3145
- return root_state
3146
-
3147
- async def set_state_for_substate(self, client_token: str, substate: BaseState):
3148
- """Set the state for a substate.
3149
-
3150
- Args:
3151
- client_token: The client token.
3152
- substate: The substate to set.
3153
- """
3154
- substate_token = _substate_key(client_token, substate)
3155
-
3156
- if substate._get_was_touched():
3157
- substate._was_touched = False # Reset the touched flag after serializing.
3158
- pickle_state = substate._serialize()
3159
- if pickle_state:
3160
- if not self.states_directory.exists():
3161
- self.states_directory.mkdir(parents=True, exist_ok=True)
3162
- self.token_path(substate_token).write_bytes(pickle_state)
3163
-
3164
- for substate_substate in substate.substates.values():
3165
- await self.set_state_for_substate(client_token, substate_substate)
3166
-
3167
- @override
3168
- async def set_state(self, token: str, state: BaseState):
3169
- """Set the state for a token.
3170
-
3171
- Args:
3172
- token: The token to set the state for.
3173
- state: The state to set.
3174
- """
3175
- client_token, substate = _split_substate_key(token)
3176
- await self.set_state_for_substate(client_token, state)
3177
-
3178
- @override
3179
- @contextlib.asynccontextmanager
3180
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
3181
- """Modify the state for a token while holding exclusive lock.
3182
-
3183
- Args:
3184
- token: The token to modify the state for.
3185
-
3186
- Yields:
3187
- The state for the token.
3188
- """
3189
- # Memory state manager ignores the substate suffix and always returns the top-level state.
3190
- client_token, substate = _split_substate_key(token)
3191
- if client_token not in self._states_locks:
3192
- async with self._state_manager_lock:
3193
- if client_token not in self._states_locks:
3194
- self._states_locks[client_token] = asyncio.Lock()
3195
-
3196
- async with self._states_locks[client_token]:
3197
- state = await self.get_state(token)
3198
- yield state
3199
- await self.set_state(token, state)
3200
-
3201
-
3202
- def _default_lock_expiration() -> int:
3203
- """Get the default lock expiration time.
3204
-
3205
- Returns:
3206
- The default lock expiration time.
3207
- """
3208
- return get_config().redis_lock_expiration
3209
-
3210
-
3211
- def _default_lock_warning_threshold() -> int:
3212
- """Get the default lock warning threshold.
3213
-
3214
- Returns:
3215
- The default lock warning threshold.
3216
- """
3217
- return get_config().redis_lock_warning_threshold
3218
-
3219
-
3220
- class StateManagerRedis(StateManager):
3221
- """A state manager that stores states in redis."""
3222
-
3223
- # The redis client to use.
3224
- redis: Redis
3225
-
3226
- # The token expiration time (s).
3227
- token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
3228
-
3229
- # The maximum time to hold a lock (ms).
3230
- lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
3231
-
3232
- # The maximum time to hold a lock (ms) before warning.
3233
- lock_warning_threshold: int = pydantic.Field(
3234
- default_factory=_default_lock_warning_threshold
3235
- )
3236
-
3237
- # The keyspace subscription string when redis is waiting for lock to be released.
3238
- _redis_notify_keyspace_events: str = (
3239
- "K" # Enable keyspace notifications (target a particular key)
3240
- "g" # For generic commands (DEL, EXPIRE, etc)
3241
- "x" # For expired events
3242
- "e" # For evicted events (i.e. maxmemory exceeded)
3243
- )
3244
-
3245
- # These events indicate that a lock is no longer held.
3246
- _redis_keyspace_lock_release_events: set[bytes] = {
3247
- b"del",
3248
- b"expire",
3249
- b"expired",
3250
- b"evicted",
3251
- }
3252
-
3253
- # Whether keyspace notifications have been enabled.
3254
- _redis_notify_keyspace_events_enabled: bool = False
3255
-
3256
- # The logical database number used by the redis client.
3257
- _redis_db: int = 0
3258
-
3259
- def _get_required_state_classes(
3260
- self,
3261
- target_state_cls: type[BaseState],
3262
- subclasses: bool = False,
3263
- required_state_classes: set[type[BaseState]] | None = None,
3264
- ) -> set[type[BaseState]]:
3265
- """Recursively determine which states are required to fetch the target state.
3266
-
3267
- This will always include potentially dirty substates that depend on vars
3268
- in the target_state_cls.
3269
-
3270
- Args:
3271
- target_state_cls: The target state class being fetched.
3272
- subclasses: Whether to include subclasses of the target state.
3273
- required_state_classes: Recursive argument tracking state classes that have already been seen.
3274
-
3275
- Returns:
3276
- The set of state classes required to fetch the target state.
3277
- """
3278
- if required_state_classes is None:
3279
- required_state_classes = set()
3280
- # Get the substates if requested.
3281
- if subclasses:
3282
- for substate in target_state_cls.get_substates():
3283
- self._get_required_state_classes(
3284
- substate,
3285
- subclasses=True,
3286
- required_state_classes=required_state_classes,
3287
- )
3288
- if target_state_cls in required_state_classes:
3289
- return required_state_classes
3290
- required_state_classes.add(target_state_cls)
3291
-
3292
- # Get dependent substates.
3293
- for pd_substates in target_state_cls._get_potentially_dirty_states():
3294
- self._get_required_state_classes(
3295
- pd_substates,
3296
- subclasses=False,
3297
- required_state_classes=required_state_classes,
3298
- )
3299
-
3300
- # Get the parent state if it exists.
3301
- if parent_state := target_state_cls.get_parent_state():
3302
- self._get_required_state_classes(
3303
- parent_state,
3304
- subclasses=False,
3305
- required_state_classes=required_state_classes,
3306
- )
3307
- return required_state_classes
3308
-
3309
- def _get_populated_states(
3310
- self,
3311
- target_state: BaseState,
3312
- populated_states: dict[str, BaseState] | None = None,
3313
- ) -> dict[str, BaseState]:
3314
- """Recursively determine which states from target_state are already fetched.
3315
-
3316
- Args:
3317
- target_state: The state to check for populated states.
3318
- populated_states: Recursive argument tracking states seen in previous calls.
3319
-
3320
- Returns:
3321
- A dictionary of state full name to state instance.
3322
- """
3323
- if populated_states is None:
3324
- populated_states = {}
3325
- if target_state.get_full_name() in populated_states:
3326
- return populated_states
3327
- populated_states[target_state.get_full_name()] = target_state
3328
- for substate in target_state.substates.values():
3329
- self._get_populated_states(substate, populated_states=populated_states)
3330
- if target_state.parent_state is not None:
3331
- self._get_populated_states(
3332
- target_state.parent_state, populated_states=populated_states
3333
- )
3334
- return populated_states
3335
-
3336
- @override
3337
- async def get_state(
3338
- self,
3339
- token: str,
3340
- top_level: bool = True,
3341
- for_state_instance: BaseState | None = None,
3342
- ) -> BaseState:
3343
- """Get the state for a token.
3344
-
3345
- Args:
3346
- token: The token to get the state for.
3347
- top_level: If true, return an instance of the top-level state (self.state).
3348
- for_state_instance: If provided, attach the requested states to this existing state tree.
3349
-
3350
- Returns:
3351
- The state for the token.
3352
-
3353
- Raises:
3354
- RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
3355
- requested state was not fetched.
3356
- """
3357
- # Split the actual token from the fully qualified substate name.
3358
- token, state_path = _split_substate_key(token)
3359
- if state_path:
3360
- # Get the State class associated with the given path.
3361
- state_cls = self.state.get_class_substate(state_path)
3362
- else:
3363
- raise RuntimeError(
3364
- f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
3365
- )
3366
-
3367
- # Determine which states we already have.
3368
- flat_state_tree: dict[str, BaseState] = (
3369
- self._get_populated_states(for_state_instance) if for_state_instance else {}
3370
- )
3371
-
3372
- # Determine which states from the tree need to be fetched.
3373
- required_state_classes = sorted(
3374
- self._get_required_state_classes(state_cls, subclasses=True)
3375
- - {type(s) for s in flat_state_tree.values()},
3376
- key=lambda x: x.get_full_name(),
3377
- )
3378
-
3379
- redis_pipeline = self.redis.pipeline()
3380
- for state_cls in required_state_classes:
3381
- redis_pipeline.get(_substate_key(token, state_cls))
3382
-
3383
- for state_cls, redis_state in zip(
3384
- required_state_classes,
3385
- await redis_pipeline.execute(),
3386
- strict=False,
3387
- ):
3388
- state = None
3389
-
3390
- if redis_state is not None:
3391
- # Deserialize the substate.
3392
- with contextlib.suppress(StateSchemaMismatchError):
3393
- state = BaseState._deserialize(data=redis_state)
3394
- if state is None:
3395
- # Key didn't exist or schema mismatch so create a new instance for this token.
3396
- state = state_cls(
3397
- init_substates=False,
3398
- _reflex_internal_init=True,
3399
- )
3400
- flat_state_tree[state.get_full_name()] = state
3401
- if state.get_parent_state() is not None:
3402
- parent_state_name, _dot, state_name = state.get_full_name().rpartition(
3403
- "."
3404
- )
3405
- parent_state = flat_state_tree.get(parent_state_name)
3406
- if parent_state is None:
3407
- raise RuntimeError(
3408
- f"Parent state for {state.get_full_name()} was not found "
3409
- "in the state tree, but should have already been fetched. "
3410
- "This is a bug",
3411
- )
3412
- parent_state.substates[state_name] = state
3413
- state.parent_state = parent_state
3414
-
3415
- # To retain compatibility with previous implementation, by default, we return
3416
- # the top-level state which should always be fetched or already cached.
3417
- if top_level:
3418
- return flat_state_tree[self.state.get_full_name()]
3419
- return flat_state_tree[state_cls.get_full_name()]
3420
-
3421
- @override
3422
- async def set_state(
3423
- self,
3424
- token: str,
3425
- state: BaseState,
3426
- lock_id: bytes | None = None,
3427
- ):
3428
- """Set the state for a token.
3429
-
3430
- Args:
3431
- token: The token to set the state for.
3432
- state: The state to set.
3433
- lock_id: If provided, the lock_key must be set to this value to set the state.
3434
-
3435
- Raises:
3436
- LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
3437
- RuntimeError: If the state instance doesn't match the state name in the token.
3438
- """
3439
- # Check that we're holding the lock.
3440
- if (
3441
- lock_id is not None
3442
- and await self.redis.get(self._lock_key(token)) != lock_id
3443
- ):
3444
- raise LockExpiredError(
3445
- f"Lock expired for token {token} while processing. Consider increasing "
3446
- f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
3447
- "or use `@rx.event(background=True)` decorator for long-running tasks."
3448
- )
3449
- elif lock_id is not None:
3450
- time_taken = self.lock_expiration / 1000 - (
3451
- await self.redis.ttl(self._lock_key(token))
3452
- )
3453
- if time_taken > self.lock_warning_threshold / 1000:
3454
- console.warn(
3455
- f"Lock for token {token} was held too long {time_taken=}s, "
3456
- f"use `@rx.event(background=True)` decorator for long-running tasks.",
3457
- dedupe=True,
3458
- )
3459
-
3460
- client_token, substate_name = _split_substate_key(token)
3461
- # If the substate name on the token doesn't match the instance name, it cannot have a parent.
3462
- if state.parent_state is not None and state.get_full_name() != substate_name:
3463
- raise RuntimeError(
3464
- f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
3465
- )
3466
-
3467
- # Recursively set_state on all known substates.
3468
- tasks = [
3469
- asyncio.create_task(
3470
- self.set_state(
3471
- _substate_key(client_token, substate),
3472
- substate,
3473
- lock_id,
3474
- )
3475
- )
3476
- for substate in state.substates.values()
3477
- ]
3478
- # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
3479
- if state._get_was_touched():
3480
- pickle_state = state._serialize()
3481
- if pickle_state:
3482
- await self.redis.set(
3483
- _substate_key(client_token, state),
3484
- pickle_state,
3485
- ex=self.token_expiration,
3486
- )
3487
-
3488
- # Wait for substates to be persisted.
3489
- for t in tasks:
3490
- await t
3491
-
3492
- @override
3493
- @contextlib.asynccontextmanager
3494
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
3495
- """Modify the state for a token while holding exclusive lock.
3496
-
3497
- Args:
3498
- token: The token to modify the state for.
3499
-
3500
- Yields:
3501
- The state for the token.
3502
- """
3503
- async with self._lock(token) as lock_id:
3504
- state = await self.get_state(token)
3505
- yield state
3506
- await self.set_state(token, state, lock_id)
3507
-
3508
- @validator("lock_warning_threshold")
3509
- @classmethod
3510
- def validate_lock_warning_threshold(
3511
- cls, lock_warning_threshold: int, values: dict[str, int]
3512
- ):
3513
- """Validate the lock warning threshold.
3514
-
3515
- Args:
3516
- lock_warning_threshold: The lock warning threshold.
3517
- values: The validated attributes.
3518
-
3519
- Returns:
3520
- The lock warning threshold.
3521
-
3522
- Raises:
3523
- InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
3524
- """
3525
- if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
3526
- raise InvalidLockWarningThresholdError(
3527
- f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
3528
- )
3529
- return lock_warning_threshold
3530
-
3531
- @staticmethod
3532
- def _lock_key(token: str) -> bytes:
3533
- """Get the redis key for a token's lock.
3534
-
3535
- Args:
3536
- token: The token to get the lock key for.
3537
-
3538
- Returns:
3539
- The redis lock key for the token.
3540
- """
3541
- # All substates share the same lock domain, so ignore any substate path suffix.
3542
- client_token = _split_substate_key(token)[0]
3543
- return f"{client_token}_lock".encode()
3544
-
3545
- async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
3546
- """Try to get a redis lock for a token.
3547
-
3548
- Args:
3549
- lock_key: The redis key for the lock.
3550
- lock_id: The ID of the lock.
3551
-
3552
- Returns:
3553
- True if the lock was obtained.
3554
- """
3555
- return await self.redis.set(
3556
- lock_key,
3557
- lock_id,
3558
- px=self.lock_expiration,
3559
- nx=True, # only set if it doesn't exist
3560
- )
3561
-
3562
- async def _get_pubsub_message(
3563
- self, pubsub: PubSub, timeout: float | None = None
3564
- ) -> None:
3565
- """Get lock release events from the pubsub.
3566
-
3567
- Args:
3568
- pubsub: The pubsub to get a message from.
3569
- timeout: Remaining time to wait for a message.
3570
-
3571
- Returns:
3572
- The message.
3573
- """
3574
- if timeout is None:
3575
- timeout = self.lock_expiration / 1000.0
3576
-
3577
- started = time.time()
3578
- message = await pubsub.get_message(
3579
- ignore_subscribe_messages=True,
3580
- timeout=timeout,
3581
- )
3582
- if (
3583
- message is None
3584
- or message["data"] not in self._redis_keyspace_lock_release_events
3585
- ):
3586
- remaining = timeout - (time.time() - started)
3587
- if remaining <= 0:
3588
- return
3589
- await self._get_pubsub_message(pubsub, timeout=remaining)
3590
-
3591
- async def _enable_keyspace_notifications(self):
3592
- """Enable keyspace notifications for the redis server.
3593
-
3594
- Raises:
3595
- ResponseError: when the keyspace config cannot be set.
3596
- """
3597
- if self._redis_notify_keyspace_events_enabled:
3598
- return
3599
- # Find out which logical database index is being used.
3600
- self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
3601
-
3602
- try:
3603
- await self.redis.config_set(
3604
- "notify-keyspace-events",
3605
- self._redis_notify_keyspace_events,
3606
- )
3607
- except ResponseError:
3608
- # Some redis servers only allow out-of-band configuration, so ignore errors here.
3609
- if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
3610
- raise
3611
- self._redis_notify_keyspace_events_enabled = True
3612
-
3613
- async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
3614
- """Wait for a redis lock to be released via pubsub.
3615
-
3616
- Coroutine will not return until the lock is obtained.
3617
-
3618
- Args:
3619
- lock_key: The redis key for the lock.
3620
- lock_id: The ID of the lock.
3621
- """
3622
- # Enable keyspace notifications for the lock key, so we know when it is available.
3623
- await self._enable_keyspace_notifications()
3624
- lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
3625
- async with self.redis.pubsub() as pubsub:
3626
- await pubsub.psubscribe(lock_key_channel)
3627
- # wait for the lock to be released
3628
- while True:
3629
- # fast path
3630
- if await self._try_get_lock(lock_key, lock_id):
3631
- return
3632
- # wait for lock events
3633
- await self._get_pubsub_message(pubsub)
3634
-
3635
- @contextlib.asynccontextmanager
3636
- async def _lock(self, token: str):
3637
- """Obtain a redis lock for a token.
3638
-
3639
- Args:
3640
- token: The token to obtain a lock for.
3641
-
3642
- Yields:
3643
- The ID of the lock (to be passed to set_state).
3644
-
3645
- Raises:
3646
- LockExpiredError: If the lock has expired while processing the event.
3647
- """
3648
- lock_key = self._lock_key(token)
3649
- lock_id = uuid.uuid4().hex.encode()
3650
-
3651
- if not await self._try_get_lock(lock_key, lock_id):
3652
- # Missed the fast-path to get lock, subscribe for lock delete/expire events
3653
- await self._wait_lock(lock_key, lock_id)
3654
- state_is_locked = True
3655
-
3656
- try:
3657
- yield lock_id
3658
- except LockExpiredError:
3659
- state_is_locked = False
3660
- raise
3661
- finally:
3662
- if state_is_locked:
3663
- # only delete our lock
3664
- await self.redis.delete(lock_key)
3665
-
3666
- async def close(self):
3667
- """Explicitly close the redis connection and connection_pool.
3668
-
3669
- It is necessary in testing scenarios to close between asyncio test cases
3670
- to avoid having lingering redis connections associated with event loops
3671
- that will be closed (each test case uses its own event loop).
3672
-
3673
- Note: Connections will be automatically reopened when needed.
3674
- """
3675
- await self.redis.aclose(close_connection_pool=True)
3676
-
3677
-
3678
- def get_state_manager() -> StateManager:
3679
- """Get the state manager for the app that is currently running.
3680
-
3681
- Returns:
3682
- The state manager.
3683
- """
3684
- return prerequisites.get_and_validate_app().app.state_manager
3685
-
3686
-
3687
- class MutableProxy(wrapt.ObjectProxy):
3688
- """A proxy for a mutable object that tracks changes."""
3689
-
3690
- # Hint for finding the base class of the proxy.
3691
- __base_proxy__ = "MutableProxy"
3692
-
3693
- # Methods on wrapped objects which should mark the state as dirty.
3694
- __mark_dirty_attrs__ = {
3695
- "add",
3696
- "append",
3697
- "clear",
3698
- "difference_update",
3699
- "discard",
3700
- "extend",
3701
- "insert",
3702
- "intersection_update",
3703
- "pop",
3704
- "popitem",
3705
- "remove",
3706
- "reverse",
3707
- "setdefault",
3708
- "sort",
3709
- "symmetric_difference_update",
3710
- "update",
3711
- }
3712
-
3713
- # Methods on wrapped objects might return mutable objects that should be tracked.
3714
- __wrap_mutable_attrs__ = {
3715
- "get",
3716
- "setdefault",
3717
- }
3718
-
3719
- # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
3720
- __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
3721
- pydantic.BaseModel.__dict__
3722
- )
3723
-
3724
- # These types will be wrapped in MutableProxy
3725
- __mutable_types__ = (
3726
- list,
3727
- dict,
3728
- set,
3729
- Base,
3730
- DeclarativeBase,
3731
- BaseModelV2,
3732
- BaseModelV1,
3733
- )
3734
-
3735
- # Dynamically generated classes for tracking dataclass mutations.
3736
- __dataclass_proxies__: dict[type, type] = {}
3737
-
3738
- def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
3739
- """Create a proxy instance for a mutable object that tracks changes.
3740
-
3741
- Args:
3742
- wrapped: The object to proxy.
3743
- *args: Other args passed to MutableProxy (ignored).
3744
- **kwargs: Other kwargs passed to MutableProxy (ignored).
3745
-
3746
- Returns:
3747
- The proxy instance.
3748
- """
3749
- if dataclasses.is_dataclass(wrapped):
3750
- wrapped_cls = type(wrapped)
3751
- wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
3752
- # Find the associated class
3753
- if wrapper_cls_name not in cls.__dataclass_proxies__:
3754
- # Create a new class that has the __dataclass_fields__ defined
3755
- cls.__dataclass_proxies__[wrapper_cls_name] = type(
3756
- wrapper_cls_name,
3757
- (cls,),
3758
- {
3759
- dataclasses._FIELDS: getattr( # pyright: ignore [reportAttributeAccessIssue]
3760
- wrapped_cls,
3761
- dataclasses._FIELDS, # pyright: ignore [reportAttributeAccessIssue]
3762
- ),
3763
- },
3764
- )
3765
- cls = cls.__dataclass_proxies__[wrapper_cls_name]
3766
- return super().__new__(cls)
3767
-
3768
- def __init__(self, wrapped: Any, state: BaseState, field_name: str):
3769
- """Create a proxy for a mutable object that tracks changes.
3770
-
3771
- Args:
3772
- wrapped: The object to proxy.
3773
- state: The state to mark dirty when the object is changed.
3774
- field_name: The name of the field on the state associated with the
3775
- wrapped object.
3776
- """
3777
- super().__init__(wrapped)
3778
- self._self_state = state
3779
- self._self_field_name = field_name
3780
-
3781
- def __repr__(self) -> str:
3782
- """Get the representation of the wrapped object.
3783
-
3784
- Returns:
3785
- The representation of the wrapped object.
3786
- """
3787
- return f"{type(self).__name__}({self.__wrapped__})"
3788
-
3789
- def _mark_dirty(
3790
- self,
3791
- wrapped: Callable | None = None,
3792
- instance: BaseState | None = None,
3793
- args: tuple = (),
3794
- kwargs: dict | None = None,
3795
- ) -> Any:
3796
- """Mark the state as dirty, then call a wrapped function.
3797
-
3798
- Intended for use with `FunctionWrapper` from the `wrapt` library.
3799
-
3800
- Args:
3801
- wrapped: The wrapped function.
3802
- instance: The instance of the wrapped function.
3803
- args: The args for the wrapped function.
3804
- kwargs: The kwargs for the wrapped function.
3805
-
3806
- Returns:
3807
- The result of the wrapped function.
3808
- """
3809
- self._self_state.dirty_vars.add(self._self_field_name)
3810
- self._self_state._mark_dirty()
3811
- if wrapped is not None:
3812
- return wrapped(*args, **(kwargs or {}))
3813
-
3814
- @classmethod
3815
- def _is_mutable_type(cls, value: Any) -> bool:
3816
- """Check if a value is of a mutable type and should be wrapped.
3817
-
3818
- Args:
3819
- value: The value to check.
3820
-
3821
- Returns:
3822
- Whether the value is of a mutable type.
3823
- """
3824
- return isinstance(value, cls.__mutable_types__) or (
3825
- dataclasses.is_dataclass(value) and not isinstance(value, Var)
3826
- )
3827
-
3828
- @staticmethod
3829
- def _is_called_from_dataclasses_internal() -> bool:
3830
- """Check if the current function is called from dataclasses helper.
3831
-
3832
- Returns:
3833
- Whether the current function is called from dataclasses internal code.
3834
- """
3835
- # Walk up the stack a bit to see if we are called from dataclasses
3836
- # internal code, for example `asdict` or `astuple`.
3837
- frame = inspect.currentframe()
3838
- for _ in range(5):
3839
- # Why not `inspect.stack()` -- this is much faster!
3840
- if not (frame := frame and frame.f_back):
3841
- break
3842
- if inspect.getfile(frame) == dataclasses.__file__:
3843
- return True
3844
- return False
3845
-
3846
- def _wrap_recursive(self, value: Any) -> Any:
3847
- """Wrap a value recursively if it is mutable.
3848
-
3849
- Args:
3850
- value: The value to wrap.
3851
-
3852
- Returns:
3853
- The wrapped value.
3854
- """
3855
- # When called from dataclasses internal code, return the unwrapped value
3856
- if self._is_called_from_dataclasses_internal():
3857
- return value
3858
- # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
3859
- if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
3860
- base_cls = globals()[self.__base_proxy__]
3861
- return base_cls(
3862
- wrapped=value,
3863
- state=self._self_state,
3864
- field_name=self._self_field_name,
3865
- )
3866
- return value
3867
-
3868
- def _wrap_recursive_decorator(
3869
- self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
3870
- ) -> Any:
3871
- """Wrap a function that returns a possibly mutable value.
3872
-
3873
- Intended for use with `FunctionWrapper` from the `wrapt` library.
3874
-
3875
- Args:
3876
- wrapped: The wrapped function.
3877
- instance: The instance of the wrapped function.
3878
- args: The args for the wrapped function.
3879
- kwargs: The kwargs for the wrapped function.
3880
-
3881
- Returns:
3882
- The result of the wrapped function (possibly wrapped in a MutableProxy).
3883
- """
3884
- return self._wrap_recursive(wrapped(*args, **kwargs))
3885
-
3886
- def __getattr__(self, __name: str) -> Any:
3887
- """Get the attribute on the proxied object and return a proxy if mutable.
3888
-
3889
- Args:
3890
- __name: The name of the attribute.
3891
-
3892
- Returns:
3893
- The attribute value.
3894
- """
3895
- value = super().__getattr__(__name)
3896
-
3897
- if callable(value):
3898
- if __name in self.__mark_dirty_attrs__:
3899
- # Wrap special callables, like "append", which should mark state dirty.
3900
- value = wrapt.FunctionWrapper(value, self._mark_dirty)
3901
-
3902
- if __name in self.__wrap_mutable_attrs__:
3903
- # Wrap methods that may return mutable objects tied to the state.
3904
- value = wrapt.FunctionWrapper(
3905
- value,
3906
- self._wrap_recursive_decorator,
3907
- )
3908
-
3909
- if (
3910
- isinstance(self.__wrapped__, Base)
3911
- and __name not in self.__never_wrap_base_attrs__
3912
- and hasattr(value, "__func__")
3913
- ):
3914
- # Wrap methods called on Base subclasses, which might do _anything_
3915
- return wrapt.FunctionWrapper(
3916
- functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess]
3917
- self._wrap_recursive_decorator,
3918
- )
3919
-
3920
- if self._is_mutable_type(value) and __name not in (
3921
- "__wrapped__",
3922
- "_self_state",
3923
- ):
3924
- # Recursively wrap mutable attribute values retrieved through this proxy.
3925
- return self._wrap_recursive(value)
3926
-
3927
- return value
3928
-
3929
- def __getitem__(self, key: Any) -> Any:
3930
- """Get the item on the proxied object and return a proxy if mutable.
3931
-
3932
- Args:
3933
- key: The key of the item.
3934
-
3935
- Returns:
3936
- The item value.
3937
- """
3938
- value = super().__getitem__(key)
3939
- # Recursively wrap mutable items retrieved through this proxy.
3940
- return self._wrap_recursive(value)
3941
-
3942
- def __iter__(self) -> Any:
3943
- """Iterate over the proxied object and return a proxy if mutable.
3944
-
3945
- Yields:
3946
- Each item value (possibly wrapped in MutableProxy).
3947
- """
3948
- for value in super().__iter__():
3949
- # Recursively wrap mutable items retrieved through this proxy.
3950
- yield self._wrap_recursive(value)
3951
-
3952
- def __delattr__(self, name: str):
3953
- """Delete the attribute on the proxied object and mark state dirty.
3954
-
3955
- Args:
3956
- name: The name of the attribute.
3957
- """
3958
- self._mark_dirty(super().__delattr__, args=(name,))
3959
-
3960
- def __delitem__(self, key: str):
3961
- """Delete the item on the proxied object and mark state dirty.
3962
-
3963
- Args:
3964
- key: The key of the item.
3965
- """
3966
- self._mark_dirty(super().__delitem__, args=(key,))
3967
-
3968
- def __setitem__(self, key: str, value: Any):
3969
- """Set the item on the proxied object and mark state dirty.
3970
-
3971
- Args:
3972
- key: The key of the item.
3973
- value: The value of the item.
3974
- """
3975
- self._mark_dirty(super().__setitem__, args=(key, value))
3976
-
3977
- def __setattr__(self, name: str, value: Any):
3978
- """Set the attribute on the proxied object and mark state dirty.
3979
-
3980
- If the attribute starts with "_self_", then the state is NOT marked
3981
- dirty as these are internal proxy attributes.
3982
-
3983
- Args:
3984
- name: The name of the attribute.
3985
- value: The value of the attribute.
3986
- """
3987
- if name.startswith("_self_"):
3988
- # Special case attributes of the proxy itself, not applied to the wrapped object.
3989
- super().__setattr__(name, value)
3990
- return
3991
- self._mark_dirty(super().__setattr__, args=(name, value))
3992
-
3993
- def __copy__(self) -> Any:
3994
- """Return a copy of the proxy.
3995
-
3996
- Returns:
3997
- A copy of the wrapped object, unconnected to the proxy.
3998
- """
3999
- return copy.copy(self.__wrapped__)
4000
-
4001
- def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
4002
- """Return a deepcopy of the proxy.
4003
-
4004
- Args:
4005
- memo: The memo dict to use for the deepcopy.
4006
-
4007
- Returns:
4008
- A deepcopy of the wrapped object, unconnected to the proxy.
4009
- """
4010
- return copy.deepcopy(self.__wrapped__, memo=memo)
4011
-
4012
- def __reduce_ex__(self, protocol_version: SupportsIndex):
4013
- """Get the state for redis serialization.
4014
-
4015
- This method is called by cloudpickle to serialize the object.
4016
-
4017
- It explicitly serializes the wrapped object, stripping off the mutable proxy.
4018
-
4019
- Args:
4020
- protocol_version: The protocol version.
4021
-
4022
- Returns:
4023
- Tuple of (wrapped class, empty args, class __getstate__)
4024
- """
4025
- return self.__wrapped__.__reduce_ex__(protocol_version)
4026
-
4027
-
4028
- @serializer
4029
- def serialize_mutable_proxy(mp: MutableProxy):
4030
- """Return the wrapped value of a MutableProxy.
4031
-
4032
- Args:
4033
- mp: The MutableProxy to serialize.
4034
-
4035
- Returns:
4036
- The wrapped object.
4037
- """
4038
- return mp.__wrapped__
4039
-
4040
-
4041
- _orig_json_encoder_default = json.JSONEncoder.default
4042
-
4043
-
4044
- def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
4045
- """Wrap JSONEncoder.default to handle MutableProxy objects.
4046
-
4047
- Args:
4048
- self: the JSONEncoder instance.
4049
- o: the object to serialize.
4050
-
4051
- Returns:
4052
- A JSON-able object.
4053
- """
4054
- try:
4055
- return o.__wrapped__
4056
- except AttributeError:
4057
- pass
4058
- return _orig_json_encoder_default(self, o)
4059
-
4060
-
4061
- json.JSONEncoder.default = _json_encoder_default_wrapper
4062
-
4063
-
4064
- class ImmutableMutableProxy(MutableProxy):
4065
- """A proxy for a mutable object that tracks changes.
4066
-
4067
- This wrapper comes from StateProxy, and will raise an exception if an attempt is made
4068
- to modify the wrapped object when the StateProxy is immutable.
4069
- """
4070
-
4071
- # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
4072
- __base_proxy__ = "ImmutableMutableProxy"
4073
-
4074
- def _mark_dirty(
4075
- self,
4076
- wrapped: Callable | None = None,
4077
- instance: BaseState | None = None,
4078
- args: tuple = (),
4079
- kwargs: dict | None = None,
4080
- ) -> Any:
4081
- """Raise an exception when an attempt is made to modify the object.
4082
-
4083
- Intended for use with `FunctionWrapper` from the `wrapt` library.
4084
-
4085
- Args:
4086
- wrapped: The wrapped function.
4087
- instance: The instance of the wrapped function.
4088
- args: The args for the wrapped function.
4089
- kwargs: The kwargs for the wrapped function.
4090
-
4091
- Returns:
4092
- The result of the wrapped function.
4093
-
4094
- Raises:
4095
- ImmutableStateError: if the StateProxy is not mutable.
4096
- """
4097
- if not self._self_state._is_mutable():
4098
- raise ImmutableStateError(
4099
- "Background task StateProxy is immutable outside of a context "
4100
- "manager. Use `async with self` to modify state."
4101
- )
4102
- return super()._mark_dirty(
4103
- wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
4104
- )
4105
-
4106
-
4107
- def code_uses_state_contexts(javascript_code: str) -> bool:
4108
- """Check if the rendered Javascript uses state contexts.
4109
-
4110
- Args:
4111
- javascript_code: The Javascript code to check.
4112
-
4113
- Returns:
4114
- True if the code attempts to access a member of StateContexts.
4115
- """
4116
- return bool("useContext(StateContexts" in javascript_code)
4117
-
4118
-
4119
- def reload_state_module(
4120
- module: str,
4121
- state: type[BaseState] = State,
4122
- ) -> None:
4123
- """Reset rx.State subclasses to avoid conflict when reloading.
4124
-
4125
- Args:
4126
- module: The module to reload.
4127
- state: Recursive argument for the state class to reload.
4128
-
4129
- """
4130
- # Clean out all potentially dirty states of reloaded modules.
4131
- for pd_state in tuple(state._potentially_dirty_states):
4132
- with contextlib.suppress(ValueError):
4133
- if (
4134
- state.get_root_state().get_class_substate(pd_state).__module__ == module
4135
- and module is not None
4136
- ):
4137
- state._potentially_dirty_states.remove(pd_state)
4138
- for subclass in tuple(state.class_subclasses):
4139
- reload_state_module(module=module, state=subclass)
4140
- if subclass.__module__ == module and module is not None:
4141
- all_base_state_classes.pop(subclass.get_full_name(), None)
4142
- state.class_subclasses.remove(subclass)
4143
- state._always_dirty_substates.discard(subclass.get_name())
4144
- state._var_dependencies = {}
4145
- state._init_var_dependency_dicts()
4146
- state.get_class_substate.cache_clear()
2607
+ from reflex.istate.manager import LockExpiredError as LockExpiredError # noqa: E402
2608
+ from reflex.istate.manager import StateManager as StateManager # noqa: E402
2609
+ from reflex.istate.manager import StateManagerDisk as StateManagerDisk # noqa: E402
2610
+ from reflex.istate.manager import StateManagerMemory as StateManagerMemory # noqa: E402
2611
+ from reflex.istate.manager import StateManagerRedis as StateManagerRedis # noqa: E402
2612
+ from reflex.istate.manager import get_state_manager as get_state_manager # noqa: E402
2613
+ from reflex.istate.manager import ( # noqa: E402
2614
+ reset_disk_state_manager as reset_disk_state_manager,
2615
+ )