brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/_state.py CHANGED
@@ -13,408 +13,852 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import contextlib
19
+ import dataclasses
17
20
  import threading
18
- from typing import Any, Tuple, Dict, List, Callable, Optional
21
+ from functools import wraps, partial
22
+ from typing import (Any, Union, Callable, Generic, Mapping,
23
+ TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence)
19
24
 
20
25
  import jax
21
26
  import numpy as np
22
27
  from jax.api_util import shaped_abstractify
23
28
  from jax.extend import source_info_util
24
29
 
25
- from brainstate.typing import ArrayLike, PyTree
26
- from brainstate.util import DictManager
30
+ from brainstate.typing import ArrayLike, PyTree, Missing
31
+ from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr, TraceContextError
32
+ from brainstate.util._tracers import StateJaxTracer
27
33
 
28
34
  __all__ = [
29
- 'State', 'ShortTermState', 'LongTermState', 'ParamState',
30
- 'StateDictManager',
31
- 'StateTrace',
32
- 'visible_state_dict',
33
- 'check_state_value_tree',
34
- ]
35
+ 'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
35
36
 
36
- _pytree_registered_objects = set()
37
- max_int = np.iinfo(np.int32)
37
+ 'StateDictManager', 'StateTraceStack', 'check_state_value_tree', 'check_state_jax_tracer', 'catch_new_states',
38
+ ]
38
39
 
40
+ A = TypeVar('A')
41
+ B = TypeVar('B')
42
+ F = TypeVar('F', bound=Callable[..., Any])
39
43
 
40
- def _register_pytree_cls(cls):
41
- if cls not in _pytree_registered_objects:
42
- jax.tree_util.register_pytree_node_class(cls)
43
- _pytree_registered_objects.add(cls)
44
+ max_int = np.iinfo(np.int32)
44
45
 
45
46
 
46
47
  # The global state of the state stack is accessed by a thread-local object.
47
48
  # This allows concurrent tracing in separate threads; passing traced objects
48
49
  # between threads is forbidden.
49
50
  class ThreadLocalStack(threading.local):
50
- def __init__(self):
51
- self.stack: List[StateTrace] = []
52
-
51
+ def __init__(self):
52
+ self.state_stack: List[StateTraceStack] = []
53
+ self.tree_check: List[bool] = [False]
54
+ self.jax_tracer_check: List[bool] = [False]
55
+ self.new_state_catcher: List[Catcher] = []
53
56
 
54
- thread_local_stack = ThreadLocalStack()
55
57
 
56
- _global_context_to_check_state_tree = [False]
58
+ TRACE_CONTEXT = ThreadLocalStack()
57
59
 
58
60
 
59
61
  @contextlib.contextmanager
60
- def check_state_value_tree() -> None:
61
- """
62
- The contex manager to check weather the tree structure of the state value keeps consistently.
63
-
64
- Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
65
- the tree structure of the value is not checked to avoid off the repeated evaluation.
66
- If you want to check the tree structure of the value once the new value is assigned,
67
- you can use this context manager.
68
-
69
- Example::
70
-
71
- ```python
72
- state = brainstate.ShortTermState(jnp.zeros((2, 3)))
73
- with check_state_value_tree():
74
- state.value = jnp.zeros((2, 3))
75
-
76
- # The following code will raise an error.
77
- state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
78
- ```
79
-
80
- """
81
- try:
82
- _global_context_to_check_state_tree.append(True)
83
- yield
84
- finally:
85
- _global_context_to_check_state_tree.pop()
86
-
87
-
88
- class State(object):
89
- """
90
- The pointer to specify the dynamical data.
91
-
92
- To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
93
-
94
- Example::
95
-
96
- class MyState(State):
97
- pass
98
-
99
- The typical examples of :py:class:`~.State` subclass are:
100
-
101
- - :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
102
- - :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
103
- - :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
104
- - :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
105
-
106
- Args:
107
- value: PyTree. It can be anything as a pyTree.
108
- """
109
- __module__ = 'brainstate'
110
- __slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
111
-
112
- def __init__(self, value: PyTree[ArrayLike], name: Optional[str] = None):
113
- if isinstance(value, State):
114
- value = value.value
115
- self._value = value
116
- self._tree = jax.tree.structure(value)
117
- self._check_tree = False
118
- self._level = len(thread_local_stack.stack)
119
- self._source_info = source_info_util.current()
120
- self._name = name
121
-
122
- @property
123
- def name(self) -> Optional[str]:
62
+ def check_state_value_tree(val: bool = True) -> None:
124
63
  """
125
- The name of the state.
126
- """
127
- return self._name
64
+ The contex manager to check weather the tree structure of the state value keeps consistently.
65
+
66
+ Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
67
+ the tree structure of the value is not checked to avoid off the repeated evaluation.
68
+ If you want to check the tree structure of the value once the new value is assigned,
69
+ you can use this context manager.
70
+
71
+ Example::
72
+
73
+ >>> import brainstate as bst
74
+ >>> import jax.numpy as jnp
75
+ >>> state = bst.ShortTermState(jnp.zeros((2, 3)))
76
+ >>> with bst.check_state_value_tree():
77
+ >>> # The line below will not raise an error.
78
+ >>> state.value = jnp.zeros((2, 3))
79
+ ...
80
+ >>> # The following code will raise an error, since it changes the tree structure.
81
+ >>> state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
128
82
 
