brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +8 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +193 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +6 -1
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +68 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/_utils.py +89 -0
  68. brainstate/nn/metrics.py +3 -4
  69. brainstate/optim/_lr_scheduler.py +1 -2
  70. brainstate/optim/_lr_scheduler_test.py +2 -3
  71. brainstate/optim/_optax_optimizer_test.py +1 -2
  72. brainstate/optim/_sgd_optimizer.py +2 -3
  73. brainstate/random/_rand_funs.py +1 -2
  74. brainstate/random/_rand_funs_test.py +2 -3
  75. brainstate/random/_rand_seed.py +2 -3
  76. brainstate/random/_rand_seed_test.py +1 -2
  77. brainstate/random/_rand_state.py +3 -4
  78. brainstate/surrogate.py +5 -5
  79. brainstate/transform.py +0 -3
  80. brainstate/typing.py +28 -25
  81. brainstate/util/__init__.py +9 -7
  82. brainstate/util/_caller.py +1 -2
  83. brainstate/util/_error.py +27 -0
  84. brainstate/util/_others.py +60 -15
  85. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  86. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  87. brainstate/util/_pretty_repr.py +1 -2
  88. brainstate/util/_pretty_table.py +2900 -0
  89. brainstate/util/_struct.py +11 -11
  90. brainstate/util/filter.py +472 -0
  91. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA +2 -2
  92. brainstate-0.1.0.post20250217.dist-info/RECORD +128 -0
  93. brainstate/util/_filter.py +0 -178
  94. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/LICENSE +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/WHEEL +0 -0
  97. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt +0 -0
brainstate/_state.py CHANGED
@@ -17,31 +17,47 @@ from __future__ import annotations
17
17
 
18
18
  import contextlib
19
19
  import dataclasses
20
- import threading
21
- from functools import wraps, partial
22
- from typing import (
23
- Any, Union, Callable, Generic, Mapping,
24
- TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence
25
- )
26
-
27
20
  import jax
28
21
  import numpy as np
22
+ import threading
23
+ from functools import wraps, partial
29
24
  from jax.api_util import shaped_abstractify
30
25
  from jax.extend import source_info_util
26
+ from typing import (
27
+ Any,
28
+ Union,
29
+ Callable,
30
+ Generic,
31
+ Mapping,
32
+ TypeVar,
33
+ Optional,
34
+ TYPE_CHECKING,
35
+ Tuple,
36
+ Dict,
37
+ List,
38
+ Sequence,
39
+ Generator,
40
+ )
31
41
 
32
- from brainstate.typing import ArrayLike, PyTree, Missing
42
+ from brainstate.typing import ArrayLike, PyTree, Missing, Filter
33
43
  from brainstate.util import DictManager, PrettyObject
44
+ from brainstate.util.filter import Nothing
34
45
 
35
46
  __all__ = [
36
47
  'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
37
- 'FakedState',
48
+ 'FakeState',
38
49
 
39
- 'StateDictManager', 'StateTraceStack', 'check_state_value_tree', 'check_state_jax_tracer', 'catch_new_states',
50
+ 'StateDictManager',
51
+ 'StateTraceStack',
52
+ 'check_state_value_tree',
53
+ 'check_state_jax_tracer',
54
+ 'catch_new_states',
40
55
  'maybe_state',
41
56
  ]
42
57
 
43
58
  A = TypeVar('A')
44
59
  B = TypeVar('B')
60
+ T = TypeVar('T')
45
61
  F = TypeVar('F', bound=Callable[..., Any])
46
62
 
47
63
  max_int = np.iinfo(np.int32)
@@ -51,18 +67,38 @@ max_int = np.iinfo(np.int32)
51
67
  # This allows concurrent tracing in separate threads; passing traced objects
52
68
  # between threads is forbidden.
53
69
  class ThreadLocalStack(threading.local):
70
+ """
71
+ A thread-local storage class for managing state-related information.
72
+
73
+ This class provides thread-local storage for various state management components,
74
+ ensuring that each thread has its own isolated set of state-related data structures.
75
+
76
+ Attributes:
77
+ state_stack (List[StateTraceStack]): A list to store StateTraceStack objects for the current thread.
78
+ tree_check (List[bool]): A list of boolean flags for tree structure checking, initialized with [False].
79
+ jax_tracer_check (List[bool]): A list of boolean flags for JAX tracer checking, initialized with [False].
80
+ new_state_catcher (List[StateCatcher]): A list to store Catcher objects for capturing new states in the current thread.
81
+ """
82
+
54
83
  def __init__(self):
84
+ """
85
+ Initialize the ThreadLocalStack with empty data structures.
86
+
87
+ This constructor sets up the initial state for each thread-local instance,
88
+ creating empty lists for state stack, tree checking, JAX tracer checking,
89
+ and new state catching.
90
+ """
55
91
  self.state_stack: List[StateTraceStack] = []
56
92
  self.tree_check: List[bool] = [False]
57
93
  self.jax_tracer_check: List[bool] = [False]
58
- self.new_state_catcher: List[Catcher] = []
94
+ self.new_state_catcher: List[StateCatcher] = []
59
95
 
60
96
 
61
97
  TRACE_CONTEXT = ThreadLocalStack()
62
98
 
63
99
 
64
100
  @contextlib.contextmanager
65
- def check_state_value_tree(val: bool = True) -> None:
101
+ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
66
102
  """
67
103
  The contex manager to check weather the tree structure of the state value keeps consistently.
68
104
 
@@ -91,34 +127,198 @@ def check_state_value_tree(val: bool = True) -> None:
91
127
  TRACE_CONTEXT.tree_check.pop()
92
128
 
93
129
 
94
- @contextlib.contextmanager
95
- def catch_new_states(tag: str = None) -> List:
96
- try:
97
- catcher = Catcher(tag)
98
- TRACE_CONTEXT.new_state_catcher.append(catcher)
99
- yield catcher
100
- finally:
101
- TRACE_CONTEXT.new_state_catcher.pop()
130
+ class StateCatcher(PrettyObject):
131
+ """
132
+ The catcher to catch and manage new states.
102
133
 
134
+ This class provides functionality to collect and tag new State objects.
135
+ It ensures that each state is only added once and assigns a tag to each state.
103
136
 
