brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +875 -93
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +183 -35
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +128 -10
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -17,31 +17,47 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import contextlib
|
19
19
|
import dataclasses
|
20
|
-
import threading
|
21
|
-
from functools import wraps, partial
|
22
|
-
from typing import (
|
23
|
-
Any, Union, Callable, Generic, Mapping,
|
24
|
-
TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence
|
25
|
-
)
|
26
|
-
|
27
20
|
import jax
|
28
21
|
import numpy as np
|
22
|
+
import threading
|
23
|
+
from functools import wraps, partial
|
29
24
|
from jax.api_util import shaped_abstractify
|
30
25
|
from jax.extend import source_info_util
|
26
|
+
from typing import (
|
27
|
+
Any,
|
28
|
+
Union,
|
29
|
+
Callable,
|
30
|
+
Generic,
|
31
|
+
Mapping,
|
32
|
+
TypeVar,
|
33
|
+
Optional,
|
34
|
+
TYPE_CHECKING,
|
35
|
+
Tuple,
|
36
|
+
Dict,
|
37
|
+
List,
|
38
|
+
Sequence,
|
39
|
+
Generator,
|
40
|
+
)
|
31
41
|
|
32
|
-
from brainstate.typing import ArrayLike, PyTree, Missing
|
33
|
-
from brainstate.util import DictManager,
|
42
|
+
from brainstate.typing import ArrayLike, PyTree, Missing, Filter
|
43
|
+
from brainstate.util import DictManager, PrettyObject
|
44
|
+
from brainstate.util.filter import Nothing
|
34
45
|
|
35
46
|
__all__ = [
|
36
47
|
'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
|
37
|
-
'
|
48
|
+
'FakeState',
|
38
49
|
|
39
|
-
'StateDictManager',
|
50
|
+
'StateDictManager',
|
51
|
+
'StateTraceStack',
|
52
|
+
'check_state_value_tree',
|
53
|
+
'check_state_jax_tracer',
|
54
|
+
'catch_new_states',
|
40
55
|
'maybe_state',
|
41
56
|
]
|
42
57
|
|
43
58
|
A = TypeVar('A')
|
44
59
|
B = TypeVar('B')
|
60
|
+
T = TypeVar('T')
|
45
61
|
F = TypeVar('F', bound=Callable[..., Any])
|
46
62
|
|
47
63
|
max_int = np.iinfo(np.int32)
|
@@ -51,18 +67,38 @@ max_int = np.iinfo(np.int32)
|
|
51
67
|
# This allows concurrent tracing in separate threads; passing traced objects
|
52
68
|
# between threads is forbidden.
|
53
69
|
class ThreadLocalStack(threading.local):
|
70
|
+
"""
|
71
|
+
A thread-local storage class for managing state-related information.
|
72
|
+
|
73
|
+
This class provides thread-local storage for various state management components,
|
74
|
+
ensuring that each thread has its own isolated set of state-related data structures.
|
75
|
+
|
76
|
+
Attributes:
|
77
|
+
state_stack (List[StateTraceStack]): A list to store StateTraceStack objects for the current thread.
|
78
|
+
tree_check (List[bool]): A list of boolean flags for tree structure checking, initialized with [False].
|
79
|
+
jax_tracer_check (List[bool]): A list of boolean flags for JAX tracer checking, initialized with [False].
|
80
|
+
new_state_catcher (List[StateCatcher]): A list to store Catcher objects for capturing new states in the current thread.
|
81
|
+
"""
|
82
|
+
|
54
83
|
def __init__(self):
|
84
|
+
"""
|
85
|
+
Initialize the ThreadLocalStack with empty data structures.
|
86
|
+
|
87
|
+
This constructor sets up the initial state for each thread-local instance,
|
88
|
+
creating empty lists for state stack, tree checking, JAX tracer checking,
|
89
|
+
and new state catching.
|
90
|
+
"""
|
55
91
|
self.state_stack: List[StateTraceStack] = []
|
56
92
|
self.tree_check: List[bool] = [False]
|
57
93
|
self.jax_tracer_check: List[bool] = [False]
|
58
|
-
self.new_state_catcher: List[
|
94
|
+
self.new_state_catcher: List[StateCatcher] = []
|
59
95
|
|
60
96
|
|
61
97
|
TRACE_CONTEXT = ThreadLocalStack()
|
62
98
|
|
63
99
|
|
64
100
|
@contextlib.contextmanager
|
65
|
-
def check_state_value_tree(val: bool = True) -> None:
|
101
|
+
def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
|
66
102
|
"""
|
67
103
|
The contex manager to check weather the tree structure of the state value keeps consistently.
|
68
104
|
|
@@ -91,34 +127,198 @@ def check_state_value_tree(val: bool = True) -> None:
|
|
91
127
|
TRACE_CONTEXT.tree_check.pop()
|
92
128
|
|
93
129
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
catcher = Catcher(tag)
|
98
|
-
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
99
|
-
yield catcher
|
100
|
-
finally:
|
101
|
-
TRACE_CONTEXT.new_state_catcher.pop()
|
130
|
+
class StateCatcher(PrettyObject):
|
131
|
+
"""
|
132
|
+
The catcher to catch and manage new states.
|
102
133
|
|
134
|
+
This class provides functionality to collect and tag new State objects.
|
135
|
+
It ensures that each state is only added once and assigns a tag to each state.
|
103
136
|
|
104
|
-
|
105
|
-
|
106
|
-
|
137
|
+
Attributes:
|
138
|
+
state_tag (str): A string identifier used to tag the caught states.
|
139
|
+
state_ids (set): A set of state IDs to ensure uniqueness.
|
140
|
+
states (list): A list to store the caught State objects.
|
107
141
|
"""
|
108
142
|
|
109
|
-
def __init__(
|
110
|
-
self
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
state_tag: str,
|
146
|
+
state_to_exclude: Filter = Nothing()
|
147
|
+
):
|
148
|
+
"""
|
149
|
+
Initialize a new Catcher instance.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
state_tag (str): The tag to be assigned to caught states.
|
153
|
+
state_to_exclude (Filter, optional): A filter to exclude states from being caught.
|
154
|
+
"""
|
155
|
+
if state_to_exclude is None:
|
156
|
+
state_to_exclude = Nothing()
|
157
|
+
self.state_to_exclude = state_to_exclude
|
158
|
+
self.state_tag = state_tag
|
111
159
|
self.state_ids = set()
|
112
160
|
self.states = []
|
113
161
|
|
162
|
+
def get_state_values(self) -> List[PyTree]:
|
163
|
+
"""
|
164
|
+
Get the values of the caught states.
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
list: A list of values of the caught states.
|
168
|
+
"""
|
169
|
+
return [state.value for state in self.states]
|
170
|
+
|
171
|
+
def get_states(self) -> List[State]:
|
172
|
+
"""
|
173
|
+
Get the caught states.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
list: A list of the caught states.
|
177
|
+
"""
|
178
|
+
return self.states
|
179
|
+
|
114
180
|
def append(self, state: State):
|
181
|
+
"""
|
182
|
+
Add a new state to the catcher if it hasn't been added before.
|
183
|
+
|
184
|
+
This method adds the state to the internal list, records its ID,
|
185
|
+
and assigns the catcher's tag to the state.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
state (State): The State object to be added.
|
189
|
+
"""
|
190
|
+
if self.state_to_exclude((), state):
|
191
|
+
return
|
115
192
|
if id(state) not in self.state_ids:
|
116
193
|
self.state_ids.add(id(state))
|
117
194
|
self.states.append(state)
|
118
|
-
state.tag = self.
|
195
|
+
state.tag = self.state_tag
|
196
|
+
|
197
|
+
def __iter__(self):
|
198
|
+
"""
|
199
|
+
Allow iteration over the caught states.
|
119
200
|
|
201
|
+
Returns:
|
202
|
+
iterator: An iterator over the list of caught states.
|
203
|
+
"""
|
204
|
+
return iter(self.states)
|
205
|
+
|
206
|
+
def __len__(self):
|
207
|
+
"""
|
208
|
+
Return the number of caught states.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
int: The number of caught states.
|
212
|
+
"""
|
213
|
+
return len(self.states)
|
214
|
+
|
215
|
+
def __getitem__(self, index):
|
216
|
+
"""
|
217
|
+
Get a state by index.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
index (int): The index of the state to retrieve.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
State: The state at the specified index.
|
224
|
+
"""
|
225
|
+
return self.states[index]
|
226
|
+
|
227
|
+
def clear(self):
|
228
|
+
"""
|
229
|
+
Clear all caught states.
|
230
|
+
"""
|
231
|
+
self.state_ids.clear()
|
232
|
+
self.states.clear()
|
233
|
+
|
234
|
+
def get_by_tag(self, tag: str):
|
235
|
+
"""
|
236
|
+
Get all states with a specific tag.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
tag (str): The tag to filter by.
|
120
240
|
|
121
|
-
|
241
|
+
Returns:
|
242
|
+
list: A list of states with the specified tag.
|
243
|
+
"""
|
244
|
+
return [state for state in self.states if state.tag == tag]
|
245
|
+
|
246
|
+
def remove(self, state: State):
|
247
|
+
"""
|
248
|
+
Remove a specific state from the catcher.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
state (State): The state to remove.
|
252
|
+
"""
|
253
|
+
if id(state) in self.state_ids:
|
254
|
+
self.state_ids.remove(id(state))
|
255
|
+
self.states.remove(state)
|
256
|
+
|
257
|
+
def __contains__(self, state: State):
|
258
|
+
"""
|
259
|
+
Check if a state is in the catcher.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
state (State): The state to check for.
|
263
|
+
|
264
|
+
Returns:
|
265
|
+
bool: True if the state is in the catcher, False otherwise.
|
266
|
+
"""
|
267
|
+
return id(state) in self.state_ids
|
268
|
+
|
269
|
+
|
270
|
+
@contextlib.contextmanager
|
271
|
+
def catch_new_states(
|
272
|
+
state_tag: str = None,
|
273
|
+
state_to_exclude: Filter = Nothing()
|
274
|
+
) -> Generator[StateCatcher, None, None]:
|
275
|
+
"""
|
276
|
+
A context manager that catches and tracks new states created within its scope.
|
277
|
+
|
278
|
+
This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
|
279
|
+
new_state_catcher list. It allows for tracking and managing new states created
|
280
|
+
within the context.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
state_tag (str, optional): A string tag to associate with the caught states.
|
284
|
+
Defaults to None.
|
285
|
+
state_to_exclude (Filter, optional): A filter object to specify which states
|
286
|
+
should be excluded from catching. Defaults to Nothing(), which excludes no states.
|
287
|
+
|
288
|
+
Yields:
|
289
|
+
Catcher: A Catcher object that can be used to access and manage the
|
290
|
+
newly created states within the context.
|
291
|
+
|
292
|
+
Example::
|
293
|
+
|
294
|
+
with catch_new_states("my_tag") as catcher:
|
295
|
+
# Create new states here
|
296
|
+
# They will be caught and tagged with "my_tag"
|
297
|
+
# Access caught states through catcher object
|
298
|
+
"""
|
299
|
+
try:
|
300
|
+
catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
|
301
|
+
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
302
|
+
yield catcher
|
303
|
+
finally:
|
304
|
+
TRACE_CONTEXT.new_state_catcher.pop()
|
305
|
+
|
306
|
+
|
307
|
+
def maybe_state(val: Any) -> Any:
|
308
|
+
"""
|
309
|
+
Extracts the value from a State object if given, otherwise returns the input value.
|
310
|
+
|
311
|
+
This function is useful for handling both State objects and raw values uniformly.
|
312
|
+
If the input is a State object, it returns the value stored in that State.
|
313
|
+
If the input is not a State object, it returns the input as is.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
val (Any): The input value, which can be either a State object or any other type.
|
317
|
+
|
318
|
+
Returns:
|
319
|
+
Any: The value stored in the State if the input is a State object,
|
320
|
+
otherwise the input value itself.
|
321
|
+
"""
|
122
322
|
if isinstance(val, State):
|
123
323
|
return val.value
|
124
324
|
else:
|
@@ -126,7 +326,7 @@ def maybe_state(val: Any):
|
|
126
326
|
|
127
327
|
|
128
328
|
@contextlib.contextmanager
|
129
|
-
def check_state_jax_tracer(val: bool = True) -> None:
|
329
|
+
def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
|
130
330
|
"""
|
131
331
|
The context manager to check whether the state is valid to trace.
|
132
332
|
|
@@ -160,11 +360,16 @@ def check_state_jax_tracer(val: bool = True) -> None:
|
|
160
360
|
@dataclasses.dataclass
|
161
361
|
class StateMetadata(Generic[A]):
|
162
362
|
"""
|
163
|
-
|
363
|
+
A dataclass representing metadata for a state object.
|
164
364
|
|
165
|
-
|
166
|
-
|
167
|
-
|
365
|
+
This class encapsulates the raw value of a state along with associated metadata.
|
366
|
+
It is generic over the type of the raw value.
|
367
|
+
|
368
|
+
Attributes:
|
369
|
+
raw_value (A): The raw value of the state. The type A is a generic type parameter.
|
370
|
+
metadata (Mapping[str, Any]): A mapping of string keys to arbitrary values,
|
371
|
+
representing additional metadata for the state.
|
372
|
+
Defaults to an empty dictionary.
|
168
373
|
"""
|
169
374
|
raw_value: A
|
170
375
|
metadata: Mapping[str, Any] = dataclasses.field(default_factory=dict)
|
@@ -172,7 +377,30 @@ class StateMetadata(Generic[A]):
|
|
172
377
|
|
173
378
|
def with_metadata(initializer: F, **metadata: Any) -> F:
|
174
379
|
"""
|
175
|
-
A decorator
|
380
|
+
A decorator that adds metadata to a state initialization function.
|
381
|
+
|
382
|
+
This decorator wraps the given initializer function, allowing additional
|
383
|
+
metadata to be associated with the state it creates. The metadata is
|
384
|
+
incorporated into a StateMetadata object along with the state's value.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
initializer (F): The original state initialization function to be wrapped.
|
388
|
+
**metadata (Any): Arbitrary keyword arguments representing metadata
|
389
|
+
to be associated with the state.
|
390
|
+
|
391
|
+
Returns:
|
392
|
+
F: A wrapped version of the initializer function that returns a
|
393
|
+
StateMetadata object containing both the original state value
|
394
|
+
and the provided metadata.
|
395
|
+
|
396
|
+
Example::
|
397
|
+
@with_metadata(tag='model_param')
|
398
|
+
def init_weights(shape):
|
399
|
+
return np.zeros(shape)
|
400
|
+
|
401
|
+
state = init_weights((100, 100))
|
402
|
+
# state is now a StateMetadata object with the initialized weights
|
403
|
+
# and the 'tag' metadata
|
176
404
|
"""
|
177
405
|
|
178
406
|
@wraps(initializer)
|
@@ -186,26 +414,56 @@ def _get_trace_stack_level() -> int:
|
|
186
414
|
return len(TRACE_CONTEXT.state_stack)
|
187
415
|
|
188
416
|
|
189
|
-
class State(Generic[A],
|
417
|
+
class State(Generic[A], PrettyObject):
|
190
418
|
"""
|
191
|
-
|
192
|
-
|
193
|
-
To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
|
419
|
+
A generic class representing a dynamic data pointer in the BrainState framework.
|
194
420
|
|
195
|
-
|
421
|
+
The State class serves as a base for various types of state objects used to
|
422
|
+
manage and track dynamic data within a program. It provides mechanisms for
|
423
|
+
value storage, metadata management, and integration with the BrainState
|
424
|
+
tracing system.
|
196
425
|
|
197
|
-
|
198
|
-
|
426
|
+
Type Parameters:
|
427
|
+
A: The type of the value stored in the state.
|
199
428
|
|
200
|
-
|
429
|
+
Attributes:
|
430
|
+
name (Optional[str]): An optional name for the state.
|
431
|
+
value (PyTree): The actual value stored in the state.
|
432
|
+
tag (Optional[str]): An optional tag for categorizing or grouping states.
|
201
433
|
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
434
|
+
Args:
|
435
|
+
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
436
|
+
The initial value for the state. Can be a PyTree of array-like objects
|
437
|
+
or a StateMetadata object.
|
438
|
+
name (Optional[str]): An optional name for the state.
|
439
|
+
**metadata: Additional metadata to be stored with the state.
|
440
|
+
|
441
|
+
Example:
|
442
|
+
>>> class MyState(State):
|
443
|
+
... pass
|
444
|
+
>>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
|
445
|
+
>>> print(state.value)
|
446
|
+
[[0. 0. 0.]
|
447
|
+
[0. 0. 0.]
|
448
|
+
[0. 0. 0.]]
|
449
|
+
|
450
|
+
Note:
|
451
|
+
- Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
|
452
|
+
RandomState) are typically used for specific purposes in a program.
|
453
|
+
- The class integrates with BrainState's tracing system to track state
|
454
|
+
creation and modifications.
|
455
|
+
|
456
|
+
The typical examples of :py:class:`~.State` subclass are:
|
457
|
+
|
458
|
+
- :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
|
459
|
+
- :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
|
460
|
+
- :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
|
461
|
+
- :py:class:`RandomState`: The random generator state, which is used to store the random key in the program.
|
206
462
|
|
207
463
|
Args:
|
208
|
-
|
464
|
+
value: PyTree. It can be anything as a pyTree.
|
465
|
+
name: Optional[str]. The name of the state.
|
466
|
+
tag: Optional[str]. The tag of the state.
|
209
467
|
"""
|
210
468
|
__module__ = 'brainstate'
|
211
469
|
_level: int
|
@@ -221,6 +479,25 @@ class State(Generic[A], PrettyReprTree):
|
|
221
479
|
name: Optional[str] = None,
|
222
480
|
**metadata: Any
|
223
481
|
):
|
482
|
+
"""
|
483
|
+
Initialize a new HiddenState instance.
|
484
|
+
|
485
|
+
This constructor sets up the initial state for a hidden state in a dynamic model,
|
486
|
+
handling various input types and metadata.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
490
|
+
The initial value for the hidden state. Can be a PyTree of array-like objects
|
491
|
+
or a StateMetadata object containing both value and metadata.
|
492
|
+
name (Optional[str], optional): A name for the hidden state. Defaults to None.
|
493
|
+
**metadata: Additional metadata to be stored with the hidden state, including:
|
494
|
+
- tag (Optional[str]): A tag for categorizing or grouping states.
|
495
|
+
- Any other custom metadata fields.
|
496
|
+
|
497
|
+
Note:
|
498
|
+
This method initializes the hidden state, processes the input value and metadata,
|
499
|
+
sets up internal attributes, and records the state initialization.
|
500
|
+
"""
|
224
501
|
tag = metadata.pop('tag', None)
|
225
502
|
|
226
503
|
# set the value and metadata
|
@@ -231,12 +508,14 @@ class State(Generic[A], PrettyReprTree):
|
|
231
508
|
value = value.value
|
232
509
|
|
233
510
|
# update metadata
|
234
|
-
metadata.update(
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
511
|
+
metadata.update(
|
512
|
+
_value=value,
|
513
|
+
_level=_get_trace_stack_level(),
|
514
|
+
_source_info=source_info_util.current(),
|
515
|
+
_name=name,
|
516
|
+
_been_writen=False,
|
517
|
+
tag=tag,
|
518
|
+
)
|
240
519
|
|
241
520
|
# avoid using self._setattr to avoid the check
|
242
521
|
vars(self).update(metadata)
|
@@ -244,6 +523,26 @@ class State(Generic[A], PrettyReprTree):
|
|
244
523
|
# record the state initialization
|
245
524
|
record_state_init(self)
|
246
525
|
|
526
|
+
def decrease_stack_level(self):
|
527
|
+
"""
|
528
|
+
Decrease the stack level of the state by one, ensuring it doesn't go below zero.
|
529
|
+
|
530
|
+
This method is used to adjust the stack level of the state, typically when
|
531
|
+
exiting a nested context or scope. It ensures that the level never becomes
|
532
|
+
negative.
|
533
|
+
"""
|
534
|
+
self._level = max(self._level - 1, 0)
|
535
|
+
|
536
|
+
def increase_stack_level(self):
|
537
|
+
"""
|
538
|
+
Increase the stack level of the state by one.
|
539
|
+
|
540
|
+
This method is used to adjust the stack level of the state, typically when
|
541
|
+
entering a nested context or scope. It increments the internal level counter
|
542
|
+
by one.
|
543
|
+
"""
|
544
|
+
self._level = self._level + 1
|
545
|
+
|
247
546
|
@property
|
248
547
|
def name(self) -> Optional[str]:
|
249
548
|
"""
|
@@ -438,19 +737,19 @@ class State(Generic[A], PrettyReprTree):
|
|
438
737
|
|
439
738
|
def __pretty_repr_item__(self, k, v):
|
440
739
|
if k in ['_level', '_source_info', '_been_writen']:
|
441
|
-
return None
|
740
|
+
return None
|
442
741
|
if k == '_value':
|
443
|
-
return 'value', v
|
742
|
+
return 'value', jax.tree.map(shaped_abstractify, v)
|
444
743
|
|
445
744
|
if k == '_name':
|
446
745
|
if self.name is None:
|
447
|
-
return None
|
746
|
+
return None
|
448
747
|
else:
|
449
748
|
return 'name', v
|
450
749
|
|
451
750
|
if k == 'tag':
|
452
751
|
if self.tag is None:
|
453
|
-
return None
|
752
|
+
return None
|
454
753
|
else:
|
455
754
|
return 'tag', v
|
456
755
|
|
@@ -465,34 +764,137 @@ class State(Generic[A], PrettyReprTree):
|
|
465
764
|
"""
|
466
765
|
return hash(id(self))
|
467
766
|
|
767
|
+
def numel(self) -> int:
|
768
|
+
"""
|
769
|
+
Calculate the total number of elements in the state value.
|
770
|
+
|
771
|
+
This method traverses the state's value, which may be a nested structure (PyTree),
|
772
|
+
and computes the sum of sizes of all leaf nodes.
|
773
|
+
|
774
|
+
Returns:
|
775
|
+
int: The total number of elements across all arrays in the state value.
|
776
|
+
For scalar values, this will be 1. For arrays or nested structures,
|
777
|
+
it will be the sum of the sizes of all contained arrays.
|
778
|
+
|
779
|
+
Note:
|
780
|
+
This method uses jax.tree.leaves to flatten any nested structure in the state value,
|
781
|
+
and jax.numpy.size to compute the size of each leaf node.
|
782
|
+
"""
|
783
|
+
sizes = [jax.numpy.size(val) for val in jax.tree.leaves(self._value)]
|
784
|
+
return sum(sizes)
|
785
|
+
|
468
786
|
|
469
787
|
def record_state_init(st: State[A]):
|
470
|
-
|
788
|
+
"""
|
789
|
+
Record the initialization of a new :class:`State` object.
|
790
|
+
|
791
|
+
This function iterates through all registered state catchers in the current
|
792
|
+
trace context and appends the newly initialized state to each catcher.
|
793
|
+
|
794
|
+
Args:
|
795
|
+
st (State[A]): The newly initialized :class:`State` object to be recorded.
|
796
|
+
|
797
|
+
Note:
|
798
|
+
This function is typically called internally when a new :class:`State` object
|
799
|
+
is created to ensure proper tracking and management of states within
|
800
|
+
the current execution context.
|
801
|
+
"""
|
802
|
+
trace: StateCatcher
|
471
803
|
for trace in TRACE_CONTEXT.new_state_catcher:
|
472
804
|
trace.append(st)
|
473
805
|
|
474
806
|
|
475
807
|
def record_state_value_read(st: State[A]):
|
808
|
+
"""
|
809
|
+
Record that a state's value has been read in all relevant trace stacks.
|
810
|
+
|
811
|
+
This function iterates through all state trace stacks at or above the
|
812
|
+
state's stack level in the current trace context, and records that
|
813
|
+
the given state's value has been read.
|
814
|
+
|
815
|
+
Args:
|
816
|
+
st (State[A]): The state object whose value read is being recorded.
|
817
|
+
'A' is a generic type parameter representing the
|
818
|
+
type of the state's value.
|
819
|
+
|
820
|
+
Note:
|
821
|
+
This function modifies the state trace stacks in the current
|
822
|
+
trace context but does not return any value.
|
823
|
+
"""
|
476
824
|
trace: StateTraceStack
|
477
825
|
for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
|
478
826
|
trace.read_its_value(st)
|
479
827
|
|
480
828
|
|
481
829
|
def record_state_value_write(st: State[A]):
|
830
|
+
"""
|
831
|
+
Record that a state's value has been written in all relevant trace stacks.
|
832
|
+
|
833
|
+
This function iterates through all state trace stacks at or above the
|
834
|
+
state's stack level in the current trace context, and records that
|
835
|
+
the given state's value has been written.
|
836
|
+
|
837
|
+
Args:
|
838
|
+
st (State[A]): The state object whose value write is being recorded.
|
839
|
+
'A' is a generic type parameter representing the
|
840
|
+
type of the state's value.
|
841
|
+
|
842
|
+
Note:
|
843
|
+
This function modifies the state trace stacks in the current
|
844
|
+
trace context but does not return any value.
|
845
|
+
"""
|
482
846
|
trace: StateTraceStack
|
483
847
|
for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
|
484
848
|
trace.write_its_value(st)
|
485
849
|
|
486
850
|
|
487
851
|
def record_state_value_restore(st: State[A]):
|
852
|
+
"""
|
853
|
+
Record that a state's value has been restored.
|
854
|
+
|
855
|
+
This function is used to indicate that a state's value has been restored
|
856
|
+
to a previous value. It internally calls the record_state_value_read
|
857
|
+
function to mark the state as having been accessed.
|
858
|
+
|
859
|
+
Args:
|
860
|
+
st (State[A]): The state object whose value restoration is being recorded.
|
861
|
+
'A' is a generic type parameter representing the
|
862
|
+
type of the state's value.
|
863
|
+
|
864
|
+
See Also:
|
865
|
+
record_state_value_read: Record that a state's value has been read.
|
866
|
+
|
867
|
+
Note:
|
868
|
+
This function does not actually restore the state's value; it only
|
869
|
+
records that a restoration has occurred.
|
870
|
+
"""
|
488
871
|
record_state_value_read(st)
|
489
872
|
|
490
873
|
|
491
874
|
class ShortTermState(State):
|
492
875
|
"""
|
493
|
-
|
876
|
+
A class representing short-term state in a program.
|
494
877
|
|
495
|
-
|
878
|
+
:class:`ShortTermState` is used to store temporary or transient data that is only relevant
|
879
|
+
for a short duration within the program's execution. This class extends the base
|
880
|
+
State class, inheriting its properties and methods while specifically denoting
|
881
|
+
the short-term nature of the stored data.
|
882
|
+
|
883
|
+
For example, in a machine learning training process, the gradients of the model
|
884
|
+
would typically be represented as :class:`ShortTermState`, as they are computed and used
|
885
|
+
within each iteration but not necessarily preserved across iterations.
|
886
|
+
|
887
|
+
Attributes:
|
888
|
+
Inherits all attributes from the base State class.
|
889
|
+
|
890
|
+
Note:
|
891
|
+
This class does not introduce new methods or attributes beyond those
|
892
|
+
inherited from the State class. Its primary purpose is to semantically
|
893
|
+
distinguish short-term states from other types of states in the program.
|
894
|
+
|
895
|
+
Example:
|
896
|
+
>>> gradient = ShortTermState(np.zeros(100), name="model_gradient")
|
897
|
+
>>> intermediate_result = ShortTermState({}, name="layer_activations")
|
496
898
|
"""
|
497
899
|
|
498
900
|
__module__ = 'brainstate'
|
@@ -502,7 +904,25 @@ class LongTermState(State):
|
|
502
904
|
"""
|
503
905
|
The long-term state, which is used to store the long-term data in the program.
|
504
906
|
|
505
|
-
|
907
|
+
This class extends the base :class:`State` class and is specifically designed to represent
|
908
|
+
and manage long-term data within a program. Long-term states are typically used
|
909
|
+
for data that persists across multiple iterations or epochs of a process.
|
910
|
+
|
911
|
+
For example, in a training process, the weights of the model are considered
|
912
|
+
long-term states as they are updated and maintained throughout the entire
|
913
|
+
training procedure.
|
914
|
+
|
915
|
+
Attributes:
|
916
|
+
Inherits all attributes from the base :class:`State` class.
|
917
|
+
|
918
|
+
Note:
|
919
|
+
This class does not introduce new methods or attributes beyond those
|
920
|
+
inherited from the :class:`State` class. Its primary purpose is to semantically
|
921
|
+
distinguish long-term states from other types of states in the program.
|
922
|
+
|
923
|
+
Example:
|
924
|
+
>>> model_weights = LongTermState(np.random.randn(100, 100), name="model_weights")
|
925
|
+
>>> optimizer_state = LongTermState({}, name="optimizer_state")
|
506
926
|
"""
|
507
927
|
|
508
928
|
__module__ = 'brainstate'
|
@@ -511,6 +931,24 @@ class LongTermState(State):
|
|
511
931
|
class BatchState(LongTermState):
|
512
932
|
"""
|
513
933
|
The batch state, which is used to store the batch data in the program.
|
934
|
+
|
935
|
+
This class extends :class:`LongTermState` and is specifically designed to represent
|
936
|
+
and manage batch data within a program. It provides a way to encapsulate
|
937
|
+
batch-related information and associated metadata, facilitating operations
|
938
|
+
like batch processing in machine learning or data analysis tasks.
|
939
|
+
|
940
|
+
Attributes:
|
941
|
+
Inherits all attributes from :class:`LongTermState`.
|
942
|
+
|
943
|
+
Note:
|
944
|
+
This class does not introduce new methods or attributes beyond those
|
945
|
+
inherited from :class:`LongTermState`. Its primary purpose is to semantically
|
946
|
+
distinguish batch states from other types of long-term states
|
947
|
+
in the program.
|
948
|
+
|
949
|
+
Example:
|
950
|
+
>>> batch_data = BatchState(np.array([1, 2, 3, 4, 5]), name="current_batch")
|
951
|
+
>>> batch_labels = BatchState(np.array([0, 1, 0, 1, 1]), name="batch_labels")
|
514
952
|
"""
|
515
953
|
|
516
954
|
__module__ = 'brainstate'
|
@@ -519,6 +957,24 @@ class BatchState(LongTermState):
|
|
519
957
|
class HiddenState(ShortTermState):
|
520
958
|
"""
|
521
959
|
The hidden state, which is used to store the hidden data in a dynamic model.
|
960
|
+
|
961
|
+
This class extends :class:`ShortTermState` and is specifically designed to represent
|
962
|
+
and manage hidden states within dynamic models, such as recurrent neural networks.
|
963
|
+
It provides a way to encapsulate hidden state values and associated metadata,
|
964
|
+
facilitating operations like state updates during model execution.
|
965
|
+
|
966
|
+
Attributes:
|
967
|
+
Inherits all attributes from :class:`ShortTermState`.
|
968
|
+
|
969
|
+
Note:
|
970
|
+
This class does not introduce new methods or attributes beyond those
|
971
|
+
inherited from :class:`ShortTermState`. Its primary purpose is to semantically
|
972
|
+
distinguish hidden states from other types of short-term states
|
973
|
+
in dynamic models.
|
974
|
+
|
975
|
+
Example:
|
976
|
+
>>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
|
977
|
+
>>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
|
522
978
|
"""
|
523
979
|
|
524
980
|
__module__ = 'brainstate'
|
@@ -527,12 +983,30 @@ class HiddenState(ShortTermState):
|
|
527
983
|
class ParamState(LongTermState):
|
528
984
|
"""
|
529
985
|
The parameter state, which is used to store the trainable parameters in the model.
|
986
|
+
|
987
|
+
This class extends :class:`LongTermState` and is specifically designed to represent
|
988
|
+
and manage trainable parameters within a neural network or machine learning model.
|
989
|
+
It provides a way to encapsulate parameter values and associated metadata,
|
990
|
+
facilitating operations like parameter updates during training.
|
991
|
+
|
992
|
+
Attributes:
|
993
|
+
Inherits all attributes from :class:`LongTermState`.
|
994
|
+
|
995
|
+
Note:
|
996
|
+
This class does not introduce new methods or attributes beyond those
|
997
|
+
inherited from :class:`LongTermState`. Its primary purpose is to semantically
|
998
|
+
distinguish parameter states from other types of long-term states
|
999
|
+
in the model.
|
1000
|
+
|
1001
|
+
Example:
|
1002
|
+
>>> weight = ParamState(np.random.randn(10, 10), name="layer1_weights")
|
1003
|
+
>>> bias = ParamState(np.zeros(10), name="layer1_bias")
|
530
1004
|
"""
|
531
1005
|
|
532
1006
|
__module__ = 'brainstate'
|
533
1007
|
|
534
1008
|
|
535
|
-
class
|
1009
|
+
class FakeState:
|
536
1010
|
"""
|
537
1011
|
The faked state, which is used to store the faked data in the program.
|
538
1012
|
"""
|
@@ -540,26 +1014,63 @@ class FakedState:
|
|
540
1014
|
__module__ = 'brainstate'
|
541
1015
|
|
542
1016
|
def __init__(self, value: Any, name: Optional[str] = None):
|
1017
|
+
"""
|
1018
|
+
Initialize a FakeState instance.
|
1019
|
+
|
1020
|
+
Args:
|
1021
|
+
value (Any): The value to be stored in the fake state.
|
1022
|
+
name (Optional[str], optional): The name of the fake state. Defaults to None.
|
1023
|
+
"""
|
543
1024
|
self._value = value
|
544
1025
|
self._name = name
|
545
1026
|
|
546
1027
|
@property
|
547
1028
|
def value(self) -> Any:
|
1029
|
+
"""
|
1030
|
+
Get the value stored in the fake state.
|
1031
|
+
|
1032
|
+
Returns:
|
1033
|
+
Any: The value stored in the fake state.
|
1034
|
+
"""
|
548
1035
|
return self._value
|
549
1036
|
|
550
1037
|
@value.setter
|
551
1038
|
def value(self, v) -> None:
|
1039
|
+
"""
|
1040
|
+
Set the value of the fake state.
|
1041
|
+
|
1042
|
+
Args:
|
1043
|
+
v (Any): The new value to be stored in the fake state.
|
1044
|
+
"""
|
552
1045
|
self._value = v
|
553
1046
|
|
554
1047
|
def __repr__(self) -> str:
|
1048
|
+
"""
|
1049
|
+
Return a string representation of the FakeState instance.
|
1050
|
+
|
1051
|
+
Returns:
|
1052
|
+
str: A string representation of the FakeState instance.
|
1053
|
+
"""
|
555
1054
|
return f'FakedState(value={self._value})'
|
556
1055
|
|
557
1056
|
@property
|
558
1057
|
def name(self) -> Optional[str]:
|
1058
|
+
"""
|
1059
|
+
Get the name of the fake state.
|
1060
|
+
|
1061
|
+
Returns:
|
1062
|
+
Optional[str]: The name of the fake state, or None if not set.
|
1063
|
+
"""
|
559
1064
|
return self._name
|
560
1065
|
|
561
1066
|
@name.setter
|
562
1067
|
def name(self, name: str) -> None:
|
1068
|
+
"""
|
1069
|
+
Set the name of the fake state.
|
1070
|
+
|
1071
|
+
Args:
|
1072
|
+
name (str): The new name for the fake state.
|
1073
|
+
"""
|
563
1074
|
self._name = name
|
564
1075
|
|
565
1076
|
|
@@ -622,20 +1133,73 @@ class StateDictManager(DictManager):
|
|
622
1133
|
|
623
1134
|
class StateTraceStack(Generic[A]):
|
624
1135
|
"""
|
625
|
-
|
1136
|
+
A stack for tracing and managing states during program execution.
|
1137
|
+
|
1138
|
+
``StateTraceStack`` is used to automatically trace and manage State objects,
|
1139
|
+
keeping track of which states are read from or written to during the
|
1140
|
+
execution of a function or block of code. It provides methods for
|
1141
|
+
recording state accesses, retrieving state values, and managing the
|
1142
|
+
lifecycle of states within a tracing context.
|
1143
|
+
|
1144
|
+
The class is generic over type A, allowing for type-safe usage with
|
1145
|
+
different types of State objects.
|
1146
|
+
|
1147
|
+
Attributes:
|
1148
|
+
states (List[State]): A list of all State objects encountered during tracing.
|
1149
|
+
been_writen (List[bool]): A parallel list to states, indicating whether each state has been written to.
|
1150
|
+
_state_id_index (dict): A dictionary mapping state ids to their index in the states list.
|
1151
|
+
_original_state_values (List): A list of the original values of all states when first encountered.
|
1152
|
+
_jax_trace_new_arg (Callable): A function used to transform state values during tracing.
|
1153
|
+
|
1154
|
+
Methods:
|
1155
|
+
__enter__: Enters a new tracing context.
|
1156
|
+
__exit__: Exits the current tracing context.
|
1157
|
+
read_its_value: Records a read operation on a state.
|
1158
|
+
write_its_value: Records a write operation on a state.
|
1159
|
+
get_state_values: Retrieves the current values of all traced states.
|
1160
|
+
recovery_original_values: Restores all states to their original values.
|
1161
|
+
merge: Merges multiple ``StateTraceStack`` instances.
|
1162
|
+
get_read_states: Retrieves states that were read during tracing.
|
1163
|
+
get_read_state_values: Retrieves values of states that were read during tracing.
|
1164
|
+
|
1165
|
+
The ``StateTraceStack`` is a crucial component in implementing state-based
|
1166
|
+
computations and is particularly useful in scenarios involving automatic
|
1167
|
+
differentiation or other forms of program transformation.
|
626
1168
|
"""
|
627
1169
|
|
628
|
-
def __init__(
|
1170
|
+
def __init__(
|
1171
|
+
self,
|
1172
|
+
new_arg: Callable = None,
|
1173
|
+
name: Optional[str] = None,
|
1174
|
+
):
|
1175
|
+
self.name = name
|
629
1176
|
self.states: List[State] = []
|
630
1177
|
self.been_writen: List[bool] = [] # False: read, True: write
|
631
1178
|
self._state_id_index = dict()
|
632
1179
|
self._original_state_values = []
|
633
1180
|
self._jax_trace_new_arg: Callable = new_arg
|
1181
|
+
self._stack_level = None
|
1182
|
+
|
1183
|
+
def __str__(self) -> str:
|
1184
|
+
_stack_level = self.name if self._stack_level is None else self._stack_level
|
1185
|
+
if _stack_level is None:
|
1186
|
+
_stack_level = ''
|
1187
|
+
return f"{self.__class__.__name__}({_stack_level})"
|
634
1188
|
|
635
1189
|
@property
|
636
1190
|
def original_state_values(self) -> Tuple[PyTree, ...]:
|
637
1191
|
"""
|
638
|
-
|
1192
|
+
Get the original values of all states in the StateTraceStack.
|
1193
|
+
|
1194
|
+
This property provides access to the initial values of all states
|
1195
|
+
that were captured when they were first added to the stack. It's
|
1196
|
+
useful for comparing current state values with their original values
|
1197
|
+
or for reverting states to their initial condition.
|
1198
|
+
|
1199
|
+
Returns:
|
1200
|
+
Tuple[PyTree, ...]: A tuple containing the original values of all
|
1201
|
+
states in the order they were added to the stack. Each element
|
1202
|
+
is a PyTree representing the structure and values of a state.
|
639
1203
|
"""
|
640
1204
|
return tuple(self._original_state_values)
|
641
1205
|
|
@@ -643,12 +1207,30 @@ class StateTraceStack(Generic[A]):
|
|
643
1207
|
self._jax_trace_new_arg = new_arg
|
644
1208
|
|
645
1209
|
def new_arg(self, state: State) -> None:
|
1210
|
+
"""
|
1211
|
+
Apply a transformation to the value of a given state using a predefined function.
|
1212
|
+
|
1213
|
+
This method is used internally to transform the value of a state during tracing.
|
1214
|
+
If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
|
1215
|
+
function to each element of the state's value using JAX's tree mapping.
|
1216
|
+
|
1217
|
+
Args:
|
1218
|
+
state (State): The State object whose value needs to be transformed.
|
1219
|
+
|
1220
|
+
Returns:
|
1221
|
+
None: This function modifies the state in-place and doesn't return anything.
|
1222
|
+
|
1223
|
+
Note:
|
1224
|
+
This method is intended for internal use and relies on the presence of
|
1225
|
+
a ``_jax_trace_new_arg`` function, which should be set separately.
|
1226
|
+
"""
|
646
1227
|
if self._jax_trace_new_arg is not None:
|
647
1228
|
# internal use
|
648
1229
|
state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
|
649
1230
|
|
650
1231
|
def __enter__(self) -> 'StateTraceStack':
|
651
1232
|
TRACE_CONTEXT.state_stack.append(self)
|
1233
|
+
self._stack_level = ' / '.join([st.name for st in TRACE_CONTEXT.state_stack if st.name is not None])
|
652
1234
|
return self
|
653
1235
|
|
654
1236
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
@@ -656,10 +1238,22 @@ class StateTraceStack(Generic[A]):
|
|
656
1238
|
|
657
1239
|
def read_its_value(self, state: State) -> None:
|
658
1240
|
"""
|
659
|
-
|
660
|
-
|
1241
|
+
Record that a state's value has been read during tracing.
|
1242
|
+
|
1243
|
+
This method marks the given state as having been read in the current
|
1244
|
+
tracing context. If the state hasn't been encountered before, it adds
|
1245
|
+
it to the internal tracking structures and applies any necessary
|
1246
|
+
transformations via the new_arg method.
|
1247
|
+
|
661
1248
|
Args:
|
662
|
-
|
1249
|
+
state (State): The State object whose value is being read.
|
1250
|
+
|
1251
|
+
Returns:
|
1252
|
+
None
|
1253
|
+
|
1254
|
+
Note:
|
1255
|
+
This method updates the internal tracking of state accesses.
|
1256
|
+
It doesn't actually read or return the state's value.
|
663
1257
|
"""
|
664
1258
|
id_ = id(state)
|
665
1259
|
if id_ not in self._state_id_index:
|
@@ -671,10 +1265,21 @@ class StateTraceStack(Generic[A]):
|
|
671
1265
|
|
672
1266
|
def write_its_value(self, state: State) -> None:
|
673
1267
|
"""
|
674
|
-
|
1268
|
+
Record that a state's value has been written to during tracing.
|
1269
|
+
|
1270
|
+
This method marks the given state as having been written to in the current
|
1271
|
+
tracing context. If the state hasn't been encountered before, it first
|
1272
|
+
records it as being read before marking it as written.
|
675
1273
|
|
676
1274
|
Args:
|
677
|
-
|
1275
|
+
state (State): The State object whose value is being written to.
|
1276
|
+
|
1277
|
+
Returns:
|
1278
|
+
None
|
1279
|
+
|
1280
|
+
Note:
|
1281
|
+
This method updates the internal tracking of state modifications.
|
1282
|
+
It doesn't actually modify the state's value.
|
678
1283
|
"""
|
679
1284
|
id_ = id(state)
|
680
1285
|
if id_ not in self._state_id_index:
|
@@ -682,10 +1287,37 @@ class StateTraceStack(Generic[A]):
|
|
682
1287
|
index = self._state_id_index[id_]
|
683
1288
|
self.been_writen[index] = True
|
684
1289
|
|
685
|
-
def get_state_values(
|
686
|
-
|
687
|
-
|
688
|
-
|
1290
|
+
def get_state_values(
|
1291
|
+
self,
|
1292
|
+
separate: bool = False,
|
1293
|
+
replace: bool = False
|
1294
|
+
) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
1295
|
+
"""
|
1296
|
+
Retrieve the values of all states in the StateTraceStack.
|
1297
|
+
|
1298
|
+
This method returns the values of all states, optionally separating them
|
1299
|
+
into written and read states, and optionally replacing values with None
|
1300
|
+
for states that weren't accessed in a particular way.
|
1301
|
+
|
1302
|
+
Args:
|
1303
|
+
separate (bool, optional): If True, separate the values into written
|
1304
|
+
and read states. If False, return all values in a single sequence.
|
1305
|
+
Defaults to False.
|
1306
|
+
replace (bool, optional): If True and separate is True, replace values
|
1307
|
+
with None for states that weren't written/read. If False, only
|
1308
|
+
include values for states that were written/read. Defaults to False.
|
1309
|
+
|
1310
|
+
Returns:
|
1311
|
+
Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
1312
|
+
If separate is False:
|
1313
|
+
A sequence of all state values.
|
1314
|
+
If separate is True:
|
1315
|
+
A tuple containing two sequences:
|
1316
|
+
- The first sequence contains values of written states.
|
1317
|
+
- The second sequence contains values of read states.
|
1318
|
+
If replace is True, these sequences will have None for
|
1319
|
+
states that weren't written/read respectively.
|
1320
|
+
|
689
1321
|
"""
|
690
1322
|
if separate:
|
691
1323
|
if replace:
|
@@ -711,7 +1343,18 @@ class StateTraceStack(Generic[A]):
|
|
711
1343
|
|
712
1344
|
def recovery_original_values(self) -> None:
|
713
1345
|
"""
|
714
|
-
|
1346
|
+
Restore the original values of all states in the StateTraceStack.
|
1347
|
+
|
1348
|
+
This method iterates through all states in the stack and restores
|
1349
|
+
their values to the original ones that were captured when the states
|
1350
|
+
were first added to the stack. This is useful for reverting changes
|
1351
|
+
made during tracing or for resetting the states to their initial condition.
|
1352
|
+
|
1353
|
+
Note:
|
1354
|
+
This method modifies the states in-place.
|
1355
|
+
|
1356
|
+
Returns:
|
1357
|
+
None
|
715
1358
|
"""
|
716
1359
|
for st, val in zip(self.states, self._original_state_values):
|
717
1360
|
# internal use
|
@@ -719,7 +1362,22 @@ class StateTraceStack(Generic[A]):
|
|
719
1362
|
|
720
1363
|
def merge(self, *traces) -> 'StateTraceStack':
|
721
1364
|
"""
|
722
|
-
Merge other state traces
|
1365
|
+
Merge other state traces into the current ``StateTraceStack``.
|
1366
|
+
|
1367
|
+
This method combines the states, their write status, and original values from
|
1368
|
+
other ``StateTraceStack`` instances into the current one. If a state from another
|
1369
|
+
trace is not present in the current trace, it is added. If a state is already
|
1370
|
+
present, its write status is updated if necessary.
|
1371
|
+
|
1372
|
+
Args:
|
1373
|
+
*traces: Variable number of ``StateTraceStack`` instances to be merged into
|
1374
|
+
the current instance.
|
1375
|
+
|
1376
|
+
Returns:
|
1377
|
+
StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
|
1378
|
+
|
1379
|
+
Note:
|
1380
|
+
This method modifies the current ``StateTraceStack`` in-place and also returns it.
|
723
1381
|
"""
|
724
1382
|
trace: StateTraceStack
|
725
1383
|
for trace in traces:
|
@@ -735,10 +1393,22 @@ class StateTraceStack(Generic[A]):
|
|
735
1393
|
|
736
1394
|
def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
|
737
1395
|
"""
|
738
|
-
|
739
|
-
|
1396
|
+
Retrieve the states that were read during the function execution.
|
1397
|
+
|
1398
|
+
This method returns the states that were accessed (read from) during
|
1399
|
+
the traced function's execution. It can optionally replace written
|
1400
|
+
states with None.
|
1401
|
+
|
1402
|
+
Args:
|
1403
|
+
replace_writen (bool, optional): If True, replace written states with None
|
1404
|
+
in the returned tuple. If False, exclude written states entirely from
|
1405
|
+
the result. Defaults to False.
|
1406
|
+
|
740
1407
|
Returns:
|
741
|
-
|
1408
|
+
Tuple[State, ...]: A tuple containing the read states.
|
1409
|
+
If replace_writen is True, the tuple will have the same length as the
|
1410
|
+
total number of states, with None for written states.
|
1411
|
+
If replace_writen is False, the tuple will only contain read-only states.
|
742
1412
|
"""
|
743
1413
|
if replace_writen:
|
744
1414
|
return tuple([st if not been_writen else None
|
@@ -748,23 +1418,49 @@ class StateTraceStack(Generic[A]):
|
|
748
1418
|
|
749
1419
|
def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
|
750
1420
|
"""
|
751
|
-
|
752
|
-
|
1421
|
+
Retrieve the values of states that were read during the function execution.
|
1422
|
+
|
1423
|
+
This method returns the values of states that were accessed (read from) during
|
1424
|
+
the traced function's execution. It can optionally replace written states with None.
|
1425
|
+
|
1426
|
+
Args:
|
1427
|
+
replace_writen (bool, optional): If True, replace the values of written
|
1428
|
+
states with None in the returned tuple. If False, exclude written
|
1429
|
+
states entirely from the result. Defaults to False.
|
1430
|
+
|
753
1431
|
Returns:
|
754
|
-
|
1432
|
+
Tuple[PyTree, ...]: A tuple containing the values of read states.
|
1433
|
+
If replace_writen is True, the tuple will have the same length as the
|
1434
|
+
total number of states, with None for written states.
|
1435
|
+
If replace_writen is False, the tuple will only contain values of
|
1436
|
+
read-only states.
|
755
1437
|
"""
|
756
1438
|
if replace_writen:
|
757
1439
|
return tuple(
|
758
|
-
[st.value if not been_writen else None
|
1440
|
+
[st.value if not been_writen else None
|
1441
|
+
for st, been_writen in zip(self.states, self.been_writen)]
|
1442
|
+
)
|
759
1443
|
else:
|
760
1444
|
return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
|
761
1445
|
|
762
1446
|
def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
|
763
1447
|
"""
|
764
|
-
|
765
|
-
|
1448
|
+
Retrieve the states that were written during the function execution.
|
1449
|
+
|
1450
|
+
This method returns the states that were modified (written to) during
|
1451
|
+
the traced function's execution. It can optionally replace unwritten (read-only)
|
1452
|
+
states with None.
|
1453
|
+
|
1454
|
+
Args:
|
1455
|
+
replace_read (bool, optional): If True, replace read-only states with None
|
1456
|
+
in the returned tuple. If False, exclude read-only states entirely from
|
1457
|
+
the result. Defaults to False.
|
1458
|
+
|
766
1459
|
Returns:
|
767
|
-
|
1460
|
+
Tuple[State, ...]: A tuple containing the written states.
|
1461
|
+
If replace_read is True, the tuple will have the same length as the
|
1462
|
+
total number of states, with None for read-only states.
|
1463
|
+
If replace_read is False, the tuple will only contain written states.
|
768
1464
|
"""
|
769
1465
|
if replace_read:
|
770
1466
|
return tuple([st if been_writen else None
|
@@ -774,10 +1470,24 @@ class StateTraceStack(Generic[A]):
|
|
774
1470
|
|
775
1471
|
def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
|
776
1472
|
"""
|
777
|
-
|
778
|
-
|
1473
|
+
Retrieve the values of states that were written during the function execution.
|
1474
|
+
|
1475
|
+
This method returns the values of states that were modified (written to) during
|
1476
|
+
the traced function's execution. It can optionally replace unwritten (read-only)
|
1477
|
+
states with None.
|
1478
|
+
|
1479
|
+
Args:
|
1480
|
+
replace_read (bool, optional): If True, replace the values of read-only
|
1481
|
+
states with None in the returned tuple. If False, exclude read-only
|
1482
|
+
states entirely from the result. Defaults to False.
|
1483
|
+
|
779
1484
|
Returns:
|
780
|
-
|
1485
|
+
Tuple[PyTree, ...]: A tuple containing the values of written states.
|
1486
|
+
If replace_read is True, the tuple will have the same length as the
|
1487
|
+
total number of states, with None for read-only states.
|
1488
|
+
If replace_read is False, the tuple will only contain values of
|
1489
|
+
written states.
|
1490
|
+
|
781
1491
|
"""
|
782
1492
|
if replace_read:
|
783
1493
|
return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
|
@@ -790,8 +1500,64 @@ class StateTraceStack(Generic[A]):
|
|
790
1500
|
"""
|
791
1501
|
return StateTraceStack().merge(self, other)
|
792
1502
|
|
1503
|
+
def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
|
1504
|
+
"""
|
1505
|
+
Assign new values to the states tracked by this ``StateTraceStack``.
|
1506
|
+
|
1507
|
+
This method updates the values of the states based on whether they were
|
1508
|
+
written to or only read during the tracing process. For states that were
|
1509
|
+
written to, it directly assigns the new value. For states that were only
|
1510
|
+
read, it restores the value using the state's restore_value method.
|
1511
|
+
|
1512
|
+
Args:
|
1513
|
+
state_vals (Sequence[PyTree]): A sequence of new state values to be
|
1514
|
+
assigned. Each element in this sequence corresponds to a state
|
1515
|
+
in the ``StateTraceStack``'s states list.
|
1516
|
+
|
1517
|
+
Raises:
|
1518
|
+
ValueError: If the length of state_vals doesn't match the number of
|
1519
|
+
states in the ``StateTraceStack``.
|
1520
|
+
|
1521
|
+
Returns:
|
1522
|
+
None
|
1523
|
+
|
1524
|
+
Note:
|
1525
|
+
The order of state_vals should match the order of states in the
|
1526
|
+
``StateTraceStack``'s states list.
|
1527
|
+
"""
|
1528
|
+
if len(state_vals) != len(self.states):
|
1529
|
+
raise ValueError('The length of the state values must be equal to the states. '
|
1530
|
+
f'Bug got {len(state_vals)} and {len(self.states)}')
|
1531
|
+
for st, written, val in zip(self.states, self.been_writen, state_vals):
|
1532
|
+
if written:
|
1533
|
+
st.value = val
|
1534
|
+
else:
|
1535
|
+
st.restore_value(val)
|
1536
|
+
|
1537
|
+
def state_subset(self, state_type: type) -> List:
|
1538
|
+
"""
|
1539
|
+
Get a subset of states of a specific type from the ``StateTraceStack``.
|
1540
|
+
|
1541
|
+
This method filters the states in the ``StateTraceStack`` and returns only
|
1542
|
+
those that match the specified state type.
|
1543
|
+
|
1544
|
+
Args:
|
1545
|
+
state_type (type): The type of state to filter by. This should be
|
1546
|
+
a subclass of State or State itself.
|
1547
|
+
|
1548
|
+
Returns:
|
1549
|
+
List[State]: A list containing all states in the ``StateTraceStack``
|
1550
|
+
that are instances of the specified state_type.
|
1551
|
+
|
1552
|
+
Example:
|
1553
|
+
>>> stack = StateTraceStack()
|
1554
|
+
>>> # Assume stack has been populated with various state types
|
1555
|
+
>>> short_term_states = stack.state_subset(ShortTermState)
|
1556
|
+
"""
|
1557
|
+
return [st for st in self.states if isinstance(st, state_type)]
|
1558
|
+
|
793
1559
|
|
794
|
-
class TreefyState(Generic[A],
|
1560
|
+
class TreefyState(Generic[A], PrettyObject):
|
795
1561
|
"""
|
796
1562
|
The state as a pytree.
|
797
1563
|
"""
|
@@ -815,14 +1581,28 @@ class TreefyState(Generic[A], PrettyReprTree):
|
|
815
1581
|
|
816
1582
|
def __pretty_repr_item__(self, k, v):
|
817
1583
|
if k in ['_level', '_source_info', '_been_writen']:
|
818
|
-
return None
|
1584
|
+
return None
|
819
1585
|
if k == '_value':
|
820
1586
|
return 'value', v
|
821
1587
|
|
822
1588
|
if k == '_name':
|
823
|
-
return
|
1589
|
+
return None if v is None else ('name', v)
|
824
1590
|
return k, v
|
825
1591
|
|
1592
|
+
@property
|
1593
|
+
def name(self) -> Optional[str]:
|
1594
|
+
"""
|
1595
|
+
The name of the state.
|
1596
|
+
"""
|
1597
|
+
return self._name
|
1598
|
+
|
1599
|
+
@name.setter
|
1600
|
+
def name(self, name: str) -> None:
|
1601
|
+
"""
|
1602
|
+
Set the name of the state.
|
1603
|
+
"""
|
1604
|
+
self._name = name
|
1605
|
+
|
826
1606
|
def replace(self, value: B) -> TreefyState[B]:
|
827
1607
|
"""
|
828
1608
|
Replace the value of the state reference.
|
@@ -837,7 +1617,9 @@ class TreefyState(Generic[A], PrettyReprTree):
|
|
837
1617
|
# __init__ logic which should not be called twice
|
838
1618
|
metadata = self.get_metadata()
|
839
1619
|
state = object.__new__(self.type)
|
840
|
-
|
1620
|
+
metadata.pop('_value', None)
|
1621
|
+
metadata.pop('_level', None)
|
1622
|
+
vars(state).update(**metadata, _value=self.value, _level=_get_trace_stack_level())
|
841
1623
|
return state
|
842
1624
|
|
843
1625
|
def copy(self: TreefyState[A]) -> TreefyState[A]:
|