129
- @name.setter
130
- def name(self, name: str) -> None:
131
- """
132
- Set the name of the state.
133
83
  """
134
- self._name = name
84
+ try:
85
+ TRACE_CONTEXT.tree_check.append(val)
86
+ yield
87
+ finally:
88
+ TRACE_CONTEXT.tree_check.pop()
89
+
90
+
91
+ @contextlib.contextmanager
92
+ def catch_new_states(tag: str = None) -> List:
93
+ try:
94
+ catcher = Catcher(tag)
95
+ TRACE_CONTEXT.new_state_catcher.append(catcher)
96
+ yield catcher
97
+ finally:
98
+ TRACE_CONTEXT.new_state_catcher.pop()
99
+
100
+
101
+ class Catcher:
102
+ def __init__(self, tag: str):
103
+ self.tag = tag
104
+ self.state_ids = set()
105
+ self.states = []
135
106
 
136
- @property
137
- def value(self) -> PyTree[ArrayLike]:
107
+ def append(self, state: State):
108
+ if id(state) not in self.state_ids:
109
+ self.state_ids.add(id(state))
110
+ self.states.append(state)
111
+ state.tag = self.tag
112
+
113
+
114
+ @contextlib.contextmanager
115
+ def check_state_jax_tracer(val: bool = True) -> None:
138
116
  """
139
- The data and its value.
117
+ The context manager to check whether the state is valid to trace.
118
+
119
+ Example::
120
+
121
+ >>> import jax
122
+ >>> import brainstate as bst
123
+ >>> import jax.numpy as jnp
124
+ >>>
125
+ >>> a = bst.ShortTermState(jnp.zeros((2, 3)))
126
+ >>>
127
+ >>> @jax.jit
128
+ >>> def run_state(b):
129
+ >>> a.value = b
130
+ >>> return a.value
131
+ >>>
132
+ >>> # The following code will not raise an error, since the state is valid to trace.
133
+ >>> run_state(jnp.ones((2, 3)))
134
+ >>>
135
+ >>> with check_state_jax_tracer():
136
+ >>> # The line below will not raise an error.
137
+ >>> run_state(jnp.ones((2, 4)))
140
138
  """
141
- self._check_if_deleted()
139
+ try:
140
+ TRACE_CONTEXT.jax_tracer_check.append(val)
141
+ yield
142
+ finally:
143
+ TRACE_CONTEXT.jax_tracer_check.pop()
142
144
 
143
- # read the value by the stack (>= level)
144
- trace: StateTrace
145
- for trace in thread_local_stack.stack[self._level:]:
146
- trace.read_its_value(self)
147
- # return the value
148
- return self._value
149
145
 
150
- @value.setter
151
- def value(self, v) -> None:
146
+ @dataclasses.dataclass
147
+ class StateMetadata(Generic[A]):
152
148
  """
153
- Set the value of the state.
149
+ The state metadata.
154
150
 
155
151
  Args:
156
- v: The value.
152
+ raw_value: The raw value.
153
+ metadata: The metadata.
157
154
  """
158
- # value checking
159
- v = v.value if isinstance(v, State) else v
160
- self._check_value_tree(v)
161
- # write the value by the stack (>= level)
162
- trace: StateTrace
163
- for trace in thread_local_stack.stack[self._level:]:
164
- trace.write_its_value(self)
165
- # set the value
166
- self._value = v
167
-
168
- def _check_value_tree(self, v):
169
- if self._check_tree or _global_context_to_check_state_tree[-1]:
170
- in_tree = jax.tree.structure(v)
171
- if in_tree != self._tree:
172
- self._raise_error_with_source_info(
173
- ValueError(f'The given value {in_tree} does not '
174
- f'match with the origin tree structure '
175
- f'{self._tree}.')
176
- )
155
+ raw_value: A
156
+ metadata: Mapping[str, Any] = dataclasses.field(default_factory=dict)
177
157
 
178
- def _raise_error_with_source_info(self, error: Exception):
179
- name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
180
- with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
181
- raise error
182
158
 
183
- def _check_if_deleted(self):
184
- pass
185
-
186
- @property
187
- def source_info(self) -> source_info_util.SourceInfo:
159
+ def with_metadata(initializer: F, **metadata: Any) -> F:
188
160
  """
189
- The source information of the state, can be useful to identify
190
- the source code where the definition of the state.
191
-
192
- Returns:
193
- The source information.
161
+ A decorator to add metadata to the state.
194
162
  """