104
- class Catcher:
105
- """
106
- The catcher to catch the new states.
137
+ Attributes:
138
+ state_tag (str): A string identifier used to tag the caught states.
139
+ state_ids (set): A set of state IDs to ensure uniqueness.
140
+ states (list): A list to store the caught State objects.
107
141
  """
108
142
 
109
- def __init__(self, tag: str):
110
- self.tag = tag
143
+ def __init__(
144
+ self,
145
+ state_tag: str,
146
+ state_to_exclude: Filter = Nothing()
147
+ ):
148
+ """
149
+ Initialize a new Catcher instance.
150
+
151
+ Args:
152
+ state_tag (str): The tag to be assigned to caught states.
153
+ state_to_exclude (Filter, optional): A filter to exclude states from being caught.
154
+ """
155
+ if state_to_exclude is None:
156
+ state_to_exclude = Nothing()
157
+ self.state_to_exclude = state_to_exclude
158
+ self.state_tag = state_tag
111
159
  self.state_ids = set()
112
160
  self.states = []
113
161
 
162
+ def get_state_values(self) -> List[PyTree]:
163
+ """
164
+ Get the values of the caught states.
165
+
166
+ Returns:
167
+ list: A list of values of the caught states.
168
+ """
169
+ return [state.value for state in self.states]
170
+
171
+ def get_states(self) -> List[State]:
172
+ """
173
+ Get the caught states.
174
+
175
+ Returns:
176
+ list: A list of the caught states.
177
+ """
178
+ return self.states
179
+
114
180
  def append(self, state: State):
181
+ """
182
+ Add a new state to the catcher if it hasn't been added before.
183
+
184
+ This method adds the state to the internal list, records its ID,
185
+ and assigns the catcher's tag to the state.
186
+
187
+ Args:
188
+ state (State): The State object to be added.
189
+ """
190
+ if self.state_to_exclude((), state):
191
+ return
115
192
  if id(state) not in self.state_ids:
116
193
  self.state_ids.add(id(state))
117
194
  self.states.append(state)
118
- state.tag = self.tag
195
+ state.tag = self.state_tag
196
+
197
+ def __iter__(self):
198
+ """
199
+ Allow iteration over the caught states.
200
+
201
+ Returns:
202
+ iterator: An iterator over the list of caught states.
203
+ """
204
+ return iter(self.states)
205
+
206
+ def __len__(self):
207
+ """
208
+ Return the number of caught states.
209
+
210
+ Returns:
211
+ int: The number of caught states.
212
+ """
213
+ return len(self.states)
214
+
215
+ def __getitem__(self, index):
216
+ """
217
+ Get a state by index.
218
+
219
+ Args:
220
+ index (int): The index of the state to retrieve.
221
+
222
+ Returns:
223
+ State: The state at the specified index.
224
+ """
225
+ return self.states[index]
226
+
227
+ def clear(self):
228
+ """
229
+ Clear all caught states.
230
+ """
231
+ self.state_ids.clear()
232
+ self.states.clear()
233
+
234
+ def get_by_tag(self, tag: str):
235
+ """
236
+ Get all states with a specific tag.
237
+
238
+ Args:
239
+ tag (str): The tag to filter by.
240
+
241
+ Returns:
242
+ list: A list of states with the specified tag.
243
+ """
244
+ return [state for state in self.states if state.tag == tag]
245
+
246
+ def remove(self, state: State):
247
+ """
248
+ Remove a specific state from the catcher.
249
+
250
+ Args:
251
+ state (State): The state to remove.
252
+ """
253
+ if id(state) in self.state_ids:
254
+ self.state_ids.remove(id(state))
255
+ self.states.remove(state)
256
+
257
+ def __contains__(self, state: State):
258
+ """
259
+ Check if a state is in the catcher.
260
+
261
+ Args:
262
+ state (State): The state to check for.
263
+
264
+ Returns:
265
+ bool: True if the state is in the catcher, False otherwise.
266
+ """
267
+ return id(state) in self.state_ids
268
+
269
+
270
+ @contextlib.contextmanager
271
+ def catch_new_states(
272
+ state_tag: str = None,
273
+ state_to_exclude: Filter = Nothing()
274
+ ) -> Generator[StateCatcher, None, None]:
275
+ """
276
+ A context manager that catches and tracks new states created within its scope.
277
+
278
+ This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
279
+ new_state_catcher list. It allows for tracking and managing new states created
280
+ within the context.
281
+
282
+ Args:
283
+ state_tag (str, optional): A string tag to associate with the caught states.
284
+ Defaults to None.
285
+ state_to_exclude (Filter, optional): A filter object to specify which states
286
+ should be excluded from catching. Defaults to Nothing(), which excludes no states.
287
+
288
+ Yields:
289
+ Catcher: A Catcher object that can be used to access and manage the
290
+ newly created states within the context.
291
+
292
+ Example::
293
+
294
+ with catch_new_states("my_tag") as catcher:
295
+ # Create new states here
296
+ # They will be caught and tagged with "my_tag"
297
+ # Access caught states through catcher object
298
+ """
299
+ try:
300
+ catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
301
+ TRACE_CONTEXT.new_state_catcher.append(catcher)
302
+ yield catcher
303
+ finally:
304
+ TRACE_CONTEXT.new_state_catcher.pop()
305
+
306
+
307
+ def maybe_state(val: Any) -> Any:
308
+ """
309
+ Extracts the value from a State object if given, otherwise returns the input value.
119
310
 
311
+ This function is useful for handling both State objects and raw values uniformly.
312
+ If the input is a State object, it returns the value stored in that State.
313
+ If the input is not a State object, it returns the input as is.
120
314
 
121
- def maybe_state(val: Any):
315
+ Args:
316
+ val (Any): The input value, which can be either a State object or any other type.
317
+
318
+ Returns:
319
+ Any: The value stored in the State if the input is a State object,
320
+ otherwise the input value itself.
321
+ """
122
322
  if isinstance(val, State):
123
323
  return val.value
124
324
  else:
@@ -126,7 +326,7 @@ def maybe_state(val: Any):
126
326
 
127
327
 
128
328
  @contextlib.contextmanager
