brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/_state.py CHANGED
@@ -1,1605 +1,1652 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import contextlib
19
- import threading
20
- from functools import partial
21
- from typing import (
22
- Any,
23
- Union,
24
- Callable,
25
- Generic,
26
- TypeVar,
27
- Optional,
28
- TYPE_CHECKING,
29
- Tuple,
30
- Dict,
31
- List,
32
- Sequence,
33
- Generator,
34
- )
35
-
36
- import jax
37
- import numpy as np
38
- from jax.api_util import shaped_abstractify
39
- from jax.extend import source_info_util
40
-
41
- from brainstate.typing import ArrayLike, PyTree, Missing, Filter
42
- from brainstate.util import DictManager, PrettyObject
43
- from brainstate.util.filter import Nothing
44
-
45
- __all__ = [
46
- 'State',
47
- 'ShortTermState',
48
- 'LongTermState',
49
- 'HiddenState',
50
- 'ParamState',
51
- 'BatchState',
52
- 'TreefyState',
53
- 'FakeState',
54
-
55
- 'StateDictManager',
56
- 'StateTraceStack',
57
- 'check_state_value_tree',
58
- 'check_state_jax_tracer',
59
- 'catch_new_states',
60
- 'maybe_state',
61
- ]
62
-
63
- A = TypeVar('A')
64
- B = TypeVar('B')
65
- T = TypeVar('T')
66
- F = TypeVar('F', bound=Callable[..., Any])
67
-
68
- max_int = np.iinfo(np.int32)
69
-
70
-
71
- # The global state of the state stack is accessed by a thread-local object.
72
- # This allows concurrent tracing in separate threads; passing traced objects
73
- # between threads is forbidden.
74
- class ThreadLocalStack(threading.local):
75
- """
76
- A thread-local storage class for managing state-related information.
77
-
78
- This class provides thread-local storage for various state management components,
79
- ensuring that each thread has its own isolated set of state-related data structures.
80
-
81
- Attributes:
82
- state_stack (List[StateTraceStack]): A list to store StateTraceStack objects for the current thread.
83
- tree_check (List[bool]): A list of boolean flags for tree structure checking, initialized with [False].
84
- jax_tracer_check (List[bool]): A list of boolean flags for JAX tracer checking, initialized with [False].
85
- new_state_catcher (List[StateCatcher]): A list to store Catcher objects for capturing new states in the current thread.
86
- """
87
-
88
- def __init__(self):
89
- """
90
- Initialize the ThreadLocalStack with empty data structures.
91
-
92
- This constructor sets up the initial state for each thread-local instance,
93
- creating empty lists for state stack, tree checking, JAX tracer checking,
94
- and new state catching.
95
- """
96
- self.state_stack: List[StateTraceStack] = []
97
- self.tree_check: List[bool] = [False]
98
- self.jax_tracer_check: List[bool] = [False]
99
- self.new_state_catcher: List[StateCatcher] = []
100
-
101
-
102
- TRACE_CONTEXT = ThreadLocalStack()
103
-
104
-
105
- @contextlib.contextmanager
106
- def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
107
- """
108
- The contex manager to check weather the tree structure of the state value keeps consistently.
109
-
110
- Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
111
- the tree structure of the value is not checked to avoid off the repeated evaluation.
112
- If you want to check the tree structure of the value once the new value is assigned,
113
- you can use this context manager.
114
-
115
- Example::
116
-
117
- >>> import brainstate as brainstate
118
- >>> import jax.numpy as jnp
119
- >>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
120
- >>> with brainstate.check_state_value_tree():
121
- >>> # The line below will not raise an error.
122
- >>> state.value = jnp.zeros((2, 3))
123
- ...
124
- >>> # The following code will raise an error, since it changes the tree structure.
125
- >>> state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
126
-
127
- """
128
- try:
129
- TRACE_CONTEXT.tree_check.append(val)
130
- yield
131
- finally:
132
- TRACE_CONTEXT.tree_check.pop()
133
-
134
-
135
- def maybe_state(val: Any) -> Any:
136
- """
137
- Extracts the value from a State object if given, otherwise returns the input value.
138
-
139
- This function is useful for handling both State objects and raw values uniformly.
140
- If the input is a State object, it returns the value stored in that State.
141
- If the input is not a State object, it returns the input as is.
142
-
143
- Args:
144
- val (Any): The input value, which can be either a State object or any other type.
145
-
146
- Returns:
147
- Any: The value stored in the State if the input is a State object,
148
- otherwise the input value itself.
149
- """
150
- if isinstance(val, State):
151
- return val.value
152
- else:
153
- return val
154
-
155
-
156
- @contextlib.contextmanager
157
- def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
158
- """
159
- The context manager to check whether the state is valid to trace.
160
-
161
- Example::
162
-
163
- >>> import jax
164
- >>> import brainstate as brainstate
165
- >>> import jax.numpy as jnp
166
- >>>
167
- >>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
168
- >>>
169
- >>> @jax.jit
170
- >>> def run_state(b):
171
- >>> a.value = b
172
- >>> return a.value
173
- >>>
174
- >>> # The following code will not raise an error, since the state is valid to trace.
175
- >>> run_state(jnp.ones((2, 3)))
176
- >>>
177
- >>> with check_state_jax_tracer():
178
- >>> # The line below will not raise an error.
179
- >>> run_state(jnp.ones((2, 4)))
180
- """
181
- try:
182
- TRACE_CONTEXT.jax_tracer_check.append(val)
183
- yield
184
- finally:
185
- TRACE_CONTEXT.jax_tracer_check.pop()
186
-
187
-
188
- def _get_trace_stack_level() -> int:
189
- return len(TRACE_CONTEXT.state_stack)
190
-
191
-
192
- class State(Generic[A], PrettyObject):
193
- """
194
- A generic class representing a dynamic data pointer in the BrainState framework.
195
-
196
- The State class serves as a base for various types of state objects used to
197
- manage and track dynamic data within a program. It provides mechanisms for
198
- value storage, metadata management, and integration with the BrainState
199
- tracing system.
200
-
201
- Type Parameters:
202
- A: The type of the value stored in the state.
203
-
204
- Attributes:
205
- name (Optional[str]): An optional name for the state.
206
- value (PyTree): The actual value stored in the state.
207
- tag (Optional[str]): An optional tag for categorizing or grouping states.
208
-
209
- Args:
210
- value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
211
- The initial value for the state. Can be a PyTree of array-like objects
212
- or a StateMetadata object.
213
- name (Optional[str]): An optional name for the state.
214
- **metadata: Additional metadata to be stored with the state.
215
-
216
- Example:
217
- >>> class MyState(State):
218
- ... pass
219
- >>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
220
- >>> print(state.value)
221
- [[0. 0. 0.]
222
- [0. 0. 0.]
223
- [0. 0. 0.]]
224
-
225
- Note:
226
- - Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
227
- RandomState) are typically used for specific purposes in a program.
228
- - The class integrates with BrainState's tracing system to track state
229
- creation and modifications.
230
-
231
- The typical examples of :py:class:`~.State` subclass are:
232
-
233
- - :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
234
- - :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
235
- - :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
236
- - :py:class:`RandomState`: The random generator state, which is used to store the random key in the program.
237
-
238
- Args:
239
- value: PyTree. It can be anything as a pyTree.
240
- name: Optional[str]. The name of the state.
241
- tag: Optional[str]. The tag of the state.
242
- """
243
- __module__ = 'brainstate'
244
- _level: int
245
- _source_info: source_info_util.SourceInfo
246
- _name: Optional[str]
247
- _value: PyTree
248
- _been_writen: bool # useful in `unflatten` and `flatten` graph processing
249
- tag: Optional[str]
250
-
251
- def __init__(
252
- self,
253
- value: PyTree[ArrayLike],
254
- name: Optional[str] = None,
255
- **metadata: Any
256
- ):
257
- """
258
- Initialize a new HiddenState instance.
259
-
260
- This constructor sets up the initial state for a hidden state in a dynamic model,
261
- handling various input types and metadata.
262
-
263
- Args:
264
- value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
265
- The initial value for the hidden state. Can be a PyTree of array-like objects
266
- or a StateMetadata object containing both value and metadata.
267
- name (Optional[str], optional): A name for the hidden state. Defaults to None.
268
- **metadata: Additional metadata to be stored with the hidden state, including:
269
- - tag (Optional[str]): A tag for categorizing or grouping states.
270
- - Any other custom metadata fields.
271
-
272
- Note:
273
- This method initializes the hidden state, processes the input value and metadata,
274
- sets up internal attributes, and records the state initialization.
275
- """
276
- tag = metadata.pop('tag', None)
277
-
278
- # set the value and metadata
279
- if isinstance(value, State):
280
- value = value.value
281
-
282
- # update metadata
283
- metadata.update(
284
- _value=value,
285
- _level=_get_trace_stack_level(),
286
- _source_info=source_info_util.current(),
287
- _name=name,
288
- _been_writen=False,
289
- tag=tag,
290
- )
291
-
292
- # avoid using self._setattr to avoid the check
293
- vars(self).update(metadata)
294
-
295
- # record the state initialization
296
- record_state_init(self)
297
-
298
- def decrease_stack_level(self):
299
- """
300
- Decrease the stack level of the state by one, ensuring it doesn't go below zero.
301
-
302
- This method is used to adjust the stack level of the state, typically when
303
- exiting a nested context or scope. It ensures that the level never becomes
304
- negative.
305
- """
306
- self._level = max(self._level - 1, 0)
307
-
308
- def increase_stack_level(self):
309
- """
310
- Increase the stack level of the state by one.
311
-
312
- This method is used to adjust the stack level of the state, typically when
313
- entering a nested context or scope. It increments the internal level counter
314
- by one.
315
- """
316
- self._level = self._level + 1
317
-
318
- @property
319
- def name(self) -> Optional[str]:
320
- """
321
- The name of the state.
322
- """
323
- return self._name
324
-
325
- @name.setter
326
- def name(self, name: str) -> None:
327
- """
328
- Set the name of the state.
329
- """
330
- self._name = name
331
-
332
- @property
333
- def value(self) -> PyTree[ArrayLike]:
334
- """
335
- The data and its value.
336
- """
337
- record_state_value_read(self)
338
- return self._read_value()
339
-
340
- @value.setter
341
- def value(self, v) -> None:
342
- """
343
- Set the value of the state.
344
-
345
- Args:
346
- v: The value.
347
- """
348
- # NOTE: the following order is important
349
-
350
- if isinstance(v, State): # value checking
351
- raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
352
- self._check_value_tree(v) # check the tree structure
353
- record_state_value_write(self) # record the value by the stack (>= level)
354
- self._been_writen = True # set the flag
355
- self._write_value(v) # write the value
356
-
357
- @property
358
- def stack_level(self):
359
- """
360
- The stack level of the state.
361
-
362
- Returns:
363
- The stack level.
364
- """
365
- return self._level
366
-
367
- @stack_level.setter
368
- def stack_level(self, level: int):
369
- """
370
- Set the stack level of the state.
371
-
372
- Args:
373
- level: The stack level.
374
- """
375
- self._level = level
376
-
377
- def _read_value(self) -> PyTree[ArrayLike]:
378
- """
379
- The interface to customize the value reading.
380
- """
381
- self.check_if_deleted()
382
- return self._value
383
-
384
- def _write_value(self, v) -> None:
385
- """
386
- The interface to customize the value writing.
387
- """
388
- self._value = v
389
-
390
- def restore_value(self, v) -> None:
391
- """
392
- Restore the value of the state.
393
-
394
- Args:
395
- v: The value.
396
- """
397
- # value checking
398
- if isinstance(v, State):
399
- raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
400
- with check_state_value_tree():
401
- self._check_value_tree(v)
402
- # record the value by the stack (>= level)
403
- record_state_value_restore(self)
404
- # set the value
405
- self._value = v
406
-
407
- def value_call(self, func: Callable[..., Any]) -> Any:
408
- """
409
- Call the function with the value of the state.
410
- """
411
- return jax.tree.map(func, self.value)
412
-
413
- def _check_value_tree(self, v):
414
- """
415
- Check if the value tree structure is consistent.
416
- """
417
- if TRACE_CONTEXT.tree_check[-1]:
418
- in_tree = jax.tree.structure(v)
419
- self_tree = jax.tree.structure(self._value)
420
- if in_tree != self_tree:
421
- self.raise_error_with_source_info(
422
- ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
423
- )
424
-
425
- def raise_error_with_source_info(self, error: Exception):
426
- """
427
- Raise an error with the source information for easy debugging.
428
- """
429
- name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
430
- with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
431
- raise error
432
-
433
- def check_if_deleted(self):
434
- pass
435
-
436
- @property
437
- def source_info(self) -> source_info_util.SourceInfo:
438
- """
439
- The source information of the state, can be useful to identify
440
- the source code where the definition of the state.
441
-
442
- Returns:
443
- The source information.
444
- """
445
- return self._source_info
446
-
447
- def update_from_ref(self, state_ref: TreefyState[A]) -> None:
448
- """
449
- Update the state from the state reference :py:class:`TreefyState`.
450
-
451
- Args:
452
- state_ref: The state reference.
453
- """
454
- metadata = state_ref.get_metadata()
455
- variable_vars = vars(self)
456
- variable_vars.update(**metadata)
457
- if metadata.pop('_been_writen', True):
458
- self.value = state_ref.value
459
- else:
460
- self.restore_value(state_ref.value)
461
-
462
- def replace(self, value: Any = Missing, **kwargs) -> State[Any]:
463
- """
464
- Replace the attribute of the state.
465
- """
466
- if value is not Missing:
467
- kwargs['_value'] = value
468
-
469
- # return `value` if it is a State
470
- if '_value' in kwargs and isinstance(value := kwargs['_value'], State):
471
- # remove value from kwargs
472
- kwargs.pop('_value')
473
- if type(self) is not type(value):
474
- raise ValueError('Cannot replace value from incompatible container, '
475
- f'expected {type(self).__name__}, got {type(value).__name__}')
476
- # if kwargs aren't empty, recursively call replace
477
- # else return variable value
478
- if kwargs:
479
- return value.replace(**kwargs)
480
- else:
481
- return value
482
-
483
- # get and update attributes
484
- attributes = vars(self).copy()
485
- attributes.update(**kwargs)
486
- # return new instance with updated attributes
487
- obj = object.__new__(type(self))
488
- vars(obj).update(attributes)
489
- return obj
490
-
491
- def copy(self: State[A]) -> State[A]:
492
- """
493
- Copy the state.
494
- """
495
- obj = object.__new__(type(self))
496
- attributes = vars(self).copy()
497
- # keep its own trace state and stack level
498
- attributes['_level'] = _get_trace_stack_level()
499
- attributes['_source_info'] = source_info_util.current()
500
- attributes.pop('_been_writen', None)
501
- # update the metadata
502
- vars(obj).update(attributes)
503
- return obj
504
-
505
- def to_state_ref(self: State[A]) -> TreefyState[A]:
506
- metadata = vars(self).copy()
507
- del metadata['_value']
508
- return TreefyState(type(self), self._value, **metadata)
509
-
510
- def __pretty_repr_item__(self, k, v):
511
- if k in ['_level', '_source_info', '_been_writen']:
512
- return None
513
- if k == '_value':
514
- return 'value', jax.tree.map(shaped_abstractify, v)
515
-
516
- if k == '_name':
517
- if self.name is None:
518
- return None
519
- else:
520
- return 'name', v
521
-
522
- if k == 'tag':
523
- if self.tag is None:
524
- return None
525
- else:
526
- return 'tag', v
527
-
528
- return k, v
529
-
530
- # def __eq__(self, other: object) -> bool:
531
- # return type(self) is type(other) and vars(other) == vars(self)
532
-
533
- def __hash__(self):
534
- """
535
- Make the state hashable.
536
- """
537
- return hash(id(self))
538
-
539
- def numel(self) -> int:
540
- """
541
- Calculate the total number of elements in the state value.
542
-
543
- This method traverses the state's value, which may be a nested structure (PyTree),
544
- and computes the sum of sizes of all leaf nodes.
545
-
546
- Returns:
547
- int: The total number of elements across all arrays in the state value.
548
- For scalar values, this will be 1. For arrays or nested structures,
549
- it will be the sum of the sizes of all contained arrays.
550
-
551
- Note:
552
- This method uses jax.tree.leaves to flatten any nested structure in the state value,
553
- and jax.numpy.size to compute the size of each leaf node.
554
- """
555
- sizes = [jax.numpy.size(val) for val in jax.tree.leaves(self._value)]
556
- return sum(sizes)
557
-
558
-
559
- def record_state_init(st: State[A]):
560
- """
561
- Record the initialization of a new :class:`State` object.
562
-
563
- This function iterates through all registered state catchers in the current
564
- trace context and appends the newly initialized state to each catcher.
565
-
566
- Args:
567
- st (State[A]): The newly initialized :class:`State` object to be recorded.
568
-
569
- Note:
570
- This function is typically called internally when a new :class:`State` object
571
- is created to ensure proper tracking and management of states within
572
- the current execution context.
573
- """
574
- trace: StateCatcher
575
- for trace in TRACE_CONTEXT.new_state_catcher:
576
- trace.append(st)
577
-
578
-
579
- def record_state_value_read(st: State[A]):
580
- """
581
- Record that a state's value has been read in all relevant trace stacks.
582
-
583
- This function iterates through all state trace stacks at or above the
584
- state's stack level in the current trace context, and records that
585
- the given state's value has been read.
586
-
587
- Args:
588
- st (State[A]): The state object whose value read is being recorded.
589
- 'A' is a generic type parameter representing the
590
- type of the state's value.
591
-
592
- Note:
593
- This function modifies the state trace stacks in the current
594
- trace context but does not return any value.
595
- """
596
- trace: StateTraceStack
597
- for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
598
- trace.read_its_value(st)
599
-
600
-
601
- def record_state_value_write(st: State[A]):
602
- """
603
- Record that a state's value has been written in all relevant trace stacks.
604
-
605
- This function iterates through all state trace stacks at or above the
606
- state's stack level in the current trace context, and records that
607
- the given state's value has been written.
608
-
609
- Args:
610
- st (State[A]): The state object whose value write is being recorded.
611
- 'A' is a generic type parameter representing the
612
- type of the state's value.
613
-
614
- Note:
615
- This function modifies the state trace stacks in the current
616
- trace context but does not return any value.
617
- """
618
- trace: StateTraceStack
619
- for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
620
- trace.write_its_value(st)
621
-
622
-
623
- def record_state_value_restore(st: State[A]):
624
- """
625
- Record that a state's value has been restored.
626
-
627
- This function is used to indicate that a state's value has been restored
628
- to a previous value. It internally calls the record_state_value_read
629
- function to mark the state as having been accessed.
630
-
631
- Args:
632
- st (State[A]): The state object whose value restoration is being recorded.
633
- 'A' is a generic type parameter representing the
634
- type of the state's value.
635
-
636
- See Also:
637
- record_state_value_read: Record that a state's value has been read.
638
-
639
- Note:
640
- This function does not actually restore the state's value; it only
641
- records that a restoration has occurred.
642
- """
643
- record_state_value_read(st)
644
-
645
-
646
- class ShortTermState(State):
647
- """
648
- A class representing short-term state in a program.
649
-
650
- :class:`ShortTermState` is used to store temporary or transient data that is only relevant
651
- for a short duration within the program's execution. This class extends the base
652
- State class, inheriting its properties and methods while specifically denoting
653
- the short-term nature of the stored data.
654
-
655
- For example, in a machine learning training process, the gradients of the model
656
- would typically be represented as :class:`ShortTermState`, as they are computed and used
657
- within each iteration but not necessarily preserved across iterations.
658
-
659
- Attributes:
660
- Inherits all attributes from the base State class.
661
-
662
- Note:
663
- This class does not introduce new methods or attributes beyond those
664
- inherited from the State class. Its primary purpose is to semantically
665
- distinguish short-term states from other types of states in the program.
666
-
667
- Example:
668
- >>> gradient = ShortTermState(np.zeros(100), name="model_gradient")
669
- >>> intermediate_result = ShortTermState({}, name="layer_activations")
670
- """
671
-
672
- __module__ = 'brainstate'
673
-
674
-
675
- class LongTermState(State):
676
- """
677
- The long-term state, which is used to store the long-term data in the program.
678
-
679
- This class extends the base :class:`State` class and is specifically designed to represent
680
- and manage long-term data within a program. Long-term states are typically used
681
- for data that persists across multiple iterations or epochs of a process.
682
-
683
- For example, in a training process, the weights of the model are considered
684
- long-term states as they are updated and maintained throughout the entire
685
- training procedure.
686
-
687
- Attributes:
688
- Inherits all attributes from the base :class:`State` class.
689
-
690
- Note:
691
- This class does not introduce new methods or attributes beyond those
692
- inherited from the :class:`State` class. Its primary purpose is to semantically
693
- distinguish long-term states from other types of states in the program.
694
-
695
- Example:
696
- >>> model_weights = LongTermState(np.random.randn(100, 100), name="model_weights")
697
- >>> optimizer_state = LongTermState({}, name="optimizer_state")
698
- """
699
-
700
- __module__ = 'brainstate'
701
-
702
-
703
- class BatchState(LongTermState):
704
- """
705
- The batch state, which is used to store the batch data in the program.
706
-
707
- This class extends :class:`LongTermState` and is specifically designed to represent
708
- and manage batch data within a program. It provides a way to encapsulate
709
- batch-related information and associated metadata, facilitating operations
710
- like batch processing in machine learning or data analysis tasks.
711
-
712
- Attributes:
713
- Inherits all attributes from :class:`LongTermState`.
714
-
715
- Note:
716
- This class does not introduce new methods or attributes beyond those
717
- inherited from :class:`LongTermState`. Its primary purpose is to semantically
718
- distinguish batch states from other types of long-term states
719
- in the program.
720
-
721
- Example:
722
- >>> batch_data = BatchState(np.array([1, 2, 3, 4, 5]), name="current_batch")
723
- >>> batch_labels = BatchState(np.array([0, 1, 0, 1, 1]), name="batch_labels")
724
- """
725
-
726
- __module__ = 'brainstate'
727
-
728
-
729
- class HiddenState(ShortTermState):
730
- """
731
- Represents hidden state variables in neurons or synapses.
732
-
733
- This class extends :class:`ShortTermState` and is specifically designed to represent
734
- and manage hidden states within dynamic models, such as recurrent neural networks.
735
- It provides a way to encapsulate hidden state values and associated metadata,
736
- facilitating operations like state updates during model execution.
737
-
738
- Note:
739
- :class:`HiddenState` and :class:`ParamState` are two most important state types
740
- in brainstate. The former is used to store the hidden states in neurons, synapses,
741
- or networks. The latter is used to store the trainable parameters in the model,
742
- such as synaptic weights.
743
-
744
- Example:
745
- >>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
746
- >>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
747
- """
748
-
749
- __module__ = 'brainstate'
750
-
751
-
752
- class ParamState(LongTermState):
753
- """
754
- The parameter state, which is used to store the trainable parameters in the model.
755
-
756
- This class extends :class:`LongTermState` and is specifically designed to represent
757
- and manage trainable parameters within a neural network or machine learning model.
758
- It provides a way to encapsulate parameter values and associated metadata,
759
- facilitating operations like parameter updates during training.
760
-
761
- Note:
762
- :class:`HiddenState` and :class:`ParamState` are two most important state types
763
- in brainstate. The former is used to store the hidden states in neurons, synapses,
764
- or networks. The latter is used to store the trainable parameters in the model,
765
- such as synaptic weights.
766
-
767
- Example:
768
- >>> weight = ParamState(np.random.randn(10, 10), name="layer1_weights")
769
- >>> bias = ParamState(np.zeros(10), name="layer1_bias")
770
- """
771
-
772
- __module__ = 'brainstate'
773
-
774
-
775
- class FakeState:
776
- """
777
- The faked state, which is used to store the faked data in the program.
778
- """
779
-
780
- __module__ = 'brainstate'
781
-
782
- def __init__(self, value: Any, name: Optional[str] = None):
783
- """
784
- Initialize a FakeState instance.
785
-
786
- Args:
787
- value (Any): The value to be stored in the fake state.
788
- name (Optional[str], optional): The name of the fake state. Defaults to None.
789
- """
790
- self._value = value
791
- self._name = name
792
-
793
- @property
794
- def value(self) -> Any:
795
- """
796
- Get the value stored in the fake state.
797
-
798
- Returns:
799
- Any: The value stored in the fake state.
800
- """
801
- return self._value
802
-
803
- @value.setter
804
- def value(self, v) -> None:
805
- """
806
- Set the value of the fake state.
807
-
808
- Args:
809
- v (Any): The new value to be stored in the fake state.
810
- """
811
- self._value = v
812
-
813
- def __repr__(self) -> str:
814
- """
815
- Return a string representation of the FakeState instance.
816
-
817
- Returns:
818
- str: A string representation of the FakeState instance.
819
- """
820
- return f'FakedState(value={self._value})'
821
-
822
- @property
823
- def name(self) -> Optional[str]:
824
- """
825
- Get the name of the fake state.
826
-
827
- Returns:
828
- Optional[str]: The name of the fake state, or None if not set.
829
- """
830
- return self._name
831
-
832
- @name.setter
833
- def name(self, name: str) -> None:
834
- """
835
- Set the name of the fake state.
836
-
837
- Args:
838
- name (str): The new name for the fake state.
839
- """
840
- self._name = name
841
-
842
-
843
- class StateDictManager(DictManager):
844
- """
845
- State stack, for collecting all :py:class:`~.State` used in the program.
846
-
847
- :py:class:`~.StateDictManager` supports all features of python dict.
848
- """
849
-
850
- __module__ = 'brainstate'
851
-
852
- def assign_values(self, *args: Dict) -> None:
853
- """
854
- Assign the value for each element according to the given ``data``.
855
- """
856
- for arg in args:
857
- assert isinstance(arg, dict), 'Must be an instance of dict.'
858
- for k, v in arg.items():
859
- self._set_elem(k, v)
860
-
861
- def split_values(self, *filters: type) -> Tuple[Dict, ...]:
862
- """
863
- Split the values into several subsets of stack by the given types.
864
- """
865
- results = tuple(DictManager() for _ in range(len(filters) + 1))
866
- for k, v in self.items():
867
- for i, filt in enumerate(filters):
868
- if isinstance(v, filt):
869
- results[i][k] = v.value
870
- break
871
- else:
872
- results[-1][k] = v.value
873
- return results
874
-
875
- def collect_values(self) -> Dict:
876
- """
877
- Collect the values by the given types.
878
- """
879
- results = DictManager()
880
- for k, v in self.items():
881
- results[k] = v.value
882
- return results
883
-
884
- def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
885
- return super().split(first, *others)
886
-
887
- def to_dict_values(self) -> Dict:
888
- """
889
- Convert the values into a dict.
890
- """
891
- return {k: v.value for k, v in self.items()}
892
-
893
- def _check_elem(self, elem):
894
- assert isinstance(elem, State), f'must be instance of {State}'
895
-
896
- def _set_elem(self, key: Any, value: Any) -> None:
897
- self[key].value = value
898
-
899
-
900
- class StateTraceStack(Generic[A]):
901
- """
902
- A stack for tracing and managing states during program execution.
903
-
904
- ``StateTraceStack`` is used to automatically trace and manage State objects,
905
- keeping track of which states are read from or written to during the
906
- execution of a function or block of code. It provides methods for
907
- recording state accesses, retrieving state values, and managing the
908
- lifecycle of states within a tracing context.
909
-
910
- The class is generic over type A, allowing for type-safe usage with
911
- different types of State objects.
912
-
913
- Attributes:
914
- states (List[State]): A list of all State objects encountered during tracing.
915
- been_writen (List[bool]): A parallel list to states, indicating whether each state has been written to.
916
- _state_id_index (dict): A dictionary mapping state ids to their index in the states list.
917
- _original_state_values (List): A list of the original values of all states when first encountered.
918
- _jax_trace_new_arg (Callable): A function used to transform state values during tracing.
919
-
920
- Methods:
921
- __enter__: Enters a new tracing context.
922
- __exit__: Exits the current tracing context.
923
- read_its_value: Records a read operation on a state.
924
- write_its_value: Records a write operation on a state.
925
- get_state_values: Retrieves the current values of all traced states.
926
- recovery_original_values: Restores all states to their original values.
927
- merge: Merges multiple ``StateTraceStack`` instances.
928
- get_read_states: Retrieves states that were read during tracing.
929
- get_read_state_values: Retrieves values of states that were read during tracing.
930
-
931
- The ``StateTraceStack`` is a crucial component in implementing state-based
932
- computations and is particularly useful in scenarios involving automatic
933
- differentiation or other forms of program transformation.
934
- """
935
-
936
- def __init__(
937
- self,
938
- new_arg: Callable = None,
939
- name: Optional[str] = None,
940
- ):
941
- self.name = name
942
- self.states: List[State] = []
943
- self.been_writen: List[bool] = [] # False: read, True: write
944
- self._state_id_index = dict()
945
- self._original_state_values = []
946
- self._jax_trace_new_arg: Callable = new_arg
947
- self._stack_level = None
948
-
949
- def __str__(self) -> str:
950
- _stack_level = self.name if self._stack_level is None else self._stack_level
951
- if _stack_level is None:
952
- _stack_level = ''
953
- return f"{self.__class__.__name__}({_stack_level})"
954
-
955
- @property
956
- def original_state_values(self) -> Tuple[PyTree, ...]:
957
- """
958
- Get the original values of all states in the StateTraceStack.
959
-
960
- This property provides access to the initial values of all states
961
- that were captured when they were first added to the stack. It's
962
- useful for comparing current state values with their original values
963
- or for reverting states to their initial condition.
964
-
965
- Returns:
966
- Tuple[PyTree, ...]: A tuple containing the original values of all
967
- states in the order they were added to the stack. Each element
968
- is a PyTree representing the structure and values of a state.
969
- """
970
- return tuple(self._original_state_values)
971
-
972
- def set_new_arg(self, new_arg: Callable) -> None:
973
- self._jax_trace_new_arg = new_arg
974
-
975
- def new_arg(self, state: State) -> None:
976
- """
977
- Apply a transformation to the value of a given state using a predefined function.
978
-
979
- This method is used internally to transform the value of a state during tracing.
980
- If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
981
- function to each element of the state's value using JAX's tree mapping.
982
-
983
- Args:
984
- state (State): The State object whose value needs to be transformed.
985
-
986
- Returns:
987
- None: This function modifies the state in-place and doesn't return anything.
988
-
989
- Note:
990
- This method is intended for internal use and relies on the presence of
991
- a ``_jax_trace_new_arg`` function, which should be set separately.
992
- """
993
- if self._jax_trace_new_arg is not None:
994
- # internal use
995
- state._value = jax.tree.map(self._jax_trace_new_arg, state._value)
996
-
997
- def __enter__(self) -> 'StateTraceStack':
998
- TRACE_CONTEXT.state_stack.append(self)
999
- self._stack_level = ' / '.join([st.name for st in TRACE_CONTEXT.state_stack if st.name is not None])
1000
- return self
1001
-
1002
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1003
- TRACE_CONTEXT.state_stack.pop()
1004
-
1005
- def read_its_value(self, state: State) -> None:
1006
- """
1007
- Record that a state's value has been read during tracing.
1008
-
1009
- This method marks the given state as having been read in the current
1010
- tracing context. If the state hasn't been encountered before, it adds
1011
- it to the internal tracking structures and applies any necessary
1012
- transformations via the new_arg method.
1013
-
1014
- Args:
1015
- state (State): The State object whose value is being read.
1016
-
1017
- Returns:
1018
- None
1019
-
1020
- Note:
1021
- This method updates the internal tracking of state accesses.
1022
- It doesn't actually read or return the state's value.
1023
- """
1024
- id_ = id(state)
1025
- if id_ not in self._state_id_index:
1026
- self._state_id_index[id_] = len(self.states)
1027
- self.states.append(state)
1028
- self.been_writen.append(False)
1029
- self._original_state_values.append(state._value) # internal use
1030
- self.new_arg(state)
1031
-
1032
- def write_its_value(self, state: State) -> None:
1033
- """
1034
- Record that a state's value has been written to during tracing.
1035
-
1036
- This method marks the given state as having been written to in the current
1037
- tracing context. If the state hasn't been encountered before, it first
1038
- records it as being read before marking it as written.
1039
-
1040
- Args:
1041
- state (State): The State object whose value is being written to.
1042
-
1043
- Returns:
1044
- None
1045
-
1046
- Note:
1047
- This method updates the internal tracking of state modifications.
1048
- It doesn't actually modify the state's value.
1049
- """
1050
- id_ = id(state)
1051
- if id_ not in self._state_id_index:
1052
- self.read_its_value(state)
1053
- index = self._state_id_index[id_]
1054
- self.been_writen[index] = True
1055
-
1056
- def get_state_values(
1057
- self,
1058
- separate: bool = False,
1059
- replace: bool = False
1060
- ) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1061
- """
1062
- Retrieve the values of all states in the StateTraceStack.
1063
-
1064
- This method returns the values of all states, optionally separating them
1065
- into written and read states, and optionally replacing values with None
1066
- for states that weren't accessed in a particular way.
1067
-
1068
- Args:
1069
- separate (bool, optional): If True, separate the values into written
1070
- and read states. If False, return all values in a single sequence.
1071
- Defaults to False.
1072
- replace (bool, optional): If True and separate is True, replace values
1073
- with None for states that weren't written/read. If False, only
1074
- include values for states that were written/read. Defaults to False.
1075
-
1076
- Returns:
1077
- Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1078
- If separate is False:
1079
- A sequence of all state values.
1080
- If separate is True:
1081
- A tuple containing two sequences:
1082
- - The first sequence contains values of written states.
1083
- - The second sequence contains values of read states.
1084
- If replace is True, these sequences will have None for
1085
- states that weren't written/read respectively.
1086
-
1087
- """
1088
- if separate:
1089
- if replace:
1090
- writes, reads = [], []
1091
- for st, been_writen in zip(self.states, self.been_writen):
1092
- if been_writen:
1093
- writes.append(st.value)
1094
- reads.append(None)
1095
- else:
1096
- reads.append(st.value)
1097
- writes.append(None)
1098
- return tuple(writes), tuple(reads)
1099
- else:
1100
- writes, reads = [], []
1101
- for st, been_writen in zip(self.states, self.been_writen):
1102
- if been_writen:
1103
- writes.append(st.value)
1104
- else:
1105
- reads.append(st.value)
1106
- return tuple(writes), tuple(reads)
1107
- else:
1108
- return tuple([st.value for st in self.states])
1109
-
1110
- def recovery_original_values(self) -> None:
1111
- """
1112
- Restore the original values of all states in the StateTraceStack.
1113
-
1114
- This method iterates through all states in the stack and restores
1115
- their values to the original ones that were captured when the states
1116
- were first added to the stack. This is useful for reverting changes
1117
- made during tracing or for resetting the states to their initial condition.
1118
-
1119
- Note:
1120
- This method modifies the states in-place.
1121
-
1122
- Returns:
1123
- None
1124
- """
1125
- for st, val in zip(self.states, self._original_state_values):
1126
- # internal use
1127
- st.restore_value(val)
1128
-
1129
- def merge(self, *traces) -> 'StateTraceStack':
1130
- """
1131
- Merge other state traces into the current ``StateTraceStack``.
1132
-
1133
- This method combines the states, their write status, and original values from
1134
- other ``StateTraceStack`` instances into the current one. If a state from another
1135
- trace is not present in the current trace, it is added. If a state is already
1136
- present, its write status is updated if necessary.
1137
-
1138
- Args:
1139
- *traces: Variable number of ``StateTraceStack`` instances to be merged into
1140
- the current instance.
1141
-
1142
- Returns:
1143
- StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
1144
-
1145
- Note:
1146
- This method modifies the current ``StateTraceStack`` in-place and also returns it.
1147
- """
1148
- trace: StateTraceStack
1149
- for trace in traces:
1150
- for st, been_writen, org_val in zip(trace.states, trace.been_writen, trace._original_state_values):
1151
- if id(st) not in self._state_id_index: # read the value
1152
- self._state_id_index[id(st)] = len(self.states)
1153
- self._original_state_values.append(org_val) # add the original value
1154
- self.states.append(st) # append the state
1155
- self.been_writen.append(False)
1156
- if been_writen:
1157
- self.write_its_value(st)
1158
- return self
1159
-
1160
- def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
1161
- """
1162
- Retrieve the states that were read during the function execution.
1163
-
1164
- This method returns the states that were accessed (read from) during
1165
- the traced function's execution. It can optionally replace written
1166
- states with None.
1167
-
1168
- Args:
1169
- replace_writen (bool, optional): If True, replace written states with None
1170
- in the returned tuple. If False, exclude written states entirely from
1171
- the result. Defaults to False.
1172
-
1173
- Returns:
1174
- Tuple[State, ...]: A tuple containing the read states.
1175
- If replace_writen is True, the tuple will have the same length as the
1176
- total number of states, with None for written states.
1177
- If replace_writen is False, the tuple will only contain read-only states.
1178
- """
1179
- if replace_writen:
1180
- return tuple([st if not been_writen else None
1181
- for st, been_writen in zip(self.states, self.been_writen)])
1182
- else:
1183
- return tuple([st for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
1184
-
1185
- def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
1186
- """
1187
- Retrieve the values of states that were read during the function execution.
1188
-
1189
- This method returns the values of states that were accessed (read from) during
1190
- the traced function's execution. It can optionally replace written states with None.
1191
-
1192
- Args:
1193
- replace_writen (bool, optional): If True, replace the values of written
1194
- states with None in the returned tuple. If False, exclude written
1195
- states entirely from the result. Defaults to False.
1196
-
1197
- Returns:
1198
- Tuple[PyTree, ...]: A tuple containing the values of read states.
1199
- If replace_writen is True, the tuple will have the same length as the
1200
- total number of states, with None for written states.
1201
- If replace_writen is False, the tuple will only contain values of
1202
- read-only states.
1203
- """
1204
- if replace_writen:
1205
- return tuple(
1206
- [st.value if not been_writen else None
1207
- for st, been_writen in zip(self.states, self.been_writen)]
1208
- )
1209
- else:
1210
- return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
1211
-
1212
- def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
1213
- """
1214
- Retrieve the states that were written during the function execution.
1215
-
1216
- This method returns the states that were modified (written to) during
1217
- the traced function's execution. It can optionally replace unwritten (read-only)
1218
- states with None.
1219
-
1220
- Args:
1221
- replace_read (bool, optional): If True, replace read-only states with None
1222
- in the returned tuple. If False, exclude read-only states entirely from
1223
- the result. Defaults to False.
1224
-
1225
- Returns:
1226
- Tuple[State, ...]: A tuple containing the written states.
1227
- If replace_read is True, the tuple will have the same length as the
1228
- total number of states, with None for read-only states.
1229
- If replace_read is False, the tuple will only contain written states.
1230
- """
1231
- if replace_read:
1232
- return tuple([st if been_writen else None
1233
- for st, been_writen in zip(self.states, self.been_writen)])
1234
- else:
1235
- return tuple([st for st, been_writen in zip(self.states, self.been_writen) if been_writen])
1236
-
1237
- def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
1238
- """
1239
- Retrieve the values of states that were written during the function execution.
1240
-
1241
- This method returns the values of states that were modified (written to) during
1242
- the traced function's execution. It can optionally replace unwritten (read-only)
1243
- states with None.
1244
-
1245
- Args:
1246
- replace_read (bool, optional): If True, replace the values of read-only
1247
- states with None in the returned tuple. If False, exclude read-only
1248
- states entirely from the result. Defaults to False.
1249
-
1250
- Returns:
1251
- Tuple[PyTree, ...]: A tuple containing the values of written states.
1252
- If replace_read is True, the tuple will have the same length as the
1253
- total number of states, with None for read-only states.
1254
- If replace_read is False, the tuple will only contain values of
1255
- written states.
1256
-
1257
- """
1258
- if replace_read:
1259
- return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
1260
- else:
1261
- return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if been_writen])
1262
-
1263
- def __add__(self, other: 'StateTraceStack') -> 'StateTraceStack':
1264
- """
1265
- Support the syntax of `+` to merge the state traces.
1266
- """
1267
- return StateTraceStack().merge(self, other)
1268
-
1269
- def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
1270
- """
1271
- Assign new values to the states tracked by this ``StateTraceStack``.
1272
-
1273
- This method updates the values of the states based on whether they were
1274
- written to or only read during the tracing process. For states that were
1275
- written to, it directly assigns the new value. For states that were only
1276
- read, it restores the value using the state's restore_value method.
1277
-
1278
- Args:
1279
- state_vals (Sequence[PyTree]): A sequence of new state values to be
1280
- assigned. Each element in this sequence corresponds to a state
1281
- in the ``StateTraceStack``'s states list.
1282
-
1283
- Raises:
1284
- ValueError: If the length of state_vals doesn't match the number of
1285
- states in the ``StateTraceStack``.
1286
-
1287
- Returns:
1288
- None
1289
-
1290
- Note:
1291
- The order of state_vals should match the order of states in the
1292
- ``StateTraceStack``'s states list.
1293
- """
1294
- if len(state_vals) != len(self.states):
1295
- raise ValueError('The length of the state values must be equal to the states. '
1296
- f'Bug got {len(state_vals)} and {len(self.states)}')
1297
- for st, written, val in zip(self.states, self.been_writen, state_vals):
1298
- if written:
1299
- st.value = val
1300
- else:
1301
- st.restore_value(val)
1302
-
1303
- def state_subset(self, state_type: type) -> List:
1304
- """
1305
- Get a subset of states of a specific type from the ``StateTraceStack``.
1306
-
1307
- This method filters the states in the ``StateTraceStack`` and returns only
1308
- those that match the specified state type.
1309
-
1310
- Args:
1311
- state_type (type): The type of state to filter by. This should be
1312
- a subclass of State or State itself.
1313
-
1314
- Returns:
1315
- List[State]: A list containing all states in the ``StateTraceStack``
1316
- that are instances of the specified state_type.
1317
-
1318
- Example:
1319
- >>> stack = StateTraceStack()
1320
- >>> # Assume stack has been populated with various state types
1321
- >>> short_term_states = stack.state_subset(ShortTermState)
1322
- """
1323
- return [st for st in self.states if isinstance(st, state_type)]
1324
-
1325
-
1326
- class TreefyState(Generic[A], PrettyObject):
1327
- """
1328
- The state as a pytree.
1329
- """
1330
-
1331
- def __init__(
1332
- self,
1333
- type: type[State[Any]],
1334
- value: A,
1335
- **metadata
1336
- ):
1337
- self.type = type
1338
- self.value = value
1339
- vars(self).update(metadata)
1340
-
1341
- if TYPE_CHECKING:
1342
- def __getattr__(self, name: str) -> None: ...
1343
-
1344
- def __setattr__(self, name: str, value: Any) -> None: ...
1345
-
1346
- def __delattr__(self, name: str) -> None: ...
1347
-
1348
- def __pretty_repr_item__(self, k, v):
1349
- if k in ['_level', '_source_info', '_been_writen']:
1350
- return None
1351
- if k == '_value':
1352
- return 'value', v
1353
-
1354
- if k == '_name':
1355
- return None if v is None else ('name', v)
1356
- return k, v
1357
-
1358
- @property
1359
- def name(self) -> Optional[str]:
1360
- """
1361
- The name of the state.
1362
- """
1363
- return self._name
1364
-
1365
- @name.setter
1366
- def name(self, name: str) -> None:
1367
- """
1368
- Set the name of the state.
1369
- """
1370
- self._name = name
1371
-
1372
- def replace(self, value: B) -> TreefyState[B]:
1373
- """
1374
- Replace the value of the state reference.
1375
- """
1376
- return TreefyState(self.type, value, **self.get_metadata())
1377
-
1378
- def to_state(self) -> State[A]:
1379
- """
1380
- Convert the state reference to the state.
1381
- """
1382
- # we use object.__new__ to avoid calling __init__ and bypass the
1383
- # __init__ logic which should not be called twice
1384
- metadata = self.get_metadata()
1385
- state = object.__new__(self.type)
1386
- metadata.pop('_value', None)
1387
- metadata.pop('_level', None)
1388
- vars(state).update(**metadata, _value=self.value, _level=_get_trace_stack_level())
1389
- return state
1390
-
1391
- def copy(self: TreefyState[A]) -> TreefyState[A]:
1392
- """
1393
- Copy the state reference.
1394
- """
1395
- return jax.tree.map(lambda x: x, self)
1396
-
1397
- def get_metadata(self) -> Dict[str, Any]:
1398
- """
1399
- Get the metadata of the state reference
1400
- """
1401
- metadata = vars(self).copy()
1402
- del metadata['type']
1403
- del metadata['value']
1404
- return metadata
1405
-
1406
-
1407
- def _state_ref_flatten(x: TreefyState[Any], *, with_keys: bool):
1408
- metadata = tuple(x.get_metadata().items())
1409
- if with_keys:
1410
- node = (jax.tree_util.GetAttrKey('value'), x.value)
1411
- else:
1412
- node = x.value
1413
- return (node,), (x.type, metadata)
1414
-
1415
-
1416
- def _state_ref_unflatten(
1417
- static: Tuple[type[State[A]], Tuple[Tuple[str, Any], ...]],
1418
- children: Tuple[A],
1419
- ) -> TreefyState[A]:
1420
- return TreefyState(type=static[0], value=children[0], **dict(static[1]))
1421
-
1422
-
1423
- jax.tree_util.register_pytree_with_keys(
1424
- TreefyState,
1425
- partial(_state_ref_flatten, with_keys=True), # type: ignore
1426
- _state_ref_unflatten, # type: ignore
1427
- flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
1428
- )
1429
-
1430
-
1431
- class StateCatcher(PrettyObject):
1432
- """
1433
- The catcher to catch and manage new states.
1434
-
1435
- This class provides functionality to collect and tag new State objects.
1436
- It ensures that each state is only added once and assigns a tag to each state.
1437
-
1438
- Attributes:
1439
- state_tag (str): A string identifier used to tag the caught states.
1440
- state_ids (set): A set of state IDs to ensure uniqueness.
1441
- states (list): A list to store the caught State objects.
1442
- """
1443
-
1444
- def __init__(
1445
- self,
1446
- state_tag: str,
1447
- state_to_exclude: Filter = Nothing()
1448
- ):
1449
- """
1450
- Initialize a new Catcher instance.
1451
-
1452
- Args:
1453
- state_tag (str): The tag to be assigned to caught states.
1454
- state_to_exclude (Filter, optional): A filter to exclude states from being caught.
1455
- """
1456
- if state_to_exclude is None:
1457
- state_to_exclude = Nothing()
1458
- self.state_to_exclude = state_to_exclude
1459
- self.state_tag = state_tag
1460
- self.state_ids = set()
1461
- self.states = []
1462
-
1463
- def get_state_values(self) -> List[PyTree]:
1464
- """
1465
- Get the values of the caught states.
1466
-
1467
- Returns:
1468
- list: A list of values of the caught states.
1469
- """
1470
- return [state.value for state in self.states]
1471
-
1472
- def get_states(self) -> List[State]:
1473
- """
1474
- Get the caught states.
1475
-
1476
- Returns:
1477
- list: A list of the caught states.
1478
- """
1479
- return self.states
1480
-
1481
- def append(self, state: State):
1482
- """
1483
- Add a new state to the catcher if it hasn't been added before.
1484
-
1485
- This method adds the state to the internal list, records its ID,
1486
- and assigns the catcher's tag to the state.
1487
-
1488
- Args:
1489
- state (State): The State object to be added.
1490
- """
1491
- if self.state_to_exclude((), state):
1492
- return
1493
- if id(state) not in self.state_ids:
1494
- self.state_ids.add(id(state))
1495
- self.states.append(state)
1496
- state.tag = self.state_tag
1497
-
1498
- def __iter__(self):
1499
- """
1500
- Allow iteration over the caught states.
1501
-
1502
- Returns:
1503
- iterator: An iterator over the list of caught states.
1504
- """
1505
- return iter(self.states)
1506
-
1507
- def __len__(self):
1508
- """
1509
- Return the number of caught states.
1510
-
1511
- Returns:
1512
- int: The number of caught states.
1513
- """
1514
- return len(self.states)
1515
-
1516
- def __getitem__(self, index):
1517
- """
1518
- Get a state by index.
1519
-
1520
- Args:
1521
- index (int): The index of the state to retrieve.
1522
-
1523
- Returns:
1524
- State: The state at the specified index.
1525
- """
1526
- return self.states[index]
1527
-
1528
- def clear(self):
1529
- """
1530
- Clear all caught states.
1531
- """
1532
- self.state_ids.clear()
1533
- self.states.clear()
1534
-
1535
- def get_by_tag(self, tag: str):
1536
- """
1537
- Get all states with a specific tag.
1538
-
1539
- Args:
1540
- tag (str): The tag to filter by.
1541
-
1542
- Returns:
1543
- list: A list of states with the specified tag.
1544
- """
1545
- return [state for state in self.states if state.tag == tag]
1546
-
1547
- def remove(self, state: State):
1548
- """
1549
- Remove a specific state from the catcher.
1550
-
1551
- Args:
1552
- state (State): The state to remove.
1553
- """
1554
- if id(state) in self.state_ids:
1555
- self.state_ids.remove(id(state))
1556
- self.states.remove(state)
1557
-
1558
- def __contains__(self, state: State):
1559
- """
1560
- Check if a state is in the catcher.
1561
-
1562
- Args:
1563
- state (State): The state to check for.
1564
-
1565
- Returns:
1566
- bool: True if the state is in the catcher, False otherwise.
1567
- """
1568
- return id(state) in self.state_ids
1569
-
1570
-
1571
- @contextlib.contextmanager
1572
- def catch_new_states(
1573
- state_tag: str = None,
1574
- state_to_exclude: Filter = Nothing()
1575
- ) -> Generator[StateCatcher, None, None]:
1576
- """
1577
- A context manager that catches and tracks new states created within its scope.
1578
-
1579
- This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
1580
- new_state_catcher list. It allows for tracking and managing new states created
1581
- within the context.
1582
-
1583
- Args:
1584
- state_tag (str, optional): A string tag to associate with the caught states.
1585
- Defaults to None.
1586
- state_to_exclude (Filter, optional): A filter object to specify which states
1587
- should be excluded from catching. Defaults to Nothing(), which excludes no states.
1588
-
1589
- Yields:
1590
- Catcher: A Catcher object that can be used to access and manage the
1591
- newly created states within the context.
1592
-
1593
- Example::
1594
-
1595
- with catch_new_states("my_tag") as catcher:
1596
- # Create new states here
1597
- # They will be caught and tagged with "my_tag"
1598
- # Access caught states through catcher object
1599
- """
1600
- try:
1601
- catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
1602
- TRACE_CONTEXT.new_state_catcher.append(catcher)
1603
- yield catcher
1604
- finally:
1605
- TRACE_CONTEXT.new_state_catcher.pop()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import contextlib
19
+ import threading
20
+ from functools import partial
21
+ from typing import (
22
+ Any,
23
+ Union,
24
+ Callable,
25
+ Generic,
26
+ TypeVar,
27
+ Optional,
28
+ TYPE_CHECKING,
29
+ Tuple,
30
+ Dict,
31
+ List,
32
+ Sequence,
33
+ Generator,
34
+ )
35
+
36
+ import jax
37
+ import numpy as np
38
+ from jax.api_util import shaped_abstractify
39
+ from jax.extend import source_info_util
40
+
41
+ from brainstate.typing import ArrayLike, PyTree, Missing, Filter
42
+ from brainstate.util import DictManager, PrettyObject
43
+ from brainstate.util.filter import Nothing
44
+
45
+ __all__ = [
46
+ 'State',
47
+ 'ShortTermState',
48
+ 'LongTermState',
49
+ 'HiddenState',
50
+ 'ParamState',
51
+ 'BatchState',
52
+ 'TreefyState',
53
+ 'FakeState',
54
+
55
+ 'StateDictManager',
56
+ 'StateTraceStack',
57
+ 'check_state_value_tree',
58
+ 'check_state_jax_tracer',
59
+ 'catch_new_states',
60
+ 'maybe_state',
61
+ ]
62
+
63
+ A = TypeVar('A')
64
+ B = TypeVar('B')
65
+ T = TypeVar('T')
66
+ F = TypeVar('F', bound=Callable[..., Any])
67
+
68
+ max_int = np.iinfo(np.int32)
69
+
70
+
71
+ # The global state of the state stack is accessed by a thread-local object.
72
+ # This allows concurrent tracing in separate threads; passing traced objects
73
+ # between threads is forbidden.
74
+ class ThreadLocalStack(threading.local):
75
+ """
76
+ A thread-local storage class for managing state-related information.
77
+
78
+ This class provides thread-local storage for various state management components,
79
+ ensuring that each thread has its own isolated set of state-related data structures.
80
+
81
+ Attributes:
82
+ state_stack (List[StateTraceStack]): A list to store StateTraceStack objects for the current thread.
83
+ tree_check (List[bool]): A list of boolean flags for tree structure checking, initialized with [False].
84
+ jax_tracer_check (List[bool]): A list of boolean flags for JAX tracer checking, initialized with [False].
85
+ new_state_catcher (List[StateCatcher]): A list to store Catcher objects for capturing new states in the current thread.
86
+ """
87
+
88
+ def __init__(self):
89
+ """
90
+ Initialize the ThreadLocalStack with empty data structures.
91
+
92
+ This constructor sets up the initial state for each thread-local instance,
93
+ creating empty lists for state stack, tree checking, JAX tracer checking,
94
+ and new state catching.
95
+ """
96
+ self.state_stack: List[StateTraceStack] = []
97
+ self.tree_check: List[bool] = [False]
98
+ self.jax_tracer_check: List[bool] = [False]
99
+ self.new_state_catcher: List[StateCatcher] = []
100
+
101
+
102
+ TRACE_CONTEXT = ThreadLocalStack()
103
+
104
+
105
+ @contextlib.contextmanager
106
+ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
107
+ """
108
+ The contex manager to check weather the tree structure of the state value keeps consistently.
109
+
110
+ Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
111
+ the tree structure of the value is not checked to avoid off the repeated evaluation.
112
+ If you want to check the tree structure of the value once the new value is assigned,
113
+ you can use this context manager.
114
+
115
+ Examples
116
+ --------
117
+
118
+ .. code-block:: python
119
+
120
+ >>> import brainstate
121
+ >>> import jax.numpy as jnp
122
+ >>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
123
+ >>> with brainstate.check_state_value_tree():
124
+ >>> # The line below will not raise an error.
125
+ >>> state.value = jnp.zeros((2, 3))
126
+ ...
127
+ >>> # The following code will raise an error, since it changes the tree structure.
128
+ >>> state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
129
+
130
+ """
131
+ try:
132
+ TRACE_CONTEXT.tree_check.append(val)
133
+ yield
134
+ finally:
135
+ TRACE_CONTEXT.tree_check.pop()
136
+
137
+
138
+ def maybe_state(val: Any) -> Any:
139
+ """
140
+ Extracts the value from a State object if given, otherwise returns the input value.
141
+
142
+ This function is useful for handling both State objects and raw values uniformly.
143
+ If the input is a State object, it returns the value stored in that State.
144
+ If the input is not a State object, it returns the input as is.
145
+
146
+ Args:
147
+ val (Any): The input value, which can be either a State object or any other type.
148
+
149
+ Returns:
150
+ Any: The value stored in the State if the input is a State object,
151
+ otherwise the input value itself.
152
+ """
153
+ if isinstance(val, State):
154
+ return val.value
155
+ else:
156
+ return val
157
+
158
+
159
+ @contextlib.contextmanager
160
+ def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
161
+ """
162
+ The context manager to check whether the state is valid to trace.
163
+
164
+ Example
165
+ -------
166
+
167
+ .. code-block:: python
168
+
169
+ >>> import jax
170
+ >>> import brainstate
171
+ >>> import jax.numpy as jnp
172
+ >>>
173
+ >>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
174
+ >>>
175
+ >>> @jax.jit
176
+ >>> def run_state(b):
177
+ >>> a.value = b
178
+ >>> return a.value
179
+ >>>
180
+ >>> # The following code will not raise an error, since the state is valid to trace.
181
+ >>> run_state(jnp.ones((2, 3)))
182
+ >>>
183
+ >>> with check_state_jax_tracer():
184
+ >>> # The line below will not raise an error.
185
+ >>> run_state(jnp.ones((2, 4)))
186
+ """
187
+ try:
188
+ TRACE_CONTEXT.jax_tracer_check.append(val)
189
+ yield
190
+ finally:
191
+ TRACE_CONTEXT.jax_tracer_check.pop()
192
+
193
+
194
+ def _get_trace_stack_level() -> int:
195
+ return len(TRACE_CONTEXT.state_stack)
196
+
197
+
198
+ class State(Generic[A], PrettyObject):
199
+ """
200
+ A generic class representing a dynamic data pointer in the BrainState framework.
201
+
202
+ The State class serves as a base for various types of state objects used to
203
+ manage and track dynamic data within a program. It provides mechanisms for
204
+ value storage, metadata management, and integration with the BrainState
205
+ tracing system.
206
+
207
+ Type Parameters:
208
+ A: The type of the value stored in the state.
209
+
210
+ Attributes:
211
+ name (Optional[str]): An optional name for the state.
212
+ value (PyTree): The actual value stored in the state.
213
+ tag (Optional[str]): An optional tag for categorizing or grouping states.
214
+
215
+ Args:
216
+ value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
217
+ The initial value for the state. Can be a PyTree of array-like objects
218
+ or a StateMetadata object.
219
+ name (Optional[str]): An optional name for the state.
220
+ **metadata: Additional metadata to be stored with the state.
221
+
222
+ Example
223
+ -------
224
+
225
+ .. code-block:: python
226
+
227
+ >>> class MyState(State):
228
+ ... pass
229
+ >>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
230
+ >>> print(state.value)
231
+ [[0. 0. 0.]
232
+ [0. 0. 0.]
233
+ [0. 0. 0.]]
234
+
235
+ Note:
236
+ - Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
237
+ RandomState) are typically used for specific purposes in a program.
238
+ - The class integrates with BrainState's tracing system to track state
239
+ creation and modifications.
240
+
241
+ The typical examples of :py:class:`~.State` subclass are:
242
+
243
+ - :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
244
+ - :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
245
+ - :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
246
+ - :py:class:`RandomState`: The random generator state, which is used to store the random key in the program.
247
+
248
+ Args:
249
+ value: PyTree. It can be anything as a pyTree.
250
+ name: Optional[str]. The name of the state.
251
+ tag: Optional[str]. The tag of the state.
252
+ """
253
+ __module__ = 'brainstate'
254
+ _level: int
255
+ _source_info: source_info_util.SourceInfo
256
+ _name: Optional[str]
257
+ _value: PyTree
258
+ _been_writen: bool # useful in `unflatten` and `flatten` graph processing
259
+ tag: Optional[str]
260
+
261
+ def __init__(
262
+ self,
263
+ value: PyTree[ArrayLike],
264
+ name: Optional[str] = None,
265
+ **metadata: Any
266
+ ):
267
+ """
268
+ Initialize a new HiddenState instance.
269
+
270
+ This constructor sets up the initial state for a hidden state in a dynamic model,
271
+ handling various input types and metadata.
272
+
273
+ Args:
274
+ value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
275
+ The initial value for the hidden state. Can be a PyTree of array-like objects
276
+ or a StateMetadata object containing both value and metadata.
277
+ name (Optional[str], optional): A name for the hidden state. Defaults to None.
278
+ **metadata: Additional metadata to be stored with the hidden state, including:
279
+ - tag (Optional[str]): A tag for categorizing or grouping states.
280
+ - Any other custom metadata fields.
281
+
282
+ Note:
283
+ This method initializes the hidden state, processes the input value and metadata,
284
+ sets up internal attributes, and records the state initialization.
285
+ """
286
+ tag = metadata.pop('tag', None)
287
+
288
+ # set the value and metadata
289
+ if isinstance(value, State):
290
+ value = value.value
291
+
292
+ # update metadata
293
+ metadata.update(
294
+ _value=value,
295
+ _level=_get_trace_stack_level(),
296
+ _source_info=source_info_util.current(),
297
+ _name=name,
298
+ _been_writen=False,
299
+ tag=tag,
300
+ )
301
+
302
+ # avoid using self._setattr to avoid the check
303
+ vars(self).update(metadata)
304
+
305
+ # record the state initialization
306
+ record_state_init(self)
307
+
308
+ def decrease_stack_level(self):
309
+ """
310
+ Decrease the stack level of the state by one, ensuring it doesn't go below zero.
311
+
312
+ This method is used to adjust the stack level of the state, typically when
313
+ exiting a nested context or scope. It ensures that the level never becomes
314
+ negative.
315
+ """
316
+ self._level = max(self._level - 1, 0)
317
+
318
+ def increase_stack_level(self):
319
+ """
320
+ Increase the stack level of the state by one.
321
+
322
+ This method is used to adjust the stack level of the state, typically when
323
+ entering a nested context or scope. It increments the internal level counter
324
+ by one.
325
+ """
326
+ self._level = self._level + 1
327
+
328
+ @property
329
+ def name(self) -> Optional[str]:
330
+ """
331
+ The name of the state.
332
+ """
333
+ return self._name
334
+
335
+ @name.setter
336
+ def name(self, name: str) -> None:
337
+ """
338
+ Set the name of the state.
339
+ """
340
+ self._name = name
341
+
342
+ @property
343
+ def value(self) -> PyTree[ArrayLike]:
344
+ """
345
+ The data and its value.
346
+ """
347
+ record_state_value_read(self)
348
+ return self._read_value()
349
+
350
+ @value.setter
351
+ def value(self, v) -> None:
352
+ """
353
+ Set the value of the state.
354
+
355
+ Args:
356
+ v: The value.
357
+ """
358
+ # NOTE: the following order is important
359
+
360
+ if isinstance(v, State): # value checking
361
+ raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
362
+ self._check_value_tree(v) # check the tree structure
363
+ record_state_value_write(self) # record the value by the stack (>= level)
364
+ self._been_writen = True # set the flag
365
+ self._write_value(v) # write the value
366
+
367
+ @property
368
+ def stack_level(self):
369
+ """
370
+ The stack level of the state.
371
+
372
+ Returns:
373
+ The stack level.
374
+ """
375
+ return self._level
376
+
377
+ @stack_level.setter
378
+ def stack_level(self, level: int):
379
+ """
380
+ Set the stack level of the state.
381
+
382
+ Args:
383
+ level: The stack level.
384
+ """
385
+ self._level = level
386
+
387
+ def _read_value(self) -> PyTree[ArrayLike]:
388
+ """
389
+ The interface to customize the value reading.
390
+ """
391
+ self.check_if_deleted()
392
+ return self._value
393
+
394
+ def _write_value(self, v) -> None:
395
+ """
396
+ The interface to customize the value writing.
397
+ """
398
+ self._value = v
399
+
400
+ def restore_value(self, v) -> None:
401
+ """
402
+ Restore the value of the state.
403
+
404
+ Args:
405
+ v: The value.
406
+ """
407
+ # value checking
408
+ if isinstance(v, State):
409
+ raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
410
+ with check_state_value_tree():
411
+ self._check_value_tree(v)
412
+ # record the value by the stack (>= level)
413
+ record_state_value_restore(self)
414
+ # set the value
415
+ self._value = v
416
+
417
+ def value_call(self, func: Callable[..., Any]) -> Any:
418
+ """
419
+ Call the function with the value of the state.
420
+ """
421
+ return jax.tree.map(func, self.value)
422
+
423
+ def _check_value_tree(self, v):
424
+ """
425
+ Check if the value tree structure is consistent.
426
+ """
427
+ if TRACE_CONTEXT.tree_check[-1]:
428
+ in_tree = jax.tree.structure(v)
429
+ self_tree = jax.tree.structure(self._value)
430
+ if in_tree != self_tree:
431
+ self.raise_error_with_source_info(
432
+ ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
433
+ )
434
+
435
+ def raise_error_with_source_info(self, error: Exception):
436
+ """
437
+ Raise an error with the source information for easy debugging.
438
+ """
439
+ name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
440
+ with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
441
+ raise error
442
+
443
+ def check_if_deleted(self):
444
+ pass
445
+
446
+ @property
447
+ def source_info(self) -> source_info_util.SourceInfo:
448
+ """
449
+ The source information of the state, can be useful to identify
450
+ the source code where the definition of the state.
451
+
452
+ Returns:
453
+ The source information.
454
+ """
455
+ return self._source_info
456
+
457
+ def update_from_ref(self, state_ref: TreefyState[A]) -> None:
458
+ """
459
+ Update the state from the state reference :py:class:`TreefyState`.
460
+
461
+ Args:
462
+ state_ref: The state reference.
463
+ """
464
+ metadata = state_ref.get_metadata()
465
+ variable_vars = vars(self)
466
+ variable_vars.update(**metadata)
467
+ if metadata.pop('_been_writen', True):
468
+ self.value = state_ref.value
469
+ else:
470
+ self.restore_value(state_ref.value)
471
+
472
+ def replace(self, value: Any = Missing, **kwargs) -> State[Any]:
473
+ """
474
+ Replace the attribute of the state.
475
+ """
476
+ if value is not Missing:
477
+ kwargs['_value'] = value
478
+
479
+ # return `value` if it is a State
480
+ if '_value' in kwargs and isinstance(value := kwargs['_value'], State):
481
+ # remove value from kwargs
482
+ kwargs.pop('_value')
483
+ if type(self) is not type(value):
484
+ raise ValueError('Cannot replace value from incompatible container, '
485
+ f'expected {type(self).__name__}, got {type(value).__name__}')
486
+ # if kwargs aren't empty, recursively call replace
487
+ # else return variable value
488
+ if kwargs:
489
+ return value.replace(**kwargs)
490
+ else:
491
+ return value
492
+
493
+ # get and update attributes
494
+ attributes = vars(self).copy()
495
+ attributes.update(**kwargs)
496
+ # return new instance with updated attributes
497
+ obj = object.__new__(type(self))
498
+ vars(obj).update(attributes)
499
+ return obj
500
+
501
+ def copy(self: State[A]) -> State[A]:
502
+ """
503
+ Copy the state.
504
+ """
505
+ obj = object.__new__(type(self))
506
+ attributes = vars(self).copy()
507
+ # keep its own trace state and stack level
508
+ attributes['_level'] = _get_trace_stack_level()
509
+ attributes['_source_info'] = source_info_util.current()
510
+ attributes.pop('_been_writen', None)
511
+ # update the metadata
512
+ vars(obj).update(attributes)
513
+ return obj
514
+
515
+ def to_state_ref(self: State[A]) -> TreefyState[A]:
516
+ metadata = vars(self).copy()
517
+ del metadata['_value']
518
+ return TreefyState(type(self), self._value, **metadata)
519
+
520
+ def __pretty_repr_item__(self, k, v):
521
+ if k in ['_level', '_source_info', '_been_writen']:
522
+ return None
523
+ if k == '_value':
524
+ return 'value', jax.tree.map(shaped_abstractify, v)
525
+
526
+ if k == '_name':
527
+ if self.name is None:
528
+ return None
529
+ else:
530
+ return 'name', v
531
+
532
+ if k == 'tag':
533
+ if self.tag is None:
534
+ return None
535
+ else:
536
+ return 'tag', v
537
+
538
+ return k, v
539
+
540
+ # def __eq__(self, other: object) -> bool:
541
+ # return type(self) is type(other) and vars(other) == vars(self)
542
+
543
+ def __hash__(self):
544
+ """
545
+ Make the state hashable.
546
+ """
547
+ return hash(id(self))
548
+
549
+ def numel(self) -> int:
550
+ """
551
+ Calculate the total number of elements in the state value.
552
+
553
+ This method traverses the state's value, which may be a nested structure (PyTree),
554
+ and computes the sum of sizes of all leaf nodes.
555
+
556
+ Returns:
557
+ int: The total number of elements across all arrays in the state value.
558
+ For scalar values, this will be 1. For arrays or nested structures,
559
+ it will be the sum of the sizes of all contained arrays.
560
+
561
+ Note:
562
+ This method uses jax.tree.leaves to flatten any nested structure in the state value,
563
+ and jax.numpy.size to compute the size of each leaf node.
564
+ """
565
+ sizes = [jax.numpy.size(val) for val in jax.tree.leaves(self._value)]
566
+ return sum(sizes)
567
+
568
+
569
+ def record_state_init(st: State[A]):
570
+ """
571
+ Record the initialization of a new :class:`State` object.
572
+
573
+ This function iterates through all registered state catchers in the current
574
+ trace context and appends the newly initialized state to each catcher.
575
+
576
+ Args:
577
+ st (State[A]): The newly initialized :class:`State` object to be recorded.
578
+
579
+ Note:
580
+ This function is typically called internally when a new :class:`State` object
581
+ is created to ensure proper tracking and management of states within
582
+ the current execution context.
583
+ """
584
+ trace: StateCatcher
585
+ for trace in TRACE_CONTEXT.new_state_catcher:
586
+ trace.append(st)
587
+
588
+
589
+ def record_state_value_read(st: State[A]):
590
+ """
591
+ Record that a state's value has been read in all relevant trace stacks.
592
+
593
+ This function iterates through all state trace stacks at or above the
594
+ state's stack level in the current trace context, and records that
595
+ the given state's value has been read.
596
+
597
+ Args:
598
+ st (State[A]): The state object whose value read is being recorded.
599
+ 'A' is a generic type parameter representing the
600
+ type of the state's value.
601
+
602
+ Note:
603
+ This function modifies the state trace stacks in the current
604
+ trace context but does not return any value.
605
+ """
606
+ trace: StateTraceStack
607
+ for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
608
+ trace.read_its_value(st)
609
+
610
+
611
+ def record_state_value_write(st: State[A]):
612
+ """
613
+ Record that a state's value has been written in all relevant trace stacks.
614
+
615
+ This function iterates through all state trace stacks at or above the
616
+ state's stack level in the current trace context, and records that
617
+ the given state's value has been written.
618
+
619
+ Args:
620
+ st (State[A]): The state object whose value write is being recorded.
621
+ 'A' is a generic type parameter representing the
622
+ type of the state's value.
623
+
624
+ Note:
625
+ This function modifies the state trace stacks in the current
626
+ trace context but does not return any value.
627
+ """
628
+ trace: StateTraceStack
629
+ for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
630
+ trace.write_its_value(st)
631
+
632
+
633
+ def record_state_value_restore(st: State[A]):
634
+ """
635
+ Record that a state's value has been restored.
636
+
637
+ This function is used to indicate that a state's value has been restored
638
+ to a previous value. It internally calls the record_state_value_read
639
+ function to mark the state as having been accessed.
640
+
641
+ Args:
642
+ st (State[A]): The state object whose value restoration is being recorded.
643
+ 'A' is a generic type parameter representing the
644
+ type of the state's value.
645
+
646
+ See Also:
647
+ record_state_value_read: Record that a state's value has been read.
648
+
649
+ Note:
650
+ This function does not actually restore the state's value; it only
651
+ records that a restoration has occurred.
652
+ """
653
+ record_state_value_read(st)
654
+
655
+
656
+ class ShortTermState(State):
657
+ """
658
+ A class representing short-term state in a program.
659
+
660
+ :class:`ShortTermState` is used to store temporary or transient data that is only relevant
661
+ for a short duration within the program's execution. This class extends the base
662
+ State class, inheriting its properties and methods while specifically denoting
663
+ the short-term nature of the stored data.
664
+
665
+ For example, in a machine learning training process, the gradients of the model
666
+ would typically be represented as :class:`ShortTermState`, as they are computed and used
667
+ within each iteration but not necessarily preserved across iterations.
668
+
669
+ Attributes:
670
+ Inherits all attributes from the base State class.
671
+
672
+ Note:
673
+ This class does not introduce new methods or attributes beyond those
674
+ inherited from the State class. Its primary purpose is to semantically
675
+ distinguish short-term states from other types of states in the program.
676
+
677
+ Example:
678
+ >>> gradient = ShortTermState(np.zeros(100), name="model_gradient")
679
+ >>> intermediate_result = ShortTermState({}, name="layer_activations")
680
+ """
681
+
682
+ __module__ = 'brainstate'
683
+
684
+
685
+ class LongTermState(State):
686
+ """
687
+ The long-term state, which is used to store the long-term data in the program.
688
+
689
+ This class extends the base :class:`State` class and is specifically designed to represent
690
+ and manage long-term data within a program. Long-term states are typically used
691
+ for data that persists across multiple iterations or epochs of a process.
692
+
693
+ For example, in a training process, the weights of the model are considered
694
+ long-term states as they are updated and maintained throughout the entire
695
+ training procedure.
696
+
697
+ Attributes:
698
+ Inherits all attributes from the base :class:`State` class.
699
+
700
+ Note:
701
+ This class does not introduce new methods or attributes beyond those
702
+ inherited from the :class:`State` class. Its primary purpose is to semantically
703
+ distinguish long-term states from other types of states in the program.
704
+
705
+ Example:
706
+ >>> model_weights = LongTermState(np.random.randn(100, 100), name="model_weights")
707
+ >>> optimizer_state = LongTermState({}, name="optimizer_state")
708
+ """
709
+
710
+ __module__ = 'brainstate'
711
+
712
+
713
+ class BatchState(LongTermState):
714
+ """
715
+ The batch state, which is used to store the batch data in the program.
716
+
717
+ This class extends :class:`LongTermState` and is specifically designed to represent
718
+ and manage batch data within a program. It provides a way to encapsulate
719
+ batch-related information and associated metadata, facilitating operations
720
+ like batch processing in machine learning or data analysis tasks.
721
+
722
+ Attributes:
723
+ Inherits all attributes from :class:`LongTermState`.
724
+
725
+ Note:
726
+ This class does not introduce new methods or attributes beyond those
727
+ inherited from :class:`LongTermState`. Its primary purpose is to semantically
728
+ distinguish batch states from other types of long-term states
729
+ in the program.
730
+
731
+ Example:
732
+ >>> batch_data = BatchState(np.array([1, 2, 3, 4, 5]), name="current_batch")
733
+ >>> batch_labels = BatchState(np.array([0, 1, 0, 1, 1]), name="batch_labels")
734
+ """
735
+
736
+ __module__ = 'brainstate'
737
+
738
+
739
+ class HiddenState(ShortTermState):
740
+ """
741
+ Represents hidden state variables in neurons or synapses.
742
+
743
+ This class extends :class:`ShortTermState` and is specifically designed to represent
744
+ and manage hidden states within dynamic models, such as recurrent neural networks.
745
+ It provides a way to encapsulate hidden state values and associated metadata,
746
+ facilitating operations like state updates during model execution.
747
+
748
+ Note:
749
+ :class:`HiddenState` and :class:`ParamState` are two most important state types
750
+ in brainstate. The former is used to store the hidden states in neurons, synapses,
751
+ or networks. The latter is used to store the trainable parameters in the model,
752
+ such as synaptic weights.
753
+
754
+ Example:
755
+ >>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
756
+ >>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
757
+ """
758
+
759
+ __module__ = 'brainstate'
760
+
761
+
762
+ class ParamState(LongTermState):
763
+ """
764
+ The parameter state, which is used to store the trainable parameters in the model.
765
+
766
+ This class extends :class:`LongTermState` and is specifically designed to represent
767
+ and manage trainable parameters within a neural network or machine learning model.
768
+ It provides a way to encapsulate parameter values and associated metadata,
769
+ facilitating operations like parameter updates during training.
770
+
771
+ Note:
772
+ :class:`HiddenState` and :class:`ParamState` are two most important state types
773
+ in brainstate. The former is used to store the hidden states in neurons, synapses,
774
+ or networks. The latter is used to store the trainable parameters in the model,
775
+ such as synaptic weights.
776
+
777
+ Example:
778
+ >>> weight = ParamState(np.random.randn(10, 10), name="layer1_weights")
779
+ >>> bias = ParamState(np.zeros(10), name="layer1_bias")
780
+ """
781
+
782
+ __module__ = 'brainstate'
783
+
784
+
785
+ class FakeState:
786
+ """
787
+ The faked state, which is used to store the faked data in the program.
788
+ """
789
+
790
+ __module__ = 'brainstate'
791
+
792
+ def __init__(self, value: Any, name: Optional[str] = None):
793
+ """
794
+ Initialize a FakeState instance.
795
+
796
+ Args:
797
+ value (Any): The value to be stored in the fake state.
798
+ name (Optional[str], optional): The name of the fake state. Defaults to None.
799
+ """
800
+ self._value = value
801
+ self._name = name
802
+
803
+ @property
804
+ def value(self) -> Any:
805
+ """
806
+ Get the value stored in the fake state.
807
+
808
+ Returns:
809
+ Any: The value stored in the fake state.
810
+ """
811
+ return self._value
812
+
813
+ @value.setter
814
+ def value(self, v) -> None:
815
+ """
816
+ Set the value of the fake state.
817
+
818
+ Args:
819
+ v (Any): The new value to be stored in the fake state.
820
+ """
821
+ self._value = v
822
+
823
+ def __repr__(self) -> str:
824
+ """
825
+ Return a string representation of the FakeState instance.
826
+
827
+ Returns:
828
+ str: A string representation of the FakeState instance.
829
+ """
830
+ return f'FakedState(value={self._value})'
831
+
832
+ @property
833
+ def name(self) -> Optional[str]:
834
+ """
835
+ Get the name of the fake state.
836
+
837
+ Returns:
838
+ Optional[str]: The name of the fake state, or None if not set.
839
+ """
840
+ return self._name
841
+
842
+ @name.setter
843
+ def name(self, name: str) -> None:
844
+ """
845
+ Set the name of the fake state.
846
+
847
+ Args:
848
+ name (str): The new name for the fake state.
849
+ """
850
+ self._name = name
851
+
852
+
853
+ class StateDictManager(DictManager):
854
+ """
855
+ State stack, for collecting all :py:class:`~.State` used in the program.
856
+
857
+ :py:class:`~.StateDictManager` supports all features of python dict.
858
+ """
859
+
860
+ __module__ = 'brainstate'
861
+
862
+ def assign_values(self, *args: Dict) -> None:
863
+ """
864
+ Assign the value for each element according to the given ``data``.
865
+ """
866
+ for arg in args:
867
+ assert isinstance(arg, dict), 'Must be an instance of dict.'
868
+ for k, v in arg.items():
869
+ self._set_elem(k, v)
870
+
871
+ def split_values(self, *filters: type) -> Tuple[Dict, ...]:
872
+ """
873
+ Split the values into several subsets of stack by the given types.
874
+ """
875
+ results = tuple(DictManager() for _ in range(len(filters) + 1))
876
+ for k, v in self.items():
877
+ for i, filt in enumerate(filters):
878
+ if isinstance(v, filt):
879
+ results[i][k] = v.value
880
+ break
881
+ else:
882
+ results[-1][k] = v.value
883
+ return results
884
+
885
+ def collect_values(self) -> Dict:
886
+ """
887
+ Collect the values by the given types.
888
+ """
889
+ results = DictManager()
890
+ for k, v in self.items():
891
+ results[k] = v.value
892
+ return results
893
+
894
+ def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
895
+ return super().split(first, *others)
896
+
897
+ def to_dict_values(self) -> Dict:
898
+ """
899
+ Convert the values into a dict.
900
+ """
901
+ return {k: v.value for k, v in self.items()}
902
+
903
+ def _check_elem(self, elem):
904
+ assert isinstance(elem, State), f'must be instance of {State}'
905
+
906
+ def _set_elem(self, key: Any, value: Any) -> None:
907
+ self[key].value = value
908
+
909
+
910
+ class StateTraceStack(Generic[A]):
911
+ """
912
+ A stack for tracing and managing states during program execution.
913
+
914
+ ``StateTraceStack`` is used to automatically trace and manage State objects,
915
+ keeping track of which states are read from or written to during the
916
+ execution of a function or block of code. It provides methods for
917
+ recording state accesses, retrieving state values, and managing the
918
+ lifecycle of states within a tracing context.
919
+
920
+ The class is generic over type A, allowing for type-safe usage with
921
+ different types of State objects.
922
+
923
+ The ``StateTraceStack`` is a crucial component in implementing state-based
924
+ computations and is particularly useful in scenarios involving automatic
925
+ differentiation or other forms of program transformation.
926
+ """
927
+
928
+ def __init__(
929
+ self,
930
+ new_arg: Callable = None,
931
+ name: Optional[str] = None,
932
+ ):
933
+ self.name = name
934
+ self.states: List[State] = []
935
+ self.been_writen: List[bool] = [] # False: read, True: write
936
+ self._state_id_index = dict()
937
+ self._original_state_values = []
938
+ self._jax_trace_new_arg: Callable = new_arg
939
+ self._stack_level = None
940
+
941
+ def __str__(self) -> str:
942
+ _stack_level = self.name if self._stack_level is None else self._stack_level
943
+ if _stack_level is None:
944
+ _stack_level = ''
945
+ return f"{self.__class__.__name__}({_stack_level})"
946
+
947
+ @property
948
+ def original_state_values(self) -> Tuple[PyTree, ...]:
949
+ """
950
+ Get the original values of all states in the StateTraceStack.
951
+
952
+ This property provides access to the initial values of all states
953
+ that were captured when they were first added to the stack. It's
954
+ useful for comparing current state values with their original values
955
+ or for reverting states to their initial condition.
956
+
957
+ Returns:
958
+ Tuple[PyTree, ...]: A tuple containing the original values of all
959
+ states in the order they were added to the stack. Each element
960
+ is a PyTree representing the structure and values of a state.
961
+ """
962
+ return tuple(self._original_state_values)
963
+
964
+ def set_new_arg(self, new_arg: Callable) -> None:
965
+ self._jax_trace_new_arg = new_arg
966
+
967
+ def new_arg(self, state: State) -> None:
968
+ """
969
+ Apply a transformation to the value of a given state using a predefined function.
970
+
971
+ This method is used internally to transform the value of a state during tracing.
972
+ If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
973
+ function to each element of the state's value using JAX's tree mapping.
974
+
975
+ Args:
976
+ state (State): The State object whose value needs to be transformed.
977
+
978
+ Returns:
979
+ None: This function modifies the state in-place and doesn't return anything.
980
+
981
+ Note:
982
+ This method is intended for internal use and relies on the presence of
983
+ a ``_jax_trace_new_arg`` function, which should be set separately.
984
+ """
985
+ if self._jax_trace_new_arg is not None:
986
+ # internal use
987
+ state._value = self._jax_trace_new_arg(state)
988
+
989
+ def __enter__(self) -> 'StateTraceStack':
990
+ TRACE_CONTEXT.state_stack.append(self)
991
+ self._stack_level = ' / '.join([st.name for st in TRACE_CONTEXT.state_stack if st.name is not None])
992
+ return self
993
+
994
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
995
+ TRACE_CONTEXT.state_stack.pop()
996
+
997
+ def read_its_value(self, state: State) -> None:
998
+ """
999
+ Record that a state's value has been read during tracing.
1000
+
1001
+ This method marks the given state as having been read in the current
1002
+ tracing context. If the state hasn't been encountered before, it adds
1003
+ it to the internal tracking structures and applies any necessary
1004
+ transformations via the new_arg method.
1005
+
1006
+ Args:
1007
+ state (State): The State object whose value is being read.
1008
+
1009
+ Returns:
1010
+ None
1011
+
1012
+ Note:
1013
+ This method updates the internal tracking of state accesses.
1014
+ It doesn't actually read or return the state's value.
1015
+ """
1016
+ id_ = id(state)
1017
+ if id_ not in self._state_id_index:
1018
+ self._state_id_index[id_] = len(self.states)
1019
+ self.states.append(state)
1020
+ self.been_writen.append(False)
1021
+ self._original_state_values.append(state._value) # internal use
1022
+ self.new_arg(state)
1023
+
1024
+ def write_its_value(self, state: State) -> None:
1025
+ """
1026
+ Record that a state's value has been written to during tracing.
1027
+
1028
+ This method marks the given state as having been written to in the current
1029
+ tracing context. If the state hasn't been encountered before, it first
1030
+ records it as being read before marking it as written.
1031
+
1032
+ Args:
1033
+ state (State): The State object whose value is being written to.
1034
+
1035
+ Returns:
1036
+ None
1037
+
1038
+ Note:
1039
+ This method updates the internal tracking of state modifications.
1040
+ It doesn't actually modify the state's value.
1041
+ """
1042
+ id_ = id(state)
1043
+ if id_ not in self._state_id_index:
1044
+ self.read_its_value(state)
1045
+ index = self._state_id_index[id_]
1046
+ self.been_writen[index] = True
1047
+
1048
+ def get_state_values(
1049
+ self,
1050
+ separate: bool = False,
1051
+ replace: bool = False
1052
+ ) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1053
+ """
1054
+ Retrieve the values of all states in the StateTraceStack.
1055
+
1056
+ This method returns the values of all states, optionally separating them
1057
+ into written and read states, and optionally replacing values with None
1058
+ for states that weren't accessed in a particular way.
1059
+
1060
+ Args:
1061
+ separate (bool, optional): If True, separate the values into written
1062
+ and read states. If False, return all values in a single sequence.
1063
+ Defaults to False.
1064
+ replace (bool, optional): If True and separate is True, replace values
1065
+ with None for states that weren't written/read. If False, only
1066
+ include values for states that were written/read. Defaults to False.
1067
+
1068
+ Returns:
1069
+ Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1070
+ If separate is False:
1071
+ A sequence of all state values.
1072
+ If separate is True:
1073
+ A tuple containing two sequences:
1074
+ - The first sequence contains values of written states.
1075
+ - The second sequence contains values of read states.
1076
+ If replace is True, these sequences will have None for
1077
+ states that weren't written/read respectively.
1078
+
1079
+ """
1080
+ if separate:
1081
+ if replace:
1082
+ writes, reads = [], []
1083
+ for st, been_writen in zip(self.states, self.been_writen):
1084
+ if been_writen:
1085
+ writes.append(st.value)
1086
+ reads.append(None)
1087
+ else:
1088
+ reads.append(st.value)
1089
+ writes.append(None)
1090
+ return tuple(writes), tuple(reads)
1091
+ else:
1092
+ writes, reads = [], []
1093
+ for st, been_writen in zip(self.states, self.been_writen):
1094
+ if been_writen:
1095
+ writes.append(st.value)
1096
+ else:
1097
+ reads.append(st.value)
1098
+ return tuple(writes), tuple(reads)
1099
+ else:
1100
+ return tuple([st.value for st in self.states])
1101
+
1102
+ def recovery_original_values(self) -> None:
1103
+ """
1104
+ Restore the original values of all states in the StateTraceStack.
1105
+
1106
+ This method iterates through all states in the stack and restores
1107
+ their values to the original ones that were captured when the states
1108
+ were first added to the stack. This is useful for reverting changes
1109
+ made during tracing or for resetting the states to their initial condition.
1110
+
1111
+ Note:
1112
+ This method modifies the states in-place.
1113
+
1114
+ Returns:
1115
+ None
1116
+ """
1117
+ for st, val in zip(self.states, self._original_state_values):
1118
+ # internal use
1119
+ st.restore_value(val)
1120
+
1121
+ def merge(self, *traces) -> 'StateTraceStack':
1122
+ """
1123
+ Merge other state traces into the current ``StateTraceStack``.
1124
+
1125
+ This method combines the states, their write status, and original values from
1126
+ other ``StateTraceStack`` instances into the current one. If a state from another
1127
+ trace is not present in the current trace, it is added. If a state is already
1128
+ present, its write status is updated if necessary.
1129
+
1130
+ Args:
1131
+ *traces: Variable number of ``StateTraceStack`` instances to be merged into
1132
+ the current instance.
1133
+
1134
+ Returns:
1135
+ StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
1136
+
1137
+ Note:
1138
+ This method modifies the current ``StateTraceStack`` in-place and also returns it.
1139
+ """
1140
+ trace: StateTraceStack
1141
+ for trace in traces:
1142
+ for st, been_writen, org_val in zip(trace.states, trace.been_writen, trace._original_state_values):
1143
+ if id(st) not in self._state_id_index: # read the value
1144
+ self._state_id_index[id(st)] = len(self.states)
1145
+ self._original_state_values.append(org_val) # add the original value
1146
+ self.states.append(st) # append the state
1147
+ self.been_writen.append(False)
1148
+ if been_writen:
1149
+ self.write_its_value(st)
1150
+ return self
1151
+
1152
+ def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
1153
+ """
1154
+ Retrieve the states that were read during the function execution.
1155
+
1156
+ This method returns the states that were accessed (read from) during
1157
+ the traced function's execution. It can optionally replace written
1158
+ states with None.
1159
+
1160
+ Args:
1161
+ replace_writen (bool, optional): If True, replace written states with None
1162
+ in the returned tuple. If False, exclude written states entirely from
1163
+ the result. Defaults to False.
1164
+
1165
+ Returns:
1166
+ Tuple[State, ...]: A tuple containing the read states.
1167
+ If replace_writen is True, the tuple will have the same length as the
1168
+ total number of states, with None for written states.
1169
+ If replace_writen is False, the tuple will only contain read-only states.
1170
+ """
1171
+ if replace_writen:
1172
+ return tuple([st if not been_writen else None
1173
+ for st, been_writen in zip(self.states, self.been_writen)])
1174
+ else:
1175
+ return tuple([st for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
1176
+
1177
+ def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
1178
+ """
1179
+ Retrieve the values of states that were read during the function execution.
1180
+
1181
+ This method returns the values of states that were accessed (read from) during
1182
+ the traced function's execution. It can optionally replace written states with None.
1183
+
1184
+ Args:
1185
+ replace_writen (bool, optional): If True, replace the values of written
1186
+ states with None in the returned tuple. If False, exclude written
1187
+ states entirely from the result. Defaults to False.
1188
+
1189
+ Returns:
1190
+ Tuple[PyTree, ...]: A tuple containing the values of read states.
1191
+ If replace_writen is True, the tuple will have the same length as the
1192
+ total number of states, with None for written states.
1193
+ If replace_writen is False, the tuple will only contain values of
1194
+ read-only states.
1195
+ """
1196
+ if replace_writen:
1197
+ return tuple(
1198
+ [st.value if not been_writen else None
1199
+ for st, been_writen in zip(self.states, self.been_writen)]
1200
+ )
1201
+ else:
1202
+ return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
1203
+
1204
+ def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
1205
+ """
1206
+ Retrieve the states that were written during the function execution.
1207
+
1208
+ This method returns the states that were modified (written to) during
1209
+ the traced function's execution. It can optionally replace unwritten (read-only)
1210
+ states with None.
1211
+
1212
+ Args:
1213
+ replace_read (bool, optional): If True, replace read-only states with None
1214
+ in the returned tuple. If False, exclude read-only states entirely from
1215
+ the result. Defaults to False.
1216
+
1217
+ Returns:
1218
+ Tuple[State, ...]: A tuple containing the written states.
1219
+ If replace_read is True, the tuple will have the same length as the
1220
+ total number of states, with None for read-only states.
1221
+ If replace_read is False, the tuple will only contain written states.
1222
+ """
1223
+ if replace_read:
1224
+ return tuple([st if been_writen else None
1225
+ for st, been_writen in zip(self.states, self.been_writen)])
1226
+ else:
1227
+ return tuple([st for st, been_writen in zip(self.states, self.been_writen) if been_writen])
1228
+
1229
+ def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
1230
+ """
1231
+ Retrieve the values of states that were written during the function execution.
1232
+
1233
+ This method returns the values of states that were modified (written to) during
1234
+ the traced function's execution. It can optionally replace unwritten (read-only)
1235
+ states with None.
1236
+
1237
+ Args:
1238
+ replace_read (bool, optional): If True, replace the values of read-only
1239
+ states with None in the returned tuple. If False, exclude read-only
1240
+ states entirely from the result. Defaults to False.
1241
+
1242
+ Returns:
1243
+ Tuple[PyTree, ...]: A tuple containing the values of written states.
1244
+ If replace_read is True, the tuple will have the same length as the
1245
+ total number of states, with None for read-only states.
1246
+ If replace_read is False, the tuple will only contain values of
1247
+ written states.
1248
+
1249
+ """
1250
+ if replace_read:
1251
+ return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
1252
+ else:
1253
+ return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if been_writen])
1254
+
1255
+ def __add__(self, other: 'StateTraceStack') -> 'StateTraceStack':
1256
+ """
1257
+ Support the syntax of `+` to merge the state traces.
1258
+ """
1259
+ return StateTraceStack().merge(self, other)
1260
+
1261
+ def state_subset(self, state_type: type) -> List:
1262
+ """
1263
+ Get a subset of states of a specific type from the ``StateTraceStack``.
1264
+
1265
+ This method filters the states in the ``StateTraceStack`` and returns only
1266
+ those that match the specified state type.
1267
+
1268
+ Args:
1269
+ state_type (type): The type of state to filter by. This should be
1270
+ a subclass of State or State itself.
1271
+
1272
+ Returns:
1273
+ List[State]: A list containing all states in the ``StateTraceStack``
1274
+ that are instances of the specified state_type.
1275
+
1276
+ Example:
1277
+ >>> stack = StateTraceStack()
1278
+ >>> # Assume stack has been populated with various state types
1279
+ >>> short_term_states = stack.state_subset(ShortTermState)
1280
+ """
1281
+ return [st for st in self.states if isinstance(st, state_type)]
1282
+
1283
+ def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
1284
+ """
1285
+ Assign new values to the states tracked by this ``StateTraceStack``.
1286
+
1287
+ This method updates the values of the states based on whether they were
1288
+ written to or only read during the tracing process. For states that were
1289
+ written to, it directly assigns the new value. For states that were only
1290
+ read, it restores the value using the state's restore_value method.
1291
+
1292
+ Args:
1293
+ state_vals (Sequence[PyTree]): A sequence of new state values to be
1294
+ assigned. Each element in this sequence corresponds to a state
1295
+ in the ``StateTraceStack``'s states list.
1296
+
1297
+ Raises:
1298
+ ValueError: If the length of state_vals doesn't match the number of
1299
+ states in the ``StateTraceStack``.
1300
+
1301
+ Returns:
1302
+ None
1303
+
1304
+ Note:
1305
+ The order of state_vals should match the order of states in the
1306
+ ``StateTraceStack``'s states list.
1307
+ """
1308
+ if len(state_vals) != len(self.states):
1309
+ raise ValueError(
1310
+ 'The length of the state values must be equal to the states. '
1311
+ f'Bug got {len(state_vals)} and {len(self.states)}'
1312
+ )
1313
+ for st, written, val in zip(self.states, self.been_writen, state_vals):
1314
+ if written:
1315
+ st.value = val
1316
+ else:
1317
+ st.restore_value(val)
1318
+
1319
+ def assign_state_vals_v2(
1320
+ self: StateTraceStack,
1321
+ read_state_vals: Sequence[PyTree],
1322
+ write_state_vals: Sequence[PyTree],
1323
+ ):
1324
+ """
1325
+ Write back state values to their corresponding states after computation.
1326
+
1327
+ This function updates the state values based on whether they were written to
1328
+ during the computation. If a state was written to, it gets the new written value.
1329
+ If not, it restores its original read value.
1330
+
1331
+ Parameters
1332
+ ----------
1333
+ read_state_vals : sequence of PyTree
1334
+ The original state values that were read at the beginning.
1335
+ write_state_vals : sequence of PyTree
1336
+ The new state values that were written during computation.
1337
+
1338
+ Examples
1339
+ --------
1340
+ Basic usage in a compilation context:
1341
+
1342
+ .. code-block:: python
1343
+
1344
+ >>> import brainstate
1345
+ >>> import jax.numpy as jnp
1346
+ >>>
1347
+ >>> # Create states
1348
+ >>> state1 = brainstate.State(jnp.array([1.0, 2.0]))
1349
+ >>> state2 = brainstate.State(jnp.array([3.0, 4.0]))
1350
+ >>>
1351
+ >>> def f(x):
1352
+ ... state1.value += x # This state will be written
1353
+ ... return state1.value + state2.value # state2 is only read
1354
+ >>>
1355
+ >>> # During compilation, state values are collected and managed
1356
+ >>> # write_back_state_values ensures proper state management
1357
+ """
1358
+ if len(self.states) != len(self.been_writen):
1359
+ raise ValueError('The length of the state values must be equal to the states. ')
1360
+ if len(read_state_vals) != len(self.states):
1361
+ raise ValueError('The length of the read state values must be equal to the states. ')
1362
+ if len(write_state_vals) != len(self.states):
1363
+ raise ValueError('The length of the write state values must be equal to the states. ')
1364
+ for st, write, val_r, val_w in zip(
1365
+ self.states, self.been_writen, read_state_vals, write_state_vals
1366
+ ):
1367
+ if write:
1368
+ st.value = val_w
1369
+ else:
1370
+ st.restore_value(val_r)
1371
+
1372
+
1373
+ class TreefyState(Generic[A], PrettyObject):
1374
+ """
1375
+ The state as a pytree.
1376
+ """
1377
+
1378
+ def __init__(
1379
+ self,
1380
+ type: type[State[Any]],
1381
+ value: A,
1382
+ **metadata
1383
+ ):
1384
+ self.type = type
1385
+ self.value = value
1386
+ vars(self).update(metadata)
1387
+
1388
+ if TYPE_CHECKING:
1389
+ def __getattr__(self, name: str) -> None: ...
1390
+
1391
+ def __setattr__(self, name: str, value: Any) -> None: ...
1392
+
1393
+ def __delattr__(self, name: str) -> None: ...
1394
+
1395
+ def __pretty_repr_item__(self, k, v):
1396
+ if k in ['_level', '_source_info', '_been_writen']:
1397
+ return None
1398
+ if k == '_value':
1399
+ return 'value', v
1400
+
1401
+ if k == '_name':
1402
+ return None if v is None else ('name', v)
1403
+ return k, v
1404
+
1405
+ @property
1406
+ def name(self) -> Optional[str]:
1407
+ """
1408
+ The name of the state.
1409
+ """
1410
+ return self._name
1411
+
1412
+ @name.setter
1413
+ def name(self, name: str) -> None:
1414
+ """
1415
+ Set the name of the state.
1416
+ """
1417
+ self._name = name
1418
+
1419
+ def replace(self, value: B) -> TreefyState[B]:
1420
+ """
1421
+ Replace the value of the state reference.
1422
+ """
1423
+ return TreefyState(self.type, value, **self.get_metadata())
1424
+
1425
+ def to_state(self) -> State[A]:
1426
+ """
1427
+ Convert the state reference to the state.
1428
+ """
1429
+ # we use object.__new__ to avoid calling __init__ and bypass the
1430
+ # __init__ logic which should not be called twice
1431
+ metadata = self.get_metadata()
1432
+ state = object.__new__(self.type)
1433
+ metadata.pop('_value', None)
1434
+ metadata.pop('_level', None)
1435
+ vars(state).update(**metadata, _value=self.value, _level=_get_trace_stack_level())
1436
+ return state
1437
+
1438
+ def copy(self: TreefyState[A]) -> TreefyState[A]:
1439
+ """
1440
+ Copy the state reference.
1441
+ """
1442
+ return jax.tree.map(lambda x: x, self)
1443
+
1444
+ def get_metadata(self) -> Dict[str, Any]:
1445
+ """
1446
+ Get the metadata of the state reference
1447
+ """
1448
+ metadata = vars(self).copy()
1449
+ del metadata['type']
1450
+ del metadata['value']
1451
+ return metadata
1452
+
1453
+
1454
+ def _state_ref_flatten(x: TreefyState[Any], *, with_keys: bool):
1455
+ metadata = tuple(x.get_metadata().items())
1456
+ if with_keys:
1457
+ node = (jax.tree_util.GetAttrKey('value'), x.value)
1458
+ else:
1459
+ node = x.value
1460
+ return (node,), (x.type, metadata)
1461
+
1462
+
1463
+ def _state_ref_unflatten(
1464
+ static: Tuple[type[State[A]], Tuple[Tuple[str, Any], ...]],
1465
+ children: Tuple[A],
1466
+ ) -> TreefyState[A]:
1467
+ return TreefyState(type=static[0], value=children[0], **dict(static[1]))
1468
+
1469
+
1470
+ jax.tree_util.register_pytree_with_keys(
1471
+ TreefyState,
1472
+ partial(_state_ref_flatten, with_keys=True), # type: ignore
1473
+ _state_ref_unflatten, # type: ignore
1474
+ flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
1475
+ )
1476
+
1477
+
1478
+ class StateCatcher(PrettyObject):
1479
+ """
1480
+ The catcher to catch and manage new states.
1481
+
1482
+ This class provides functionality to collect and tag new State objects.
1483
+ It ensures that each state is only added once and assigns a tag to each state.
1484
+
1485
+ Attributes:
1486
+ state_tag (str): A string identifier used to tag the caught states.
1487
+ state_ids (set): A set of state IDs to ensure uniqueness.
1488
+ states (list): A list to store the caught State objects.
1489
+ """
1490
+
1491
+ def __init__(
1492
+ self,
1493
+ state_tag: str,
1494
+ state_to_exclude: Filter = Nothing()
1495
+ ):
1496
+ """
1497
+ Initialize a new Catcher instance.
1498
+
1499
+ Args:
1500
+ state_tag (str): The tag to be assigned to caught states.
1501
+ state_to_exclude (Filter, optional): A filter to exclude states from being caught.
1502
+ """
1503
+ if state_to_exclude is None:
1504
+ state_to_exclude = Nothing()
1505
+ self.state_to_exclude = state_to_exclude
1506
+ self.state_tag = state_tag
1507
+ self.state_ids = set()
1508
+ self.states = []
1509
+
1510
+ def get_state_values(self) -> List[PyTree]:
1511
+ """
1512
+ Get the values of the caught states.
1513
+
1514
+ Returns:
1515
+ list: A list of values of the caught states.
1516
+ """
1517
+ return [state.value for state in self.states]
1518
+
1519
+ def get_states(self) -> List[State]:
1520
+ """
1521
+ Get the caught states.
1522
+
1523
+ Returns:
1524
+ list: A list of the caught states.
1525
+ """
1526
+ return self.states
1527
+
1528
+ def append(self, state: State):
1529
+ """
1530
+ Add a new state to the catcher if it hasn't been added before.
1531
+
1532
+ This method adds the state to the internal list, records its ID,
1533
+ and assigns the catcher's tag to the state.
1534
+
1535
+ Args:
1536
+ state (State): The State object to be added.
1537
+ """
1538
+ if self.state_to_exclude((), state):
1539
+ return
1540
+ if id(state) not in self.state_ids:
1541
+ self.state_ids.add(id(state))
1542
+ self.states.append(state)
1543
+ state.tag = self.state_tag
1544
+
1545
+ def __iter__(self):
1546
+ """
1547
+ Allow iteration over the caught states.
1548
+
1549
+ Returns:
1550
+ iterator: An iterator over the list of caught states.
1551
+ """
1552
+ return iter(self.states)
1553
+
1554
+ def __len__(self):
1555
+ """
1556
+ Return the number of caught states.
1557
+
1558
+ Returns:
1559
+ int: The number of caught states.
1560
+ """
1561
+ return len(self.states)
1562
+
1563
+ def __getitem__(self, index):
1564
+ """
1565
+ Get a state by index.
1566
+
1567
+ Args:
1568
+ index (int): The index of the state to retrieve.
1569
+
1570
+ Returns:
1571
+ State: The state at the specified index.
1572
+ """
1573
+ return self.states[index]
1574
+
1575
+ def clear(self):
1576
+ """
1577
+ Clear all caught states.
1578
+ """
1579
+ self.state_ids.clear()
1580
+ self.states.clear()
1581
+
1582
+ def get_by_tag(self, tag: str):
1583
+ """
1584
+ Get all states with a specific tag.
1585
+
1586
+ Args:
1587
+ tag (str): The tag to filter by.
1588
+
1589
+ Returns:
1590
+ list: A list of states with the specified tag.
1591
+ """
1592
+ return [state for state in self.states if state.tag == tag]
1593
+
1594
+ def remove(self, state: State):
1595
+ """
1596
+ Remove a specific state from the catcher.
1597
+
1598
+ Args:
1599
+ state (State): The state to remove.
1600
+ """
1601
+ if id(state) in self.state_ids:
1602
+ self.state_ids.remove(id(state))
1603
+ self.states.remove(state)
1604
+
1605
+ def __contains__(self, state: State):
1606
+ """
1607
+ Check if a state is in the catcher.
1608
+
1609
+ Args:
1610
+ state (State): The state to check for.
1611
+
1612
+ Returns:
1613
+ bool: True if the state is in the catcher, False otherwise.
1614
+ """
1615
+ return id(state) in self.state_ids
1616
+
1617
+
1618
+ @contextlib.contextmanager
1619
+ def catch_new_states(
1620
+ state_tag: str = None,
1621
+ state_to_exclude: Filter = Nothing()
1622
+ ) -> Generator[StateCatcher, None, None]:
1623
+ """
1624
+ A context manager that catches and tracks new states created within its scope.
1625
+
1626
+ This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
1627
+ new_state_catcher list. It allows for tracking and managing new states created
1628
+ within the context.
1629
+
1630
+ Args:
1631
+ state_tag (str, optional): A string tag to associate with the caught states.
1632
+ Defaults to None.
1633
+ state_to_exclude (Filter, optional): A filter object to specify which states
1634
+ should be excluded from catching. Defaults to Nothing(), which excludes no states.
1635
+
1636
+ Yields:
1637
+ Catcher: A Catcher object that can be used to access and manage the
1638
+ newly created states within the context.
1639
+
1640
+ Example::
1641
+
1642
+ with catch_new_states("my_tag") as catcher:
1643
+ # Create new states here
1644
+ # They will be caught and tagged with "my_tag"
1645
+ # Access caught states through catcher object
1646
+ """
1647
+ try:
1648
+ catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
1649
+ TRACE_CONTEXT.new_state_catcher.append(catcher)
1650
+ yield catcher
1651
+ finally:
1652
+ TRACE_CONTEXT.new_state_catcher.pop()