195
- return self._source_info
196
163
 
197
- def tree_flatten(self):
198
- """Flattens this variable.
199
-
200
- Returns:
201
- A pair where the first element is a list of leaf values
202
- and the second element is a treedef representing the
203
- structure of the flattened tree.
204
- """
205
- return (self._value,), (self._level,)
164
+ @wraps(initializer)
165
+ def wrapper(*args):
166
+ return StateMetadata(initializer(*args), metadata=metadata)
206
167
 
207
- @classmethod
208
- def tree_unflatten(cls, aux_data, flat_contents):
209
- """Reconstructs a variable from the aux_data and the leaves.
168
+ return wrapper # type: ignore
210
169
 
211
- Args:
212
- aux_data:
213
- flat_contents:
214
170
 
215
- Returns:
216
- The variable.
217
- """
218
- (_level,) = aux_data
219
- self = cls(flat_contents[0])
220
- self._level = max_int
221
- return self
222
-
223
- def __repr__(self):
224
- leaves, tree = jax.tree.flatten(self._value)
225
- leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
226
- tree_info = jax.tree.unflatten(tree, leaves_info)
227
- if self.name is None:
228
- return f'{self.__class__.__name__}({tree_info})'
229
- else:
230
- return f'{self.__class__.__name__}({self.name}: {tree_info})'
171
+ def _get_trace_stack_level() -> int:
172
+ return len(TRACE_CONTEXT.state_stack)
231
173
 
232
174
 
233
- class ShapeDtype:
234
- def __init__(self, shape, dtype):
235
- self.shape = shape
236
- self.dtype = dtype
237
- self.ndim = len(shape)
238
- self.size = np.prod(shape)
175
+ class State(Generic[A], PrettyRepr):
176
+ """
177
+ The pointer to specify the dynamical data.
239
178
 
240
- def __repr__(self):
241
- return f'{self.dtype}{list(self.shape)}'
179
+ To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
242
180
 
181
+ Example::
243
182
 
244
- class ShortTermState(State):
245
- """
246
- The short-term state, which is used to store the short-term data in the program.
183
+ >>> class MyState(State):
184
+ >>> pass
247
185
 