129
- def check_state_jax_tracer(val: bool = True) -> None:
329
+ def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
130
330
  """
131
331
  The context manager to check whether the state is valid to trace.
132
332
 
@@ -160,11 +360,16 @@ def check_state_jax_tracer(val: bool = True) -> None:
160
360
  @dataclasses.dataclass
161
361
  class StateMetadata(Generic[A]):
162
362
  """
163
- The state metadata.
363
+ A dataclass representing metadata for a state object.
164
364
 
165
- Args:
166
- raw_value: The raw value.
167
- metadata: The metadata.
365
+ This class encapsulates the raw value of a state along with associated metadata.
366
+ It is generic over the type of the raw value.
367
+
368
+ Attributes:
369
+ raw_value (A): The raw value of the state. The type A is a generic type parameter.
370
+ metadata (Mapping[str, Any]): A mapping of string keys to arbitrary values,
371
+ representing additional metadata for the state.
372
+ Defaults to an empty dictionary.
168
373
  """
169
374
  raw_value: A
170
375
  metadata: Mapping[str, Any] = dataclasses.field(default_factory=dict)
@@ -172,7 +377,30 @@ class StateMetadata(Generic[A]):
172
377
 
173
378
  def with_metadata(initializer: F, **metadata: Any) -> F:
174
379
  """
175
- A decorator to add metadata to the state.
380
+ A decorator that adds metadata to a state initialization function.
381
+
382
+ This decorator wraps the given initializer function, allowing additional
383
+ metadata to be associated with the state it creates. The metadata is
384
+ incorporated into a StateMetadata object along with the state's value.
385
+
386
+ Args:
387
+ initializer (F): The original state initialization function to be wrapped.
388
+ **metadata (Any): Arbitrary keyword arguments representing metadata
389
+ to be associated with the state.
390
+
391
+ Returns:
392
+ F: A wrapped version of the initializer function that returns a
393
+ StateMetadata object containing both the original state value
394
+ and the provided metadata.
395
+
396
+ Example::
397
+ @with_metadata(tag='model_param')
398
+ def init_weights(shape):
399
+ return np.zeros(shape)
400
+
401
+ state = init_weights((100, 100))
402
+ # state is now a StateMetadata object with the initialized weights
403
+ # and the 'tag' metadata
176
404
  """
177
405
 
178
406
  @wraps(initializer)
@@ -188,24 +416,54 @@ def _get_trace_stack_level() -> int:
188
416
 
189
417
  class State(Generic[A], PrettyObject):
190
418
  """
191
- The pointer to specify the dynamical data.
419
+ A generic class representing a dynamic data pointer in the BrainState framework.
192
420
 
193
- To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
194
-
195
- Example::
421
+ The State class serves as a base for various types of state objects used to
422
+ manage and track dynamic data within a program. It provides mechanisms for
423
+ value storage, metadata management, and integration with the BrainState
424
+ tracing system.
196
425
 
197
- >>> class MyState(State):
198
- >>> pass
426
+ Type Parameters:
427
+ A: The type of the value stored in the state.
199
428
 
200
- The typical examples of :py:class:`~.State` subclass are:
429
+ Attributes:
430
+ name (Optional[str]): An optional name for the state.
431
+ value (PyTree): The actual value stored in the state.
432
+ tag (Optional[str]): An optional tag for categorizing or grouping states.
201
433
 
202
- - :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
203
- - :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
204
- - :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
205
- - :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
434
+ Args:
435
+ value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
436
+ The initial value for the state. Can be a PyTree of array-like objects
437
+ or a StateMetadata object.
438
+ name (Optional[str]): An optional name for the state.
439
+ **metadata: Additional metadata to be stored with the state.
440
+
441
+ Example:
442
+ >>> class MyState(State):
443
+ ... pass
444
+ >>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
445
+ >>> print(state.value)
446
+ [[0. 0. 0.]
447
+ [0. 0. 0.]
448
+ [0. 0. 0.]]
449
+
450
+ Note:
451
+ - Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
452
+ RandomState) are typically used for specific purposes in a program.
453
+ - The class integrates with BrainState's tracing system to track state
454
+ creation and modifications.
455
+
456
+ The typical examples of :py:class:`~.State` subclass are:
457
+
458
+ - :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
459
+ - :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
460
+ - :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
461
+ - :py:class:`RandomState`: The random generator state, which is used to store the random key in the program.
206
462
 
207
463
  Args:
208
- value: PyTree. It can be anything as a pyTree.
464
+ value: PyTree. It can be anything as a pyTree.
465
+ name: Optional[str]. The name of the state.
466
+ tag: Optional[str]. The tag of the state.
209
467
  """
210
468
  __module__ = 'brainstate'
211
469
  _level: int
@@ -221,6 +479,25 @@ class State(Generic[A], PrettyObject):
221
479
  name: Optional[str] = None,
222
480
  **metadata: Any
223
481
  ):
482
+ """
483
+ Initialize a new HiddenState instance.
484
+
485
+ This constructor sets up the initial state for a hidden state in a dynamic model,
486
+ handling various input types and metadata.
487
+
488
+ Args:
489
+ value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
490
+ The initial value for the hidden state. Can be a PyTree of array-like objects
491
+ or a StateMetadata object containing both value and metadata.
492
+ name (Optional[str], optional): A name for the hidden state. Defaults to None.
493
+ **metadata: Additional metadata to be stored with the hidden state, including:
494
+ - tag (Optional[str]): A tag for categorizing or grouping states.
495
+ - Any other custom metadata fields.
496
+
497
+ Note:
498
+ This method initializes the hidden state, processes the input value and metadata,
499
+ sets up internal attributes, and records the state initialization.
500
+ """
224
501
  tag = metadata.pop('tag', None)
225
502
 
226
503
  # set the value and metadata
@@ -231,12 +508,14 @@ class State(Generic[A], PrettyObject):
231
508
  value = value.value
232
509
 
233
510
  # update metadata
234
- metadata.update(_value=value,
235
- _level=_get_trace_stack_level(),
236
- _source_info=source_info_util.current(),
237
- _name=name,
238
- tag=tag,
239
- _been_writen=False)
511
+ metadata.update(
512
+ _value=value,
513
+ _level=_get_trace_stack_level(),
514
+ _source_info=source_info_util.current(),
515
+ _name=name,
516
+ _been_writen=False,
517
+ tag=tag,
518
+ )
240
519
 
241
520
  # avoid using self._setattr to avoid the check
242
521
  vars(self).update(metadata)
@@ -244,6 +523,26 @@ class State(Generic[A], PrettyObject):
244
523
  # record the state initialization
245
524
  record_state_init(self)
246
525
 
526
+ def decrease_stack_level(self):
527
+ """
528
+ Decrease the stack level of the state by one, ensuring it doesn't go below zero.
529
+
530
+ This method is used to adjust the stack level of the state, typically when
531
+ exiting a nested context or scope. It ensures that the level never becomes
532
+ negative.
533
+ """
534
+ self._level = max(self._level - 1, 0)
535
+
536
+ def increase_stack_level(self):
537
+ """
538
+ Increase the stack level of the state by one.
539
+
540
+ This method is used to adjust the stack level of the state, typically when
541
+ entering a nested context or scope. It increments the internal level counter
542
+ by one.
543
+ """
544
+ self._level = self._level + 1
545
+
247
546
  @property
248
547
  def name(self) -> Optional[str]:
249
548
  """