248
- For example, in a training process, the gradients of the model are short-term states.
249
- """
186
+ The typical examples of :py:class:`~.State` subclass are:
250
187
 
251
- __module__ = 'brainstate'
188
+ - :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
189
+ - :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
190
+ - :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
191
+ - :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
252
192
 
193
+ Args:
194
+ value: PyTree. It can be anything as a pyTree.
195
+ """
196
+ __module__ = 'brainstate'
197
+ _trace_state: StateJaxTracer
198
+ _level: int
199
+ _source_info: source_info_util.SourceInfo
200
+ _name: Optional[str]
201
+ _value: PyTree
202
+ _been_writen: bool # useful in `unflatten` and `flatten` graph processing
203
+ tag: Optional[str]
204
+
205
+ def __init__(
206
+ self,
207
+ value: Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]],
208
+ name: Optional[str] = None,
209
+ **metadata: Any
210
+ ):
211
+ tag = metadata.pop('tag', None)
212
+
213
+ # avoid using self._setattr to avoid the check
214
+ vars(self)['_trace_state'] = StateJaxTracer()
215
+
216
+ # set the value and metadata
217
+ if isinstance(value, StateMetadata):
218
+ metadata.update(dict(value.metadata))
219
+ value = value.raw_value
220
+ if isinstance(value, State):
221
+ value = value.value
222
+
223
+ # update metadata
224
+ metadata.update(_value=value,
225
+ _level=_get_trace_stack_level(),
226
+ _source_info=source_info_util.current(),
227
+ _name=name,
228
+ tag=tag,
229
+ _been_writen=False)
230
+
231
+ # avoid using self._setattr to avoid the check
232
+ vars(self).update(metadata)
233
+
234
+ record_state_init(self)
235
+
236
+ if not TYPE_CHECKING:
237
+ def __setattr__(self, name: str, value: Any) -> None:
238
+ return self._setattr(name, value)
239
+
240
+ def _setattr(self, name: str, value: Any):
241
+ """
242
+ Check if the state is valid to mutate.
243
+ """
244
+ if TRACE_CONTEXT.jax_tracer_check[-1]:
245
+ self.check_valid_trace(lambda: f'Cannot mutate {type(self).__name__} from a different trace level')
246
+ object.__setattr__(self, name, value)
247
+
248
+ def _setattr_no_check(self, name: str, value: Any):
249
+ """
250
+ Set the attribute without checking the trace level.
251
+ """
252
+ vars(self)[name] = value
253
+
254
+ def check_valid_trace(self, error_msg: Callable[[], str]):
255
+ """
256
+ Check if the state is valid to trace.
257
+ """
258
+ if not self._trace_state.is_valid():
259
+ raise TraceContextError(error_msg())
260
+
261
+ @property
262
+ def name(self) -> Optional[str]:
263
+ """
264
+ The name of the state.
265
+ """
266
+ return self._name
267
+
268
+ @name.setter
269
+ def name(self, name: str) -> None:
270
+ """
271
+ Set the name of the state.
272
+ """
273
+ self._setattr_no_check('_name', name)
274
+
275
+ @property
276
+ def value(self) -> PyTree[ArrayLike]:
277
+ """
278
+ The data and its value.
279
+ """
280
+ self.check_if_deleted()
281
+ record_state_value_read(self)
282
+ return self._value
283
+
284
+ @value.setter
285
+ def value(self, v) -> None:
286
+ """
287
+ Set the value of the state.
288
+
289
+ Args:
290
+ v: The value.
291
+ """
292
+ self.write_value(v)
293
+ self._been_writen = True
294
+
295
+ def write_value(self, v) -> None:
296
+ # value checking
297
+ if isinstance(v, State):
298
+ raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
299
+ self._check_value_tree(v)
300
+ # write the value by the stack (>= level)
301
+ record_state_value_write(self)
302
+ # set the value
303
+ self._value = v
304
+
305
+ def restore_value(self, v) -> None:
306
+ """
307
+ Restore the value of the state.
308
+
309
+ Args:
310
+ v: The value.
311
+ """
312
+ # value checking
313
+ if isinstance(v, State):
314
+ raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
315
+ with check_state_value_tree():
316
+ self._check_value_tree(v)
317
+ # record the value by the stack (>= level)
318
+ record_state_value_restore(self)
319
+ # set the value
320
+ self._value = v
321
+
322
+ def value_call(self, func: Callable[..., Any]) -> Any:
323
+ """
324
+ Call the function with the value of the state.
325
+ """
326
+ return jax.tree.map(func, self.value)
327
+
328
+ def _check_value_tree(self, v):
329
+ """
330
+ Check if the value tree structure is consistent.
331
+ """
332
+ if TRACE_CONTEXT.tree_check[-1]:
333
+ in_tree = jax.tree.structure(v)
334
+ self_tree = jax.tree.structure(self._value)
335
+ if in_tree != self_tree:
336
+ self._raise_error_with_source_info(
337
+ ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
338
+ )
339
+
340
+ def _raise_error_with_source_info(self, error: Exception):
341
+ """
342
+ Raise an error with the source information for easy debugging.
343
+ """
344
+ name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
345
+ with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
346
+ raise error
347
+
348
+ def check_if_deleted(self):
349
+ pass
350
+
351
+ @property
352
+ def source_info(self) -> source_info_util.SourceInfo:
353
+ """
354
+ The source information of the state, can be useful to identify
355
+ the source code where the definition of the state.
356
+
357
+ Returns:
358
+ The source information.
359
+ """
360
+ return self._source_info
361
+
362
+ def update_from_ref(self, state_ref: TreefyState[A]) -> None:
363
+ """
364
+ Update the state from the state reference :py:class:`TreefyState`.
365
+
366
+ Args:
367
+ state_ref: The state reference.
368
+ """
369
+ metadata = state_ref.get_metadata()
370
+ variable_vars = vars(self)
371
+ variable_vars.update(**metadata)
372
+ if metadata.pop('_been_writen', True):
373
+ self.value = state_ref.value
374
+ else:
375
+ self.restore_value(state_ref.value)
376
+
377
+ def replace(self, value: Any = Missing, **kwargs) -> State[Any]:
378
+ """
379
+ Replace the attribute of the state.
380
+ """
381
+ if value is not Missing:
382
+ kwargs['_value'] = value
383
+
384
+ # return `value` if it is a State
385
+ if '_value' in kwargs and isinstance(value := kwargs['_value'], State):
386
+ # remove value from kwargs
387
+ kwargs.pop('_value')
388
+ if type(self) is not type(value):
389
+ raise ValueError('Cannot replace value from incompatible container, '
390
+ f'expected {type(self).__name__}, got {type(value).__name__}')
391
+ # if kwargs aren't empty, recursively call replace
392
+ # else return variable value
393
+ if kwargs:
394
+ return value.replace(**kwargs)
395
+ else:
396
+ return value
397
+
398
+ # get and update attributes
399
+ attributes = vars(self).copy()
400
+ attributes.update(**kwargs)
401
+ # return new instance with updated attributes
402
+ obj = object.__new__(type(self))
403
+ vars(obj).update(attributes)
404
+ return obj
405
+
406
+ def copy(self: State[A]) -> State[A]:
407
+ """
408
+ Copy the state.
409
+ """
410
+ obj = object.__new__(type(self))
411
+ attributes = vars(self).copy()
412
+ # keep its own trace state and stack level
413
+ attributes['_trace_state'] = StateJaxTracer()
414
+ attributes['_level'] = _get_trace_stack_level()
415
+ attributes['_source_info'] = source_info_util.current()
416
+ attributes.pop('_been_writen', None)
417
+ # update the metadata
418
+ vars(obj).update(attributes)
419
+ return obj
420
+
421
+ def to_state_ref(self: State[A]) -> TreefyState[A]:
422
+ metadata = vars(self).copy()
423
+ del metadata['_value']
424
+ del metadata['_trace_state']
425
+ del metadata['_level']
426
+ return TreefyState(type(self), self._value, **metadata)
427
+
428
+ def __pretty_repr__(self):
429
+ yield PrettyType(type=type(self))
430
+ for name, value in vars(self).items():
431
+ if name == '_value':
432
+ name = 'value'
433
+ if name == '_name':
434
+ if value is None:
435
+ continue
436
+ else:
437
+ name = 'name'
438
+ if name == 'tag' and value is None:
439
+ continue
440
+ if name in ['_trace_state', '_level', '_source_info', '_been_writen']:
441
+ continue
442
+ yield PrettyAttr(name, repr(value))
443
+
444
+ def __treescope_repr__(self, path, subtree_renderer):
445
+ children = {}
446
+ for name, value in vars(self).items():
447
+ if name == '_value':
448
+ name = 'value'
449
+ if name == '_name':
450
+ if value is None:
451
+ continue
452
+ else:
453
+ name = 'name'
454
+ if name == 'tag' and value is None:
455
+ continue
456
+ if name in ['_trace_state', '_level', '_source_info', '_been_writen']:
457
+ continue
458
+ children[name] = value
459
+
460
+ import treescope # type: ignore[import-not-found,import-untyped]
461
+ return treescope.repr_lib.render_object_constructor(
462
+ object_type=type(self),
463
+ attributes=children,
464
+ path=path,
465
+ subtree_renderer=subtree_renderer,
466
+ )
253
467
 