@@ -438,19 +737,19 @@ class State(Generic[A], PrettyObject):
438
737
 
439
738
  def __pretty_repr_item__(self, k, v):
440
739
  if k in ['_level', '_source_info', '_been_writen']:
441
- return None, None
740
+ return None
442
741
  if k == '_value':
443
- return 'value', v
742
+ return 'value', jax.tree.map(shaped_abstractify, v)
444
743
 
445
744
  if k == '_name':
446
745
  if self.name is None:
447
- return None, None
746
+ return None
448
747
  else:
449
748
  return 'name', v
450
749
 
451
750
  if k == 'tag':
452
751
  if self.tag is None:
453
- return None, None
752
+ return None
454
753
  else:
455
754
  return 'tag', v
456
755
 
@@ -486,32 +785,116 @@ class State(Generic[A], PrettyObject):
486
785
 
487
786
 
488
787
  def record_state_init(st: State[A]):
489
- trace: Catcher
788
+ """
789
+ Record the initialization of a new :class:`State` object.
790
+
791
+ This function iterates through all registered state catchers in the current
792
+ trace context and appends the newly initialized state to each catcher.
793
+
794
+ Args:
795
+ st (State[A]): The newly initialized :class:`State` object to be recorded.
796
+
797
+ Note:
798
+ This function is typically called internally when a new :class:`State` object
799
+ is created to ensure proper tracking and management of states within
800
+ the current execution context.
801
+ """
802
+ trace: StateCatcher
490
803
  for trace in TRACE_CONTEXT.new_state_catcher:
491
804
  trace.append(st)
492
805
 
493
806
 
494
807
  def record_state_value_read(st: State[A]):
808
+ """
809
+ Record that a state's value has been read in all relevant trace stacks.
810
+
811
+ This function iterates through all state trace stacks at or above the
812
+ state's stack level in the current trace context, and records that
813
+ the given state's value has been read.
814
+
815
+ Args:
816
+ st (State[A]): The state object whose value read is being recorded.
817
+ 'A' is a generic type parameter representing the
818
+ type of the state's value.
819
+
820
+ Note:
821
+ This function modifies the state trace stacks in the current
822
+ trace context but does not return any value.
823
+ """
495
824
  trace: StateTraceStack
496
825
  for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
497
826
  trace.read_its_value(st)
498
827
 
499
828
 
500
829
  def record_state_value_write(st: State[A]):
830
+ """
831
+ Record that a state's value has been written in all relevant trace stacks.
832
+
833
+ This function iterates through all state trace stacks at or above the
834
+ state's stack level in the current trace context, and records that
835
+ the given state's value has been written.
836
+
837
+ Args:
838
+ st (State[A]): The state object whose value write is being recorded.
839
+ 'A' is a generic type parameter representing the
840
+ type of the state's value.
841
+
842
+ Note:
843
+ This function modifies the state trace stacks in the current
844
+ trace context but does not return any value.
845
+ """
501
846
  trace: StateTraceStack
502
847
  for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
503
848
  trace.write_its_value(st)
504
849
 
505
850
 
506
851
  def record_state_value_restore(st: State[A]):
852
+ """
853
+ Record that a state's value has been restored.
854
+
855
+ This function is used to indicate that a state's value has been restored
856
+ to a previous value. It internally calls the record_state_value_read
857
+ function to mark the state as having been accessed.
858
+
859
+ Args:
860
+ st (State[A]): The state object whose value restoration is being recorded.
861
+ 'A' is a generic type parameter representing the
862
+ type of the state's value.
863
+
864
+ See Also:
865
+ record_state_value_read: Record that a state's value has been read.
866
+
867
+ Note:
868
+ This function does not actually restore the state's value; it only
869
+ records that a restoration has occurred.
870
+ """
507
871
  record_state_value_read(st)
508
872
 
509
873
 
510
874
  class ShortTermState(State):
511
875
  """
512
- The short-term state, which is used to store the short-term data in the program.
876
+ A class representing short-term state in a program.
877
+
878
+ :class:`ShortTermState` is used to store temporary or transient data that is only relevant
879
+ for a short duration within the program's execution. This class extends the base
880
+ State class, inheriting its properties and methods while specifically denoting
881
+ the short-term nature of the stored data.
882
+
883
+ For example, in a machine learning training process, the gradients of the model
884
+ would typically be represented as :class:`ShortTermState`, as they are computed and used
885
+ within each iteration but not necessarily preserved across iterations.
513
886
 
514
- For example, in a training process, the gradients of the model are short-term states.
887
+ Attributes:
888
+ Inherits all attributes from the base State class.
889
+
890
+ Note:
891
+ This class does not introduce new methods or attributes beyond those
892
+ inherited from the State class. Its primary purpose is to semantically
893
+ distinguish short-term states from other types of states in the program.
894
+
895
+ Example:
896
+ >>> gradient = ShortTermState(np.zeros(100), name="model_gradient")
897
+ >>> intermediate_result = ShortTermState({}, name="layer_activations")
515
898
  """
516
899
 
517
900
  __module__ = 'brainstate'