254
- class LongTermState(State):
255
- """
256
- The long-term state, which is used to store the long-term data in the program.
468
+ def __eq__(self, other: object) -> bool:
469
+ return type(self) is type(other) and vars(other) == vars(self)
257
470
 
258
- For example, in a training process, the weights of the model are long-term states.
259
471
 
260
- """
472
+ def record_state_init(st: State[A]):
473
+ trace: Catcher
474
+ for trace in TRACE_CONTEXT.new_state_catcher:
475
+ trace.append(st)
261
476
 
262
- __module__ = 'brainstate'
263
477
 
478
+ def record_state_value_read(st: State[A]):
479
+ trace: StateTraceStack
480
+ for trace in TRACE_CONTEXT.state_stack[st._level:]:
481
+ trace.read_its_value(st)
264
482
 
265
- class ParamState(LongTermState):
266
- """
267
- The parameter state, which is used to store the trainable parameters in the model.
268
- """
269
- __module__ = 'brainstate'
270
483
 
484
+ def record_state_value_write(st: State[A]):
485
+ trace: StateTraceStack
486
+ for trace in TRACE_CONTEXT.state_stack[st._level:]:
487
+ trace.write_its_value(st)
271
488
 
272
- class StateDictManager(DictManager):
273
- """
274
- State stack, for collecting all :py:class:`~.State` used in the program.
275
489
 
276
- :py:class:`~.StateDictManager` supports all features of python dict.
277
- """
490
+ def record_state_value_restore(st: State[A]):
491
+ record_state_value_read(st)
278
492
 
279
- __module__ = 'brainstate'
280
493
 
281
- def assign_values(self, *args: Dict) -> None:
282
- """
283
- Assign the value for each element according to the given ``data``.
494
+ class ShortTermState(State):
284
495
  """
285
- for arg in args:
286
- assert isinstance(arg, dict), 'Must be an instance of dict.'
287
- for k, v in arg.items():
288
- self._set_elem(k, v)
496
+ The short-term state, which is used to store the short-term data in the program.
289
497
 
290
- def split_values(self, *filters: type) -> Tuple[Dict, ...]:
291
- """
292
- Split the values into several subsets of stack by the given types.
293
- """
294
- results = tuple(DictManager() for _ in range(len(filters) + 1))
295
- for k, v in self.items():
296
- for i, filt in enumerate(filters):
297
- if isinstance(v, filt):
298
- results[i][k] = v.value
299
- break
300
- else:
301
- results[-1][k] = v.value
302
- return results
303
-
304
- def collect_values(self) -> Dict:
305
- """
306
- Collect the values by the given types.
498
+ For example, in a training process, the gradients of the model are short-term states.
307
499
  """
308
- results = DictManager()
309
- for k, v in self.items():
310
- results[k] = v.value
311
- return results
312
500
 
313
- def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
314
- return super().split(first, *others)
501
+ __module__ = 'brainstate'
315
502
 
316
- def to_dict_values(self) -> Dict:
317
- """
318
- Convert the values into a dict.
319
- """
320
- return {k: v.value for k, v in self.items()}
321
503
 
322
- def _check_elem(self, elem):
323
- assert isinstance(elem, State), f'must be instance of {State}'
504
+ class LongTermState(State):
505
+ """
506
+ The long-term state, which is used to store the long-term data in the program.
324
507
 
325
- def _set_elem(self, key: Any, value: Any) -> None:
326
- self[key].value = value
508
+ For example, in a training process, the weights of the model are long-term states.
509
+ """
327
510
 
511
+ __module__ = 'brainstate'
328
512
 
329
- class visible_state_dict(StateDictManager):
330
- """
331
- The state dictionary whose elements are visible to ``.states()`` collection functions.
332
- """
333
- pass
334
513
 
514
+ class HiddenState(ShortTermState):
515
+ """
516
+ The hidden state, which is used to store the hidden data in a dynamic model.
517
+ """
335
518
 
336
- class StateTrace(object):
337
- """
338
- The state trace, which is used to trace the states automatically.
339
- """
519
+ __module__ = 'brainstate'
340
520
 
341
- def __init__(self, new_arg: Callable = None):
342
- self.states: List[State] = []
343
- self.types: List[str] = []
344
- self._id2index = dict()
345
- self._org_values = []
346
- self._jax_trace_new_arg = new_arg
347
- self._written_ids = set()
348
521
 
349
- def set_new_arg(self, new_arg: Callable) -> None:
350
- self._jax_trace_new_arg = new_arg
522
+ class ParamState(LongTermState):
523
+ """
524
+ The parameter state, which is used to store the trainable parameters in the model.
525
+ """
351
526
 
352
- def new_arg(self, state: State) -> None:
353
- if self._jax_trace_new_arg is not None:
354
- # internal use
355
- state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
527
+ __module__ = 'brainstate'
356
528
 
357
- def __enter__(self) -> 'StateTrace':
358
- thread_local_stack.stack.append(self)
359
- return self
360
529
 
361
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
362
- thread_local_stack.stack.pop()
530
+ class StateDictManager(DictManager):
531
+ """
532
+ State stack, for collecting all :py:class:`~.State` used in the program.
363
533
 
364
- def read_its_value(self, state: State) -> None:
534
+ :py:class:`~.StateDictManager` supports all features of python dict.
365
535
  """
366
- Read the value of the state.
367
536
 
368
- Args:
369
- state: The state.
537
+ __module__ = 'brainstate'
538
+
539
+ def assign_values(self, *args: Dict) -> None:
540
+ """
541
+ Assign the value for each element according to the given ``data``.
542
+ """
543
+ for arg in args:
544
+ assert isinstance(arg, dict), 'Must be an instance of dict.'
545
+ for k, v in arg.items():
546
+ self._set_elem(k, v)
547
+
548
+ def split_values(self, *filters: type) -> Tuple[Dict, ...]:
549
+ """
550
+ Split the values into several subsets of stack by the given types.
551
+ """
552
+ results = tuple(DictManager() for _ in range(len(filters) + 1))
553
+ for k, v in self.items():
554
+ for i, filt in enumerate(filters):
555
+ if isinstance(v, filt):
556
+ results[i][k] = v.value
557
+ break
558
+ else:
559
+ results[-1][k] = v.value
560
+ return results
561
+
562
+ def collect_values(self) -> Dict:
563
+ """
564
+ Collect the values by the given types.
565
+ """
566
+ results = DictManager()
567
+ for k, v in self.items():
568
+ results[k] = v.value
569
+ return results
570
+
571
+ def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
572
+ return super().split(first, *others)
573
+
574
+ def to_dict_values(self) -> Dict:
575
+ """
576
+ Convert the values into a dict.
577
+ """
578
+ return {k: v.value for k, v in self.items()}
579
+
580
+ def _check_elem(self, elem):
581
+ assert isinstance(elem, State), f'must be instance of {State}'
582
+
583
+ def _set_elem(self, key: Any, value: Any) -> None:
584
+ self[key].value = value
585
+
586
+
587
+ class StateTraceStack(Generic[A]):
370
588
  """
371
- id_ = id(state)
372
- if id_ not in self._id2index:
373
- self._id2index[id_] = len(self.states)
374
- self.states.append(state)
375
- self.types.append('read')
376
- self._org_values.append(state._value) # internal use
377
- self.new_arg(state)
378
-
379
- def write_its_value(self, state: State) -> None:
589
+ The state trace stack, which is used to trace the states automatically.
380
590
  """
381
- Write the value of the state.
382
591
 