@@ -521,7 +904,25 @@ class LongTermState(State):
521
904
  """
522
905
  The long-term state, which is used to store the long-term data in the program.
523
906
 
524
- For example, in a training process, the weights of the model are long-term states.
907
+ This class extends the base :class:`State` class and is specifically designed to represent
908
+ and manage long-term data within a program. Long-term states are typically used
909
+ for data that persists across multiple iterations or epochs of a process.
910
+
911
+ For example, in a training process, the weights of the model are considered
912
+ long-term states as they are updated and maintained throughout the entire
913
+ training procedure.
914
+
915
+ Attributes:
916
+ Inherits all attributes from the base :class:`State` class.
917
+
918
+ Note:
919
+ This class does not introduce new methods or attributes beyond those
920
+ inherited from the :class:`State` class. Its primary purpose is to semantically
921
+ distinguish long-term states from other types of states in the program.
922
+
923
+ Example:
924
+ >>> model_weights = LongTermState(np.random.randn(100, 100), name="model_weights")
925
+ >>> optimizer_state = LongTermState({}, name="optimizer_state")
525
926
  """
526
927
 
527
928
  __module__ = 'brainstate'
@@ -530,6 +931,24 @@ class LongTermState(State):
530
931
  class BatchState(LongTermState):
531
932
  """
532
933
  The batch state, which is used to store the batch data in the program.
934
+
935
+ This class extends :class:`LongTermState` and is specifically designed to represent
936
+ and manage batch data within a program. It provides a way to encapsulate
937
+ batch-related information and associated metadata, facilitating operations
938
+ like batch processing in machine learning or data analysis tasks.
939
+
940
+ Attributes:
941
+ Inherits all attributes from :class:`LongTermState`.
942
+
943
+ Note:
944
+ This class does not introduce new methods or attributes beyond those
945
+ inherited from :class:`LongTermState`. Its primary purpose is to semantically
946
+ distinguish batch states from other types of long-term states
947
+ in the program.
948
+
949
+ Example:
950
+ >>> batch_data = BatchState(np.array([1, 2, 3, 4, 5]), name="current_batch")
951
+ >>> batch_labels = BatchState(np.array([0, 1, 0, 1, 1]), name="batch_labels")
533
952
  """
534
953
 
535
954
  __module__ = 'brainstate'
@@ -538,6 +957,24 @@ class BatchState(LongTermState):
538
957
  class HiddenState(ShortTermState):
539
958
  """
540
959
  The hidden state, which is used to store the hidden data in a dynamic model.
960
+
961
+ This class extends :class:`ShortTermState` and is specifically designed to represent
962
+ and manage hidden states within dynamic models, such as recurrent neural networks.
963
+ It provides a way to encapsulate hidden state values and associated metadata,
964
+ facilitating operations like state updates during model execution.
965
+
966
+ Attributes:
967
+ Inherits all attributes from :class:`ShortTermState`.
968
+
969
+ Note:
970
+ This class does not introduce new methods or attributes beyond those
971
+ inherited from :class:`ShortTermState`. Its primary purpose is to semantically
972
+ distinguish hidden states from other types of short-term states
973
+ in dynamic models.
974
+
975
+ Example:
976
+ >>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
977
+ >>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
541
978
  """
542
979
 
543
980
  __module__ = 'brainstate'
@@ -546,12 +983,30 @@ class HiddenState(ShortTermState):
546
983
  class ParamState(LongTermState):
547
984
  """
548
985
  The parameter state, which is used to store the trainable parameters in the model.
986
+
987
+ This class extends :class:`LongTermState` and is specifically designed to represent
988
+ and manage trainable parameters within a neural network or machine learning model.
989
+ It provides a way to encapsulate parameter values and associated metadata,
990
+ facilitating operations like parameter updates during training.
991
+
992
+ Attributes:
993
+ Inherits all attributes from :class:`LongTermState`.
994
+
995
+ Note:
996
+ This class does not introduce new methods or attributes beyond those
997
+ inherited from :class:`LongTermState`. Its primary purpose is to semantically
998
+ distinguish parameter states from other types of long-term states
999
+ in the model.
1000
+
1001
+ Example:
1002
+ >>> weight = ParamState(np.random.randn(10, 10), name="layer1_weights")
1003
+ >>> bias = ParamState(np.zeros(10), name="layer1_bias")
549
1004
  """
550
1005
 
551
1006
  __module__ = 'brainstate'
552
1007
 
553
1008
 
554
- class FakedState:
1009
+ class FakeState:
555
1010
  """
556
1011
  The faked state, which is used to store the faked data in the program.
557
1012
  """
@@ -559,26 +1014,63 @@ class FakedState:
559
1014
  __module__ = 'brainstate'
560
1015
 
561
1016
  def __init__(self, value: Any, name: Optional[str] = None):
1017
+ """
1018
+ Initialize a FakeState instance.
1019
+
1020
+ Args:
1021
+ value (Any): The value to be stored in the fake state.
1022
+ name (Optional[str], optional): The name of the fake state. Defaults to None.
1023
+ """
562
1024
  self._value = value
563
1025
  self._name = name
564
1026
 
565
1027
  @property
566
1028
  def value(self) -> Any:
1029
+ """
1030
+ Get the value stored in the fake state.
1031
+
1032
+ Returns:
1033
+ Any: The value stored in the fake state.
1034
+ """
567
1035
  return self._value
568
1036
 
569
1037
  @value.setter
570
1038
  def value(self, v) -> None:
1039
+ """
1040
+ Set the value of the fake state.
1041
+
1042
+ Args:
1043
+ v (Any): The new value to be stored in the fake state.
1044
+ """
571
1045
  self._value = v
572
1046
 
573
1047
  def __repr__(self) -> str:
1048
+ """
1049
+ Return a string representation of the FakeState instance.
1050
+
1051
+ Returns:
1052
+ str: A string representation of the FakeState instance.
1053
+ """
574
1054
  return f'FakedState(value={self._value})'
575
1055
 
576
1056
  @property
577
1057
  def name(self) -> Optional[str]:
1058
+ """
1059
+ Get the name of the fake state.
1060
+
1061
+ Returns:
1062
+ Optional[str]: The name of the fake state, or None if not set.
1063
+ """
578
1064
  return self._name
579
1065
 
580
1066
  @name.setter