383
- Args:
384
- state: The state.
592
+ def __init__(self, new_arg: Callable = None):
593
+ self.states: List[State] = []
594
+ self.been_writen: List[bool] = [] # False: read, True: write
595
+ self._state_id_index = dict()
596
+ self._original_state_values = []
597
+ self._jax_trace_new_arg: Callable = new_arg
598
+
599
+ @property
600
+ def original_state_values(self) -> Tuple[PyTree, ...]:
601
+ """
602
+ The original values of the states.
603
+ """
604
+ return tuple(self._original_state_values)
605
+
606
+ def set_new_arg(self, new_arg: Callable) -> None:
607
+ self._jax_trace_new_arg = new_arg
608
+
609
+ def new_arg(self, state: State) -> None:
610
+ if self._jax_trace_new_arg is not None:
611
+ # internal use
612
+ state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
613
+
614
+ def __enter__(self) -> 'StateTraceStack':
615
+ TRACE_CONTEXT.state_stack.append(self)
616
+ return self
617
+
618
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
619
+ TRACE_CONTEXT.state_stack.pop()
620
+
621
+ def read_its_value(self, state: State) -> None:
622
+ """
623
+ Read the value of the state.
624
+
625
+ Args:
626
+ state: The state.
627
+ """
628
+ id_ = id(state)
629
+ if id_ not in self._state_id_index:
630
+ self._state_id_index[id_] = len(self.states)
631
+ self.states.append(state)
632
+ self.been_writen.append(False)
633
+ self._original_state_values.append(state._value) # internal use
634
+ self.new_arg(state)
635
+
636
+ def write_its_value(self, state: State) -> None:
637
+ """
638
+ Write the value of the state.
639
+
640
+ Args:
641
+ state: The state.
642
+ """
643
+ id_ = id(state)
644
+ if id_ not in self._state_id_index:
645
+ self.read_its_value(state)
646
+ index = self._state_id_index[id_]
647
+ self.been_writen[index] = True
648
+
649
+ def get_state_values(self, separate: bool = False, replace: bool = False
650
+ ) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
651
+ """
652
+ Get the values of the states.
653
+ """
654
+ if separate:
655
+ if replace:
656
+ writes, reads = [], []
657
+ for st, been_writen in zip(self.states, self.been_writen):
658
+ if been_writen:
659
+ writes.append(st.value)
660
+ reads.append(None)
661
+ else:
662
+ reads.append(st.value)
663
+ writes.append(None)
664
+ return tuple(writes), tuple(reads)
665
+ else:
666
+ writes, reads = [], []
667
+ for st, been_writen in zip(self.states, self.been_writen):
668
+ if been_writen:
669
+ writes.append(st.value)
670
+ else:
671
+ reads.append(st.value)
672
+ return tuple(writes), tuple(reads)
673
+ else:
674
+ return tuple([st.value for st in self.states])
675
+
676
+ def recovery_original_values(self) -> None:
677
+ """
678
+ Recovery the original values.
679
+ """
680
+ for st, val in zip(self.states, self._original_state_values):
681
+ # internal use
682
+ st._value = val
683
+
684
+ def merge(self, *traces) -> 'StateTraceStack':
685
+ """
686
+ Merge other state traces.
687
+ """
688
+ trace: StateTraceStack
689
+ for trace in traces:
690
+ for st, been_writen, org_val in zip(trace.states, trace.been_writen, trace._original_state_values):
691
+ if id(st) not in self._state_id_index: # read the value
692
+ self._state_id_index[id(st)] = len(self.states)
693
+ self._original_state_values.append(org_val) # add the original value
694
+ self.states.append(st) # append the state
695
+ self.been_writen.append(False)
696
+ if been_writen:
697
+ self.write_its_value(st)
698
+ return self
699
+
700
+ def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
701
+ """
702
+ Read the states that are read by the function.
703
+
704
+ Returns:
705
+ The states that are read by the function.
706
+ """
707
+ if replace_writen:
708
+ return tuple([st if not been_writen else None
709
+ for st, been_writen in zip(self.states, self.been_writen)])
710
+ else:
711
+ return tuple([st for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
712
+
713
+ def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
714
+ """
715
+ Read the states that are read by the function.
716
+
717
+ Returns:
718
+ The states that are read by the function.
719
+ """
720
+ if replace_writen:
721
+ return tuple(
722
+ [st.value if not been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
723
+ else:
724
+ return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
725
+
726
+ def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
727
+ """
728
+ Read the states that are written by the function.
729
+
730
+ Returns:
731
+ The states that are written by the function.
732
+ """
733
+ if replace_read:
734
+ return tuple([st if been_writen else None
735
+ for st, been_writen in zip(self.states, self.been_writen)])
736
+ else:
737
+ return tuple([st for st, been_writen in zip(self.states, self.been_writen) if been_writen])
738
+
739
+ def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
740
+ """
741
+ Read the states that are written by the function.
742
+
743
+ Returns:
744
+ The states that are written by the function.
745
+ """
746
+ if replace_read:
747
+ return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
748
+ else:
749
+ return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if been_writen])
750
+
751
+ def __add__(self, other: 'StateTraceStack') -> 'StateTraceStack':
752
+ """
753
+ Support the syntax of `+` to merge the state traces.
754
+ """
755
+ return StateTraceStack().merge(self, other)
756
+
757
+
758
+ class TreefyState(Generic[A], PrettyRepr):
385
759
  """
386
- id_ = id(state)
387
- if id_ not in self._id2index:
388
- self.read_its_value(state)
389
- if id_ not in self._written_ids:
390
- index = self._id2index[id_]
391
- self.types[index] = 'write'
392
- self._written_ids.add(id_)
393
-
394
- def collect_values(self, *categories: str, check_val_tree: bool = False) -> Tuple:
760
+ The state as a pytree.
395
761
  """
396
- Collect the values by the given categories.
397
762
 
398
- Args:
399
- *categories: The categories.
400
- check_val_tree: Whether to check the tree structure of the value.
763
+ def __init__(
764
+ self,
765
+ type: type[State[Any]],
766
+ value: A,
767
+ **metadata
768
+ ):
769
+ self.type = type
770
+ self.value = value
771
+ vars(self).update(metadata)
772
+
773
+ if TYPE_CHECKING:
774
+ def __getattr__(self, name: str) -> None: ...
775
+
776
+ def __setattr__(self, name: str, value: Any) -> None: ...
777
+
778
+ def __delattr__(self, name: str) -> None: ...
779
+
780
+ def __pretty_repr__(self):
781
+ yield PrettyType(type=type(self))
782
+ yield PrettyAttr('type', self.type.__name__)
783
+ for name, value in vars(self).items():
784
+ if name == '_value':
785
+ name = 'value'
786
+ if name == '_name':
787
+ if value is None:
788
+ continue
789
+ else:
790
+ name = 'name'
791
+ if name in ['_trace_state', '_level', '_source_info', 'type']:
792
+ continue
793
+ yield PrettyAttr(name, repr(value))
794
+
795
+ def __treescope_repr__(self, path, subtree_renderer):
796
+ children = {'type': self.type}
797
+ for name, value in vars(self).items():
798
+ if name == 'type':
799
+ continue
800
+ children[name] = value
801
+
802
+ import treescope # type: ignore[import-not-found,import-untyped]
803
+ return treescope.repr_lib.render_object_constructor(
804
+ object_type=type(self),
805
+ attributes=children,
806
+ path=path,
807
+ subtree_renderer=subtree_renderer,
808
+ )
401
809
 
402
- Returns:
403
- results: The values.
404
- """
405
- results = []
406
- for st, ty in zip(self.states, self.types):
407
- if ty in categories:
408
- val = st.value
409
- if check_val_tree:
410
- st._check_value_tree(val)
411
- results.append(val)
412
- return tuple(results)
413
-
414
- def recovery_original_values(self) -> None:
415
- """
416
- Recovery the original values.
417
- """
418
- for st, val in zip(self.states, self._org_values):
419
- # internal use
420
- st._value = val
810
+ def replace(self, value: B) -> TreefyState[B]:
811
+ """
812
+ Replace the value of the state reference.
813
+ """
814
+ return TreefyState(self.type, value, **self.get_metadata())
815
+
816
+ def to_state(self) -> State[A]:
817
+ """
818
+ Convert the state reference to the state.
819
+ """
820
+ # we use object.__new__ to avoid calling __init__ and bypass the
821
+ # __init__ logic which should not be called twice
822
+ metadata = self.get_metadata()
823
+ state = object.__new__(self.type)
824
+ vars(state).update(metadata, _value=self.value, _trace_state=StateJaxTracer(), _level=_get_trace_stack_level())
825
+ return state
826
+
827
+ def copy(self: TreefyState[A]) -> TreefyState[A]:
828
+ """
829
+ Copy the state reference.
830
+ """
831
+ return jax.tree.map(lambda x: x, self)
832
+
833
+ def get_metadata(self) -> Dict[str, Any]:
834
+ """
835
+ Get the metadata of the state reference
836
+ """
837
+ metadata = vars(self).copy()
838
+ del metadata['type']
839
+ del metadata['value']
840
+ return metadata
841
+
842
+
843
+ def _state_ref_flatten(x: TreefyState[Any], *, with_keys: bool):
844
+ metadata = tuple(x.get_metadata().items())
845
+ if with_keys:
846
+ node = (jax.tree_util.GetAttrKey('value'), x.value)
847
+ else:
848
+ node = x.value
849
+ return (node,), (x.type, metadata)
850
+
851
+
852
+ def _state_ref_unflatten(
853
+ static: Tuple[type[State[A]], Tuple[Tuple[str, Any], ...]],
854
+ children: Tuple[A],
855
+ ) -> TreefyState[A]:
856
+ return TreefyState(type=static[0], value=children[0], **dict(static[1]))
857
+
858
+
859
+ jax.tree_util.register_pytree_with_keys(
860
+ TreefyState,
861
+ partial(_state_ref_flatten, with_keys=True), # type: ignore
862
+ _state_ref_unflatten, # type: ignore
863
+ flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
864
+ )