581
1067
  def name(self, name: str) -> None:
1068
+ """
1069
+ Set the name of the fake state.
1070
+
1071
+ Args:
1072
+ name (str): The new name for the fake state.
1073
+ """
582
1074
  self._name = name
583
1075
 
584
1076
 
@@ -641,20 +1133,73 @@ class StateDictManager(DictManager):
641
1133
 
642
1134
  class StateTraceStack(Generic[A]):
643
1135
  """
644
- The state trace stack, which is used to trace the states automatically.
1136
+ A stack for tracing and managing states during program execution.
1137
+
1138
+ ``StateTraceStack`` is used to automatically trace and manage State objects,
1139
+ keeping track of which states are read from or written to during the
1140
+ execution of a function or block of code. It provides methods for
1141
+ recording state accesses, retrieving state values, and managing the
1142
+ lifecycle of states within a tracing context.
1143
+
1144
+ The class is generic over type A, allowing for type-safe usage with
1145
+ different types of State objects.
1146
+
1147
+ Attributes:
1148
+ states (List[State]): A list of all State objects encountered during tracing.
1149
+ been_writen (List[bool]): A parallel list to states, indicating whether each state has been written to.
1150
+ _state_id_index (dict): A dictionary mapping state ids to their index in the states list.
1151
+ _original_state_values (List): A list of the original values of all states when first encountered.
1152
+ _jax_trace_new_arg (Callable): A function used to transform state values during tracing.
1153
+
1154
+ Methods:
1155
+ __enter__: Enters a new tracing context.
1156
+ __exit__: Exits the current tracing context.
1157
+ read_its_value: Records a read operation on a state.
1158
+ write_its_value: Records a write operation on a state.
1159
+ get_state_values: Retrieves the current values of all traced states.
1160
+ recovery_original_values: Restores all states to their original values.
1161
+ merge: Merges multiple ``StateTraceStack`` instances.
1162
+ get_read_states: Retrieves states that were read during tracing.
1163
+ get_read_state_values: Retrieves values of states that were read during tracing.
1164
+
1165
+ The ``StateTraceStack`` is a crucial component in implementing state-based
1166
+ computations and is particularly useful in scenarios involving automatic
1167
+ differentiation or other forms of program transformation.
645
1168
  """
646
1169
 
647
- def __init__(self, new_arg: Callable = None):
1170
+ def __init__(
1171
+ self,
1172
+ new_arg: Callable = None,
1173
+ name: Optional[str] = None,
1174
+ ):
1175
+ self.name = name
648
1176
  self.states: List[State] = []
649
1177
  self.been_writen: List[bool] = [] # False: read, True: write
650
1178
  self._state_id_index = dict()
651
1179
  self._original_state_values = []
652
1180
  self._jax_trace_new_arg: Callable = new_arg
1181
+ self._stack_level = None
1182
+
1183
+ def __str__(self) -> str:
1184
+ _stack_level = self.name if self._stack_level is None else self._stack_level
1185
+ if _stack_level is None:
1186
+ _stack_level = ''
1187
+ return f"{self.__class__.__name__}({_stack_level})"
653
1188
 
654
1189
  @property
655
1190
  def original_state_values(self) -> Tuple[PyTree, ...]:
656
1191
  """
657
- The original values of the states.
1192
+ Get the original values of all states in the StateTraceStack.
1193
+
1194
+ This property provides access to the initial values of all states
1195
+ that were captured when they were first added to the stack. It's
1196
+ useful for comparing current state values with their original values
1197
+ or for reverting states to their initial condition.
1198
+
1199
+ Returns:
1200
+ Tuple[PyTree, ...]: A tuple containing the original values of all
1201
+ states in the order they were added to the stack. Each element
1202
+ is a PyTree representing the structure and values of a state.
658
1203
  """
659
1204
  return tuple(self._original_state_values)
660
1205
 
@@ -662,12 +1207,30 @@ class StateTraceStack(Generic[A]):
662
1207
  self._jax_trace_new_arg = new_arg
663
1208
 
664
1209
  def new_arg(self, state: State) -> None:
1210
+ """
1211
+ Apply a transformation to the value of a given state using a predefined function.
1212
+
1213
+ This method is used internally to transform the value of a state during tracing.
1214
+ If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
1215
+ function to each element of the state's value using JAX's tree mapping.
1216
+
1217
+ Args:
1218
+ state (State): The State object whose value needs to be transformed.
1219
+
1220
+ Returns:
1221
+ None: This function modifies the state in-place and doesn't return anything.
1222
+
1223
+ Note:
1224
+ This method is intended for internal use and relies on the presence of
1225
+ a ``_jax_trace_new_arg`` function, which should be set separately.
1226
+ """
665
1227
  if self._jax_trace_new_arg is not None:
666
1228
  # internal use
667
1229
  state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
668
1230
 
669
1231
  def __enter__(self) -> 'StateTraceStack':
670
1232
  TRACE_CONTEXT.state_stack.append(self)
1233
+ self._stack_level = ' / '.join([st.name for st in TRACE_CONTEXT.state_stack if st.name is not None])
671
1234
  return self
672
1235
 
673
1236
  def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
@@ -675,10 +1238,22 @@ class StateTraceStack(Generic[A]):
675
1238
 
676
1239
  def read_its_value(self, state: State) -> None:
677
1240
  """
678
- Read the value of the state.
679
-
1241
+ Record that a state's value has been read during tracing.
1242
+
1243
+ This method marks the given state as having been read in the current
1244
+ tracing context. If the state hasn't been encountered before, it adds
1245
+ it to the internal tracking structures and applies any necessary
1246
+ transformations via the new_arg method.
1247
+
680
1248
  Args:
681
- state: The state.
1249
+ state (State): The State object whose value is being read.
1250
+
1251
+ Returns:
1252
+ None
1253
+
1254
+ Note:
1255
+ This method updates the internal tracking of state accesses.
1256
+ It doesn't actually read or return the state's value.
682
1257
  """
683
1258
  id_ = id(state)
684
1259
  if id_ not in self._state_id_index:
@@ -690,10 +1265,21 @@ class StateTraceStack(Generic[A]):
690
1265
 
691
1266
  def write_its_value(self, state: State) -> None:
692
1267
  """
693
- Write the value of the state.
1268
+ Record that a state's value has been written to during tracing.
1269
+
1270
+ This method marks the given state as having been written to in the current
1271
+ tracing context. If the state hasn't been encountered before, it first
1272
+ records it as being read before marking it as written.
694
1273
 
695
1274
  Args:
696
- state: The state.
1275
+ state (State): The State object whose value is being written to.
1276
+
1277
+ Returns:
1278
+ None
1279
+
1280
+ Note:
1281
+ This method updates the internal tracking of state modifications.
1282
+ It doesn't actually modify the state's value.
697
1283
  """
698
1284
  id_ = id(state)
699
1285
  if id_ not in self._state_id_index:
@@ -701,10 +1287,37 @@ class StateTraceStack(Generic[A]):
701
1287
  index = self._state_id_index[id_]
702
1288
  self.been_writen[index] = True
703
1289
 
704
- def get_state_values(self, separate: bool = False, replace: bool = False
705
- ) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
706
- """
707
- Get the values of the states.
1290
+ def get_state_values(
1291
+ self,
1292
+ separate: bool = False,
1293
+ replace: bool = False
1294
+ ) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1295
+ """
1296
+ Retrieve the values of all states in the StateTraceStack.
1297
+
1298
+ This method returns the values of all states, optionally separating them
1299
+ into written and read states, and optionally replacing values with None
1300
+ for states that weren't accessed in a particular way.
1301
+
1302
+ Args:
1303
+ separate (bool, optional): If True, separate the values into written
1304
+ and read states. If False, return all values in a single sequence.
1305
+ Defaults to False.
1306
+ replace (bool, optional): If True and separate is True, replace values
1307
+ with None for states that weren't written/read. If False, only
1308
+ include values for states that were written/read. Defaults to False.
1309
+
1310
+ Returns:
1311
+ Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1312
+ If separate is False:
1313
+ A sequence of all state values.
1314
+ If separate is True:
1315
+ A tuple containing two sequences:
1316
+ - The first sequence contains values of written states.
1317
+ - The second sequence contains values of read states.
1318
+ If replace is True, these sequences will have None for
1319
+ states that weren't written/read respectively.
1320
+
708
1321
  """
709
1322
  if separate:
710
1323
  if replace:
@@ -730,7 +1343,18 @@ class StateTraceStack(Generic[A]):
730
1343
 
731
1344
  def recovery_original_values(self) -> None:
732
1345
  """
733
- Recovery the original values.
1346
+ Restore the original values of all states in the StateTraceStack.
1347
+
1348
+ This method iterates through all states in the stack and restores
1349
+ their values to the original ones that were captured when the states
1350
+ were first added to the stack. This is useful for reverting changes
1351
+ made during tracing or for resetting the states to their initial condition.
1352
+
1353
+ Note:
1354
+ This method modifies the states in-place.
1355
+
1356
+ Returns:
1357
+ None
734
1358
  """
735
1359
  for st, val in zip(self.states, self._original_state_values):
736
1360
  # internal use
@@ -738,7 +1362,22 @@ class StateTraceStack(Generic[A]):
738
1362
 
739
1363
  def merge(self, *traces) -> 'StateTraceStack':
740
1364
  """
741
- Merge other state traces.
1365
+ Merge other state traces into the current ``StateTraceStack``.
1366
+
1367
+ This method combines the states, their write status, and original values from
1368
+ other ``StateTraceStack`` instances into the current one. If a state from another
1369
+ trace is not present in the current trace, it is added. If a state is already
1370
+ present, its write status is updated if necessary.
1371
+
1372
+ Args:
1373
+ *traces: Variable number of ``StateTraceStack`` instances to be merged into
1374
+ the current instance.
1375
+
1376
+ Returns:
1377
+ StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
1378
+
1379
+ Note:
1380
+ This method modifies the current ``StateTraceStack`` in-place and also returns it.
742
1381
  """
743
1382
  trace: StateTraceStack
744
1383
  for trace in traces:
@@ -754,10 +1393,22 @@ class StateTraceStack(Generic[A]):
754
1393
 
755
1394
  def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
756
1395
  """
757
- Read the states that are read by the function.
758
-
1396
+ Retrieve the states that were read during the function execution.
1397
+
1398
+ This method returns the states that were accessed (read from) during
1399
+ the traced function's execution. It can optionally replace written
1400
+ states with None.
1401
+
1402
+ Args:
1403
+ replace_writen (bool, optional): If True, replace written states with None
1404
+ in the returned tuple. If False, exclude written states entirely from
1405
+ the result. Defaults to False.
1406
+
759
1407
  Returns:
760
- The states that are read by the function.
1408
+ Tuple[State, ...]: A tuple containing the read states.
1409
+ If replace_writen is True, the tuple will have the same length as the
1410
+ total number of states, with None for written states.
1411
+ If replace_writen is False, the tuple will only contain read-only states.
761
1412
  """
762
1413
  if replace_writen:
763
1414
  return tuple([st if not been_writen else None
@@ -767,23 +1418,49 @@ class StateTraceStack(Generic[A]):
767
1418
 
768
1419
  def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
769
1420
  """
770
- Read the states that are read by the function.
771
-
1421
+ Retrieve the values of states that were read during the function execution.
1422
+
1423
+ This method returns the values of states that were accessed (read from) during
1424
+ the traced function's execution. It can optionally replace written states with None.
1425
+
1426
+ Args:
1427
+ replace_writen (bool, optional): If True, replace the values of written
1428
+ states with None in the returned tuple. If False, exclude written
1429
+ states entirely from the result. Defaults to False.
1430
+
772
1431
  Returns:
773
- The states that are read by the function.
1432
+ Tuple[PyTree, ...]: A tuple containing the values of read states.
1433
+ If replace_writen is True, the tuple will have the same length as the
1434
+ total number of states, with None for written states.
1435
+ If replace_writen is False, the tuple will only contain values of
1436
+ read-only states.
774
1437
  """
775
1438
  if replace_writen:
776
1439
  return tuple(
777
- [st.value if not been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
1440
+ [st.value if not been_writen else None
1441
+ for st, been_writen in zip(self.states, self.been_writen)]
1442
+ )
778
1443
  else:
779
1444
  return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
780
1445
 
781
1446
  def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
782
1447
  """
783
- Read the states that are written by the function.
784
-
1448
+ Retrieve the states that were written during the function execution.
1449
+
1450
+ This method returns the states that were modified (written to) during
1451
+ the traced function's execution. It can optionally replace unwritten (read-only)
1452
+ states with None.
1453
+
1454
+ Args:
1455
+ replace_read (bool, optional): If True, replace read-only states with None
1456
+ in the returned tuple. If False, exclude read-only states entirely from
1457
+ the result. Defaults to False.
1458
+
785
1459
  Returns:
786
- The states that are written by the function.
1460
+ Tuple[State, ...]: A tuple containing the written states.
1461
+ If replace_read is True, the tuple will have the same length as the
1462
+ total number of states, with None for read-only states.
1463
+ If replace_read is False, the tuple will only contain written states.
787
1464
  """
788
1465
  if replace_read:
789
1466
  return tuple([st if been_writen else None
@@ -793,10 +1470,24 @@ class StateTraceStack(Generic[A]):
793
1470
 
794
1471
  def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
795
1472
  """
796
- Read the states that are written by the function.
797
-
1473
+ Retrieve the values of states that were written during the function execution.
1474
+
1475
+ This method returns the values of states that were modified (written to) during
1476
+ the traced function's execution. It can optionally replace unwritten (read-only)
1477
+ states with None.
1478
+
1479
+ Args:
1480
+ replace_read (bool, optional): If True, replace the values of read-only
1481
+ states with None in the returned tuple. If False, exclude read-only
1482
+ states entirely from the result. Defaults to False.
1483
+
798
1484
  Returns:
799
- The states that are written by the function.
1485
+ Tuple[PyTree, ...]: A tuple containing the values of written states.
1486
+ If replace_read is True, the tuple will have the same length as the
1487
+ total number of states, with None for read-only states.
1488
+ If replace_read is False, the tuple will only contain values of
1489
+ written states.
1490
+
800
1491
  """
801
1492
  if replace_read:
802
1493
  return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
@@ -809,6 +1500,62 @@ class StateTraceStack(Generic[A]):
809
1500
  """
810
1501
  return StateTraceStack().merge(self, other)
811
1502
 
1503
+ def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
1504
+ """
1505
+ Assign new values to the states tracked by this ``StateTraceStack``.
1506
+
1507
+ This method updates the values of the states based on whether they were
1508
+ written to or only read during the tracing process. For states that were
1509
+ written to, it directly assigns the new value. For states that were only
1510
+ read, it restores the value using the state's restore_value method.
1511
+
1512
+ Args:
1513
+ state_vals (Sequence[PyTree]): A sequence of new state values to be
1514
+ assigned. Each element in this sequence corresponds to a state
1515
+ in the ``StateTraceStack``'s states list.
1516
+
1517
+ Raises:
1518
+ ValueError: If the length of state_vals doesn't match the number of
1519
+ states in the ``StateTraceStack``.
1520
+
1521
+ Returns:
1522
+ None
1523
+
1524
+ Note:
1525
+ The order of state_vals should match the order of states in the
1526
+ ``StateTraceStack``'s states list.
1527
+ """
1528
+ if len(state_vals) != len(self.states):
1529
+ raise ValueError('The length of the state values must be equal to the states. '
1530
+ f'Bug got {len(state_vals)} and {len(self.states)}')
1531
+ for st, written, val in zip(self.states, self.been_writen, state_vals):
1532
+ if written:
1533
+ st.value = val
1534
+ else:
1535
+ st.restore_value(val)
1536
+
1537
+ def state_subset(self, state_type: type) -> List:
1538
+ """
1539
+ Get a subset of states of a specific type from the ``StateTraceStack``.
1540
+
1541
+ This method filters the states in the ``StateTraceStack`` and returns only
1542
+ those that match the specified state type.
1543
+
1544
+ Args:
1545
+ state_type (type): The type of state to filter by. This should be
1546
+ a subclass of State or State itself.
1547
+
1548
+ Returns:
1549
+ List[State]: A list containing all states in the ``StateTraceStack``
1550
+ that are instances of the specified state_type.
1551
+
1552
+ Example:
1553
+ >>> stack = StateTraceStack()
1554
+ >>> # Assume stack has been populated with various state types
1555
+ >>> short_term_states = stack.state_subset(ShortTermState)
1556
+ """
1557
+ return [st for st in self.states if isinstance(st, state_type)]
1558
+
812
1559
 
813
1560
  class TreefyState(Generic[A], PrettyObject):
814
1561
  """
@@ -834,14 +1581,28 @@ class TreefyState(Generic[A], PrettyObject):
834
1581
 
835
1582
  def __pretty_repr_item__(self, k, v):
836
1583
  if k in ['_level', '_source_info', '_been_writen']:
837
- return None, None
1584
+ return None
838
1585
  if k == '_value':
839
1586
  return 'value', v
840
1587
 
841
1588
  if k == '_name':
842
- return (None, None) if v is None else ('name', v)
1589
+ return None if v is None else ('name', v)
843
1590
  return k, v
844
1591
 
1592
+ @property
1593
+ def name(self) -> Optional[str]:
1594
+ """
1595
+ The name of the state.
1596
+ """
1597
+ return self._name
1598
+
1599
+ @name.setter
1600
+ def name(self, name: str) -> None:
1601
+ """
1602
+ Set the name of the state.
1603
+ """
1604
+ self._name = name
1605
+
845
1606
  def replace(self, value: B) -> TreefyState[B]:
846
1607
  """
847
1608
  Replace the value of the state reference.
@@ -856,7 +1617,9 @@ class TreefyState(Generic[A], PrettyObject):
856
1617
  # __init__ logic which should not be called twice
857
1618
  metadata = self.get_metadata()
858
1619
  state = object.__new__(self.type)
859
- vars(state).update(metadata, _value=self.value, _level=_get_trace_stack_level())
1620
+ metadata.pop('_value', None)
1621
+ metadata.pop('_level', None)
1622
+ vars(state).update(**metadata, _value=self.value, _level=_get_trace_stack_level())
860
1623
  return state
861
1624
 
862
1625
  def copy(self: TreefyState[A]) -> TreefyState[A]: