brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/_state.py
CHANGED
@@ -13,408 +13,852 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
16
18
|
import contextlib
|
19
|
+
import dataclasses
|
17
20
|
import threading
|
18
|
-
from
|
21
|
+
from functools import wraps, partial
|
22
|
+
from typing import (Any, Union, Callable, Generic, Mapping,
|
23
|
+
TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence)
|
19
24
|
|
20
25
|
import jax
|
21
26
|
import numpy as np
|
22
27
|
from jax.api_util import shaped_abstractify
|
23
28
|
from jax.extend import source_info_util
|
24
29
|
|
25
|
-
from brainstate.typing import ArrayLike, PyTree
|
26
|
-
from brainstate.util import DictManager
|
30
|
+
from brainstate.typing import ArrayLike, PyTree, Missing
|
31
|
+
from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr, TraceContextError
|
32
|
+
from brainstate.util._tracers import StateJaxTracer
|
27
33
|
|
28
34
|
__all__ = [
|
29
|
-
|
30
|
-
'StateDictManager',
|
31
|
-
'StateTrace',
|
32
|
-
'visible_state_dict',
|
33
|
-
'check_state_value_tree',
|
34
|
-
]
|
35
|
+
'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
|
35
36
|
|
36
|
-
|
37
|
-
|
37
|
+
'StateDictManager', 'StateTraceStack', 'check_state_value_tree', 'check_state_jax_tracer', 'catch_new_states',
|
38
|
+
]
|
38
39
|
|
40
|
+
A = TypeVar('A')
|
41
|
+
B = TypeVar('B')
|
42
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
39
43
|
|
40
|
-
|
41
|
-
if cls not in _pytree_registered_objects:
|
42
|
-
jax.tree_util.register_pytree_node_class(cls)
|
43
|
-
_pytree_registered_objects.add(cls)
|
44
|
+
max_int = np.iinfo(np.int32)
|
44
45
|
|
45
46
|
|
46
47
|
# The global state of the state stack is accessed by a thread-local object.
|
47
48
|
# This allows concurrent tracing in separate threads; passing traced objects
|
48
49
|
# between threads is forbidden.
|
49
50
|
class ThreadLocalStack(threading.local):
|
50
|
-
|
51
|
-
|
52
|
-
|
51
|
+
def __init__(self):
|
52
|
+
self.state_stack: List[StateTraceStack] = []
|
53
|
+
self.tree_check: List[bool] = [False]
|
54
|
+
self.jax_tracer_check: List[bool] = [False]
|
55
|
+
self.new_state_catcher: List[Catcher] = []
|
53
56
|
|
54
|
-
thread_local_stack = ThreadLocalStack()
|
55
57
|
|
56
|
-
|
58
|
+
TRACE_CONTEXT = ThreadLocalStack()
|
57
59
|
|
58
60
|
|
59
61
|
@contextlib.contextmanager
|
60
|
-
def check_state_value_tree() -> None:
|
61
|
-
"""
|
62
|
-
The contex manager to check weather the tree structure of the state value keeps consistently.
|
63
|
-
|
64
|
-
Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
|
65
|
-
the tree structure of the value is not checked to avoid off the repeated evaluation.
|
66
|
-
If you want to check the tree structure of the value once the new value is assigned,
|
67
|
-
you can use this context manager.
|
68
|
-
|
69
|
-
Example::
|
70
|
-
|
71
|
-
```python
|
72
|
-
state = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
73
|
-
with check_state_value_tree():
|
74
|
-
state.value = jnp.zeros((2, 3))
|
75
|
-
|
76
|
-
# The following code will raise an error.
|
77
|
-
state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
|
78
|
-
```
|
79
|
-
|
80
|
-
"""
|
81
|
-
try:
|
82
|
-
_global_context_to_check_state_tree.append(True)
|
83
|
-
yield
|
84
|
-
finally:
|
85
|
-
_global_context_to_check_state_tree.pop()
|
86
|
-
|
87
|
-
|
88
|
-
class State(object):
|
89
|
-
"""
|
90
|
-
The pointer to specify the dynamical data.
|
91
|
-
|
92
|
-
To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
|
93
|
-
|
94
|
-
Example::
|
95
|
-
|
96
|
-
class MyState(State):
|
97
|
-
pass
|
98
|
-
|
99
|
-
The typical examples of :py:class:`~.State` subclass are:
|
100
|
-
|
101
|
-
- :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
|
102
|
-
- :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
|
103
|
-
- :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
|
104
|
-
- :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
|
105
|
-
|
106
|
-
Args:
|
107
|
-
value: PyTree. It can be anything as a pyTree.
|
108
|
-
"""
|
109
|
-
__module__ = 'brainstate'
|
110
|
-
__slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
|
111
|
-
|
112
|
-
def __init__(self, value: PyTree[ArrayLike], name: Optional[str] = None):
|
113
|
-
if isinstance(value, State):
|
114
|
-
value = value.value
|
115
|
-
self._value = value
|
116
|
-
self._tree = jax.tree.structure(value)
|
117
|
-
self._check_tree = False
|
118
|
-
self._level = len(thread_local_stack.stack)
|
119
|
-
self._source_info = source_info_util.current()
|
120
|
-
self._name = name
|
121
|
-
|
122
|
-
@property
|
123
|
-
def name(self) -> Optional[str]:
|
62
|
+
def check_state_value_tree(val: bool = True) -> None:
|
124
63
|
"""
|
125
|
-
The
|
126
|
-
|
127
|
-
|
64
|
+
The contex manager to check weather the tree structure of the state value keeps consistently.
|
65
|
+
|
66
|
+
Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
|
67
|
+
the tree structure of the value is not checked to avoid off the repeated evaluation.
|
68
|
+
If you want to check the tree structure of the value once the new value is assigned,
|
69
|
+
you can use this context manager.
|
70
|
+
|
71
|
+
Example::
|
72
|
+
|
73
|
+
>>> import brainstate as bst
|
74
|
+
>>> import jax.numpy as jnp
|
75
|
+
>>> state = bst.ShortTermState(jnp.zeros((2, 3)))
|
76
|
+
>>> with bst.check_state_value_tree():
|
77
|
+
>>> # The line below will not raise an error.
|
78
|
+
>>> state.value = jnp.zeros((2, 3))
|
79
|
+
...
|
80
|
+
>>> # The following code will raise an error, since it changes the tree structure.
|
81
|
+
>>> state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
|
128
82
|
|
129
|
-
@name.setter
|
130
|
-
def name(self, name: str) -> None:
|
131
|
-
"""
|
132
|
-
Set the name of the state.
|
133
83
|
"""
|
134
|
-
|
84
|
+
try:
|
85
|
+
TRACE_CONTEXT.tree_check.append(val)
|
86
|
+
yield
|
87
|
+
finally:
|
88
|
+
TRACE_CONTEXT.tree_check.pop()
|
89
|
+
|
90
|
+
|
91
|
+
@contextlib.contextmanager
|
92
|
+
def catch_new_states(tag: str = None) -> List:
|
93
|
+
try:
|
94
|
+
catcher = Catcher(tag)
|
95
|
+
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
96
|
+
yield catcher
|
97
|
+
finally:
|
98
|
+
TRACE_CONTEXT.new_state_catcher.pop()
|
99
|
+
|
100
|
+
|
101
|
+
class Catcher:
|
102
|
+
def __init__(self, tag: str):
|
103
|
+
self.tag = tag
|
104
|
+
self.state_ids = set()
|
105
|
+
self.states = []
|
135
106
|
|
136
|
-
|
137
|
-
|
107
|
+
def append(self, state: State):
|
108
|
+
if id(state) not in self.state_ids:
|
109
|
+
self.state_ids.add(id(state))
|
110
|
+
self.states.append(state)
|
111
|
+
state.tag = self.tag
|
112
|
+
|
113
|
+
|
114
|
+
@contextlib.contextmanager
|
115
|
+
def check_state_jax_tracer(val: bool = True) -> None:
|
138
116
|
"""
|
139
|
-
The
|
117
|
+
The context manager to check whether the state is valid to trace.
|
118
|
+
|
119
|
+
Example::
|
120
|
+
|
121
|
+
>>> import jax
|
122
|
+
>>> import brainstate as bst
|
123
|
+
>>> import jax.numpy as jnp
|
124
|
+
>>>
|
125
|
+
>>> a = bst.ShortTermState(jnp.zeros((2, 3)))
|
126
|
+
>>>
|
127
|
+
>>> @jax.jit
|
128
|
+
>>> def run_state(b):
|
129
|
+
>>> a.value = b
|
130
|
+
>>> return a.value
|
131
|
+
>>>
|
132
|
+
>>> # The following code will not raise an error, since the state is valid to trace.
|
133
|
+
>>> run_state(jnp.ones((2, 3)))
|
134
|
+
>>>
|
135
|
+
>>> with check_state_jax_tracer():
|
136
|
+
>>> # The line below will not raise an error.
|
137
|
+
>>> run_state(jnp.ones((2, 4)))
|
140
138
|
"""
|
141
|
-
|
139
|
+
try:
|
140
|
+
TRACE_CONTEXT.jax_tracer_check.append(val)
|
141
|
+
yield
|
142
|
+
finally:
|
143
|
+
TRACE_CONTEXT.jax_tracer_check.pop()
|
142
144
|
|
143
|
-
# read the value by the stack (>= level)
|
144
|
-
trace: StateTrace
|
145
|
-
for trace in thread_local_stack.stack[self._level:]:
|
146
|
-
trace.read_its_value(self)
|
147
|
-
# return the value
|
148
|
-
return self._value
|
149
145
|
|
150
|
-
|
151
|
-
|
146
|
+
@dataclasses.dataclass
|
147
|
+
class StateMetadata(Generic[A]):
|
152
148
|
"""
|
153
|
-
|
149
|
+
The state metadata.
|
154
150
|
|
155
151
|
Args:
|
156
|
-
|
152
|
+
raw_value: The raw value.
|
153
|
+
metadata: The metadata.
|
157
154
|
"""
|
158
|
-
|
159
|
-
|
160
|
-
self._check_value_tree(v)
|
161
|
-
# write the value by the stack (>= level)
|
162
|
-
trace: StateTrace
|
163
|
-
for trace in thread_local_stack.stack[self._level:]:
|
164
|
-
trace.write_its_value(self)
|
165
|
-
# set the value
|
166
|
-
self._value = v
|
167
|
-
|
168
|
-
def _check_value_tree(self, v):
|
169
|
-
if self._check_tree or _global_context_to_check_state_tree[-1]:
|
170
|
-
in_tree = jax.tree.structure(v)
|
171
|
-
if in_tree != self._tree:
|
172
|
-
self._raise_error_with_source_info(
|
173
|
-
ValueError(f'The given value {in_tree} does not '
|
174
|
-
f'match with the origin tree structure '
|
175
|
-
f'{self._tree}.')
|
176
|
-
)
|
155
|
+
raw_value: A
|
156
|
+
metadata: Mapping[str, Any] = dataclasses.field(default_factory=dict)
|
177
157
|
|
178
|
-
def _raise_error_with_source_info(self, error: Exception):
|
179
|
-
name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
|
180
|
-
with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
|
181
|
-
raise error
|
182
158
|
|
183
|
-
|
184
|
-
pass
|
185
|
-
|
186
|
-
@property
|
187
|
-
def source_info(self) -> source_info_util.SourceInfo:
|
159
|
+
def with_metadata(initializer: F, **metadata: Any) -> F:
|
188
160
|
"""
|
189
|
-
|
190
|
-
the source code where the definition of the state.
|
191
|
-
|
192
|
-
Returns:
|
193
|
-
The source information.
|
161
|
+
A decorator to add metadata to the state.
|
194
162
|
"""
|
195
|
-
return self._source_info
|
196
163
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
Returns:
|
201
|
-
A pair where the first element is a list of leaf values
|
202
|
-
and the second element is a treedef representing the
|
203
|
-
structure of the flattened tree.
|
204
|
-
"""
|
205
|
-
return (self._value,), (self._level,)
|
164
|
+
@wraps(initializer)
|
165
|
+
def wrapper(*args):
|
166
|
+
return StateMetadata(initializer(*args), metadata=metadata)
|
206
167
|
|
207
|
-
|
208
|
-
def tree_unflatten(cls, aux_data, flat_contents):
|
209
|
-
"""Reconstructs a variable from the aux_data and the leaves.
|
168
|
+
return wrapper # type: ignore
|
210
169
|
|
211
|
-
Args:
|
212
|
-
aux_data:
|
213
|
-
flat_contents:
|
214
170
|
|
215
|
-
|
216
|
-
|
217
|
-
"""
|
218
|
-
(_level,) = aux_data
|
219
|
-
self = cls(flat_contents[0])
|
220
|
-
self._level = max_int
|
221
|
-
return self
|
222
|
-
|
223
|
-
def __repr__(self):
|
224
|
-
leaves, tree = jax.tree.flatten(self._value)
|
225
|
-
leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
|
226
|
-
tree_info = jax.tree.unflatten(tree, leaves_info)
|
227
|
-
if self.name is None:
|
228
|
-
return f'{self.__class__.__name__}({tree_info})'
|
229
|
-
else:
|
230
|
-
return f'{self.__class__.__name__}({self.name}: {tree_info})'
|
171
|
+
def _get_trace_stack_level() -> int:
|
172
|
+
return len(TRACE_CONTEXT.state_stack)
|
231
173
|
|
232
174
|
|
233
|
-
class
|
234
|
-
|
235
|
-
|
236
|
-
self.dtype = dtype
|
237
|
-
self.ndim = len(shape)
|
238
|
-
self.size = np.prod(shape)
|
175
|
+
class State(Generic[A], PrettyRepr):
|
176
|
+
"""
|
177
|
+
The pointer to specify the dynamical data.
|
239
178
|
|
240
|
-
|
241
|
-
return f'{self.dtype}{list(self.shape)}'
|
179
|
+
To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
|
242
180
|
|
181
|
+
Example::
|
243
182
|
|
244
|
-
class
|
245
|
-
|
246
|
-
The short-term state, which is used to store the short-term data in the program.
|
183
|
+
>>> class MyState(State):
|
184
|
+
>>> pass
|
247
185
|
|
248
|
-
|
249
|
-
"""
|
186
|
+
The typical examples of :py:class:`~.State` subclass are:
|
250
187
|
|
251
|
-
|
188
|
+
- :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
|
189
|
+
- :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
|
190
|
+
- :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
|
191
|
+
- :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
|
252
192
|
|
193
|
+
Args:
|
194
|
+
value: PyTree. It can be anything as a pyTree.
|
195
|
+
"""
|
196
|
+
__module__ = 'brainstate'
|
197
|
+
_trace_state: StateJaxTracer
|
198
|
+
_level: int
|
199
|
+
_source_info: source_info_util.SourceInfo
|
200
|
+
_name: Optional[str]
|
201
|
+
_value: PyTree
|
202
|
+
_been_writen: bool # useful in `unflatten` and `flatten` graph processing
|
203
|
+
tag: Optional[str]
|
204
|
+
|
205
|
+
def __init__(
|
206
|
+
self,
|
207
|
+
value: Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]],
|
208
|
+
name: Optional[str] = None,
|
209
|
+
**metadata: Any
|
210
|
+
):
|
211
|
+
tag = metadata.pop('tag', None)
|
212
|
+
|
213
|
+
# avoid using self._setattr to avoid the check
|
214
|
+
vars(self)['_trace_state'] = StateJaxTracer()
|
215
|
+
|
216
|
+
# set the value and metadata
|
217
|
+
if isinstance(value, StateMetadata):
|
218
|
+
metadata.update(dict(value.metadata))
|
219
|
+
value = value.raw_value
|
220
|
+
if isinstance(value, State):
|
221
|
+
value = value.value
|
222
|
+
|
223
|
+
# update metadata
|
224
|
+
metadata.update(_value=value,
|
225
|
+
_level=_get_trace_stack_level(),
|
226
|
+
_source_info=source_info_util.current(),
|
227
|
+
_name=name,
|
228
|
+
tag=tag,
|
229
|
+
_been_writen=False)
|
230
|
+
|
231
|
+
# avoid using self._setattr to avoid the check
|
232
|
+
vars(self).update(metadata)
|
233
|
+
|
234
|
+
record_state_init(self)
|
235
|
+
|
236
|
+
if not TYPE_CHECKING:
|
237
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
238
|
+
return self._setattr(name, value)
|
239
|
+
|
240
|
+
def _setattr(self, name: str, value: Any):
|
241
|
+
"""
|
242
|
+
Check if the state is valid to mutate.
|
243
|
+
"""
|
244
|
+
if TRACE_CONTEXT.jax_tracer_check[-1]:
|
245
|
+
self.check_valid_trace(lambda: f'Cannot mutate {type(self).__name__} from a different trace level')
|
246
|
+
object.__setattr__(self, name, value)
|
247
|
+
|
248
|
+
def _setattr_no_check(self, name: str, value: Any):
|
249
|
+
"""
|
250
|
+
Set the attribute without checking the trace level.
|
251
|
+
"""
|
252
|
+
vars(self)[name] = value
|
253
|
+
|
254
|
+
def check_valid_trace(self, error_msg: Callable[[], str]):
|
255
|
+
"""
|
256
|
+
Check if the state is valid to trace.
|
257
|
+
"""
|
258
|
+
if not self._trace_state.is_valid():
|
259
|
+
raise TraceContextError(error_msg())
|
260
|
+
|
261
|
+
@property
|
262
|
+
def name(self) -> Optional[str]:
|
263
|
+
"""
|
264
|
+
The name of the state.
|
265
|
+
"""
|
266
|
+
return self._name
|
267
|
+
|
268
|
+
@name.setter
|
269
|
+
def name(self, name: str) -> None:
|
270
|
+
"""
|
271
|
+
Set the name of the state.
|
272
|
+
"""
|
273
|
+
self._setattr_no_check('_name', name)
|
274
|
+
|
275
|
+
@property
|
276
|
+
def value(self) -> PyTree[ArrayLike]:
|
277
|
+
"""
|
278
|
+
The data and its value.
|
279
|
+
"""
|
280
|
+
self.check_if_deleted()
|
281
|
+
record_state_value_read(self)
|
282
|
+
return self._value
|
283
|
+
|
284
|
+
@value.setter
|
285
|
+
def value(self, v) -> None:
|
286
|
+
"""
|
287
|
+
Set the value of the state.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
v: The value.
|
291
|
+
"""
|
292
|
+
self.write_value(v)
|
293
|
+
self._been_writen = True
|
294
|
+
|
295
|
+
def write_value(self, v) -> None:
|
296
|
+
# value checking
|
297
|
+
if isinstance(v, State):
|
298
|
+
raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
|
299
|
+
self._check_value_tree(v)
|
300
|
+
# write the value by the stack (>= level)
|
301
|
+
record_state_value_write(self)
|
302
|
+
# set the value
|
303
|
+
self._value = v
|
304
|
+
|
305
|
+
def restore_value(self, v) -> None:
|
306
|
+
"""
|
307
|
+
Restore the value of the state.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
v: The value.
|
311
|
+
"""
|
312
|
+
# value checking
|
313
|
+
if isinstance(v, State):
|
314
|
+
raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
|
315
|
+
with check_state_value_tree():
|
316
|
+
self._check_value_tree(v)
|
317
|
+
# record the value by the stack (>= level)
|
318
|
+
record_state_value_restore(self)
|
319
|
+
# set the value
|
320
|
+
self._value = v
|
321
|
+
|
322
|
+
def value_call(self, func: Callable[..., Any]) -> Any:
|
323
|
+
"""
|
324
|
+
Call the function with the value of the state.
|
325
|
+
"""
|
326
|
+
return jax.tree.map(func, self.value)
|
327
|
+
|
328
|
+
def _check_value_tree(self, v):
|
329
|
+
"""
|
330
|
+
Check if the value tree structure is consistent.
|
331
|
+
"""
|
332
|
+
if TRACE_CONTEXT.tree_check[-1]:
|
333
|
+
in_tree = jax.tree.structure(v)
|
334
|
+
self_tree = jax.tree.structure(self._value)
|
335
|
+
if in_tree != self_tree:
|
336
|
+
self._raise_error_with_source_info(
|
337
|
+
ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
|
338
|
+
)
|
339
|
+
|
340
|
+
def _raise_error_with_source_info(self, error: Exception):
|
341
|
+
"""
|
342
|
+
Raise an error with the source information for easy debugging.
|
343
|
+
"""
|
344
|
+
name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
|
345
|
+
with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
|
346
|
+
raise error
|
347
|
+
|
348
|
+
def check_if_deleted(self):
|
349
|
+
pass
|
350
|
+
|
351
|
+
@property
|
352
|
+
def source_info(self) -> source_info_util.SourceInfo:
|
353
|
+
"""
|
354
|
+
The source information of the state, can be useful to identify
|
355
|
+
the source code where the definition of the state.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
The source information.
|
359
|
+
"""
|
360
|
+
return self._source_info
|
361
|
+
|
362
|
+
def update_from_ref(self, state_ref: TreefyState[A]) -> None:
|
363
|
+
"""
|
364
|
+
Update the state from the state reference :py:class:`TreefyState`.
|
365
|
+
|
366
|
+
Args:
|
367
|
+
state_ref: The state reference.
|
368
|
+
"""
|
369
|
+
metadata = state_ref.get_metadata()
|
370
|
+
variable_vars = vars(self)
|
371
|
+
variable_vars.update(**metadata)
|
372
|
+
if metadata.pop('_been_writen', True):
|
373
|
+
self.value = state_ref.value
|
374
|
+
else:
|
375
|
+
self.restore_value(state_ref.value)
|
376
|
+
|
377
|
+
def replace(self, value: Any = Missing, **kwargs) -> State[Any]:
|
378
|
+
"""
|
379
|
+
Replace the attribute of the state.
|
380
|
+
"""
|
381
|
+
if value is not Missing:
|
382
|
+
kwargs['_value'] = value
|
383
|
+
|
384
|
+
# return `value` if it is a State
|
385
|
+
if '_value' in kwargs and isinstance(value := kwargs['_value'], State):
|
386
|
+
# remove value from kwargs
|
387
|
+
kwargs.pop('_value')
|
388
|
+
if type(self) is not type(value):
|
389
|
+
raise ValueError('Cannot replace value from incompatible container, '
|
390
|
+
f'expected {type(self).__name__}, got {type(value).__name__}')
|
391
|
+
# if kwargs aren't empty, recursively call replace
|
392
|
+
# else return variable value
|
393
|
+
if kwargs:
|
394
|
+
return value.replace(**kwargs)
|
395
|
+
else:
|
396
|
+
return value
|
397
|
+
|
398
|
+
# get and update attributes
|
399
|
+
attributes = vars(self).copy()
|
400
|
+
attributes.update(**kwargs)
|
401
|
+
# return new instance with updated attributes
|
402
|
+
obj = object.__new__(type(self))
|
403
|
+
vars(obj).update(attributes)
|
404
|
+
return obj
|
405
|
+
|
406
|
+
def copy(self: State[A]) -> State[A]:
|
407
|
+
"""
|
408
|
+
Copy the state.
|
409
|
+
"""
|
410
|
+
obj = object.__new__(type(self))
|
411
|
+
attributes = vars(self).copy()
|
412
|
+
# keep its own trace state and stack level
|
413
|
+
attributes['_trace_state'] = StateJaxTracer()
|
414
|
+
attributes['_level'] = _get_trace_stack_level()
|
415
|
+
attributes['_source_info'] = source_info_util.current()
|
416
|
+
attributes.pop('_been_writen', None)
|
417
|
+
# update the metadata
|
418
|
+
vars(obj).update(attributes)
|
419
|
+
return obj
|
420
|
+
|
421
|
+
def to_state_ref(self: State[A]) -> TreefyState[A]:
|
422
|
+
metadata = vars(self).copy()
|
423
|
+
del metadata['_value']
|
424
|
+
del metadata['_trace_state']
|
425
|
+
del metadata['_level']
|
426
|
+
return TreefyState(type(self), self._value, **metadata)
|
427
|
+
|
428
|
+
def __pretty_repr__(self):
|
429
|
+
yield PrettyType(type=type(self))
|
430
|
+
for name, value in vars(self).items():
|
431
|
+
if name == '_value':
|
432
|
+
name = 'value'
|
433
|
+
if name == '_name':
|
434
|
+
if value is None:
|
435
|
+
continue
|
436
|
+
else:
|
437
|
+
name = 'name'
|
438
|
+
if name == 'tag' and value is None:
|
439
|
+
continue
|
440
|
+
if name in ['_trace_state', '_level', '_source_info', '_been_writen']:
|
441
|
+
continue
|
442
|
+
yield PrettyAttr(name, repr(value))
|
443
|
+
|
444
|
+
def __treescope_repr__(self, path, subtree_renderer):
|
445
|
+
children = {}
|
446
|
+
for name, value in vars(self).items():
|
447
|
+
if name == '_value':
|
448
|
+
name = 'value'
|
449
|
+
if name == '_name':
|
450
|
+
if value is None:
|
451
|
+
continue
|
452
|
+
else:
|
453
|
+
name = 'name'
|
454
|
+
if name == 'tag' and value is None:
|
455
|
+
continue
|
456
|
+
if name in ['_trace_state', '_level', '_source_info', '_been_writen']:
|
457
|
+
continue
|
458
|
+
children[name] = value
|
459
|
+
|
460
|
+
import treescope # type: ignore[import-not-found,import-untyped]
|
461
|
+
return treescope.repr_lib.render_object_constructor(
|
462
|
+
object_type=type(self),
|
463
|
+
attributes=children,
|
464
|
+
path=path,
|
465
|
+
subtree_renderer=subtree_renderer,
|
466
|
+
)
|
253
467
|
|
254
|
-
|
255
|
-
|
256
|
-
The long-term state, which is used to store the long-term data in the program.
|
468
|
+
def __eq__(self, other: object) -> bool:
|
469
|
+
return type(self) is type(other) and vars(other) == vars(self)
|
257
470
|
|
258
|
-
For example, in a training process, the weights of the model are long-term states.
|
259
471
|
|
260
|
-
|
472
|
+
def record_state_init(st: State[A]):
|
473
|
+
trace: Catcher
|
474
|
+
for trace in TRACE_CONTEXT.new_state_catcher:
|
475
|
+
trace.append(st)
|
261
476
|
|
262
|
-
__module__ = 'brainstate'
|
263
477
|
|
478
|
+
def record_state_value_read(st: State[A]):
|
479
|
+
trace: StateTraceStack
|
480
|
+
for trace in TRACE_CONTEXT.state_stack[st._level:]:
|
481
|
+
trace.read_its_value(st)
|
264
482
|
|
265
|
-
class ParamState(LongTermState):
|
266
|
-
"""
|
267
|
-
The parameter state, which is used to store the trainable parameters in the model.
|
268
|
-
"""
|
269
|
-
__module__ = 'brainstate'
|
270
483
|
|
484
|
+
def record_state_value_write(st: State[A]):
|
485
|
+
trace: StateTraceStack
|
486
|
+
for trace in TRACE_CONTEXT.state_stack[st._level:]:
|
487
|
+
trace.write_its_value(st)
|
271
488
|
|
272
|
-
class StateDictManager(DictManager):
|
273
|
-
"""
|
274
|
-
State stack, for collecting all :py:class:`~.State` used in the program.
|
275
489
|
|
276
|
-
|
277
|
-
|
490
|
+
def record_state_value_restore(st: State[A]):
|
491
|
+
record_state_value_read(st)
|
278
492
|
|
279
|
-
__module__ = 'brainstate'
|
280
493
|
|
281
|
-
|
282
|
-
"""
|
283
|
-
Assign the value for each element according to the given ``data``.
|
494
|
+
class ShortTermState(State):
|
284
495
|
"""
|
285
|
-
|
286
|
-
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
287
|
-
for k, v in arg.items():
|
288
|
-
self._set_elem(k, v)
|
496
|
+
The short-term state, which is used to store the short-term data in the program.
|
289
497
|
|
290
|
-
|
291
|
-
"""
|
292
|
-
Split the values into several subsets of stack by the given types.
|
293
|
-
"""
|
294
|
-
results = tuple(DictManager() for _ in range(len(filters) + 1))
|
295
|
-
for k, v in self.items():
|
296
|
-
for i, filt in enumerate(filters):
|
297
|
-
if isinstance(v, filt):
|
298
|
-
results[i][k] = v.value
|
299
|
-
break
|
300
|
-
else:
|
301
|
-
results[-1][k] = v.value
|
302
|
-
return results
|
303
|
-
|
304
|
-
def collect_values(self) -> Dict:
|
305
|
-
"""
|
306
|
-
Collect the values by the given types.
|
498
|
+
For example, in a training process, the gradients of the model are short-term states.
|
307
499
|
"""
|
308
|
-
results = DictManager()
|
309
|
-
for k, v in self.items():
|
310
|
-
results[k] = v.value
|
311
|
-
return results
|
312
500
|
|
313
|
-
|
314
|
-
return super().split(first, *others)
|
501
|
+
__module__ = 'brainstate'
|
315
502
|
|
316
|
-
def to_dict_values(self) -> Dict:
|
317
|
-
"""
|
318
|
-
Convert the values into a dict.
|
319
|
-
"""
|
320
|
-
return {k: v.value for k, v in self.items()}
|
321
503
|
|
322
|
-
|
323
|
-
|
504
|
+
class LongTermState(State):
|
505
|
+
"""
|
506
|
+
The long-term state, which is used to store the long-term data in the program.
|
324
507
|
|
325
|
-
|
326
|
-
|
508
|
+
For example, in a training process, the weights of the model are long-term states.
|
509
|
+
"""
|
327
510
|
|
511
|
+
__module__ = 'brainstate'
|
328
512
|
|
329
|
-
class visible_state_dict(StateDictManager):
|
330
|
-
"""
|
331
|
-
The state dictionary whose elements are visible to ``.states()`` collection functions.
|
332
|
-
"""
|
333
|
-
pass
|
334
513
|
|
514
|
+
class HiddenState(ShortTermState):
|
515
|
+
"""
|
516
|
+
The hidden state, which is used to store the hidden data in a dynamic model.
|
517
|
+
"""
|
335
518
|
|
336
|
-
|
337
|
-
"""
|
338
|
-
The state trace, which is used to trace the states automatically.
|
339
|
-
"""
|
519
|
+
__module__ = 'brainstate'
|
340
520
|
|
341
|
-
def __init__(self, new_arg: Callable = None):
|
342
|
-
self.states: List[State] = []
|
343
|
-
self.types: List[str] = []
|
344
|
-
self._id2index = dict()
|
345
|
-
self._org_values = []
|
346
|
-
self._jax_trace_new_arg = new_arg
|
347
|
-
self._written_ids = set()
|
348
521
|
|
349
|
-
|
350
|
-
|
522
|
+
class ParamState(LongTermState):
|
523
|
+
"""
|
524
|
+
The parameter state, which is used to store the trainable parameters in the model.
|
525
|
+
"""
|
351
526
|
|
352
|
-
|
353
|
-
if self._jax_trace_new_arg is not None:
|
354
|
-
# internal use
|
355
|
-
state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
|
527
|
+
__module__ = 'brainstate'
|
356
528
|
|
357
|
-
def __enter__(self) -> 'StateTrace':
|
358
|
-
thread_local_stack.stack.append(self)
|
359
|
-
return self
|
360
529
|
|
361
|
-
|
362
|
-
|
530
|
+
class StateDictManager(DictManager):
|
531
|
+
"""
|
532
|
+
State stack, for collecting all :py:class:`~.State` used in the program.
|
363
533
|
|
364
|
-
|
534
|
+
:py:class:`~.StateDictManager` supports all features of python dict.
|
365
535
|
"""
|
366
|
-
Read the value of the state.
|
367
536
|
|
368
|
-
|
369
|
-
|
537
|
+
__module__ = 'brainstate'
|
538
|
+
|
539
|
+
def assign_values(self, *args: Dict) -> None:
|
540
|
+
"""
|
541
|
+
Assign the value for each element according to the given ``data``.
|
542
|
+
"""
|
543
|
+
for arg in args:
|
544
|
+
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
545
|
+
for k, v in arg.items():
|
546
|
+
self._set_elem(k, v)
|
547
|
+
|
548
|
+
def split_values(self, *filters: type) -> Tuple[Dict, ...]:
|
549
|
+
"""
|
550
|
+
Split the values into several subsets of stack by the given types.
|
551
|
+
"""
|
552
|
+
results = tuple(DictManager() for _ in range(len(filters) + 1))
|
553
|
+
for k, v in self.items():
|
554
|
+
for i, filt in enumerate(filters):
|
555
|
+
if isinstance(v, filt):
|
556
|
+
results[i][k] = v.value
|
557
|
+
break
|
558
|
+
else:
|
559
|
+
results[-1][k] = v.value
|
560
|
+
return results
|
561
|
+
|
562
|
+
def collect_values(self) -> Dict:
|
563
|
+
"""
|
564
|
+
Collect the values by the given types.
|
565
|
+
"""
|
566
|
+
results = DictManager()
|
567
|
+
for k, v in self.items():
|
568
|
+
results[k] = v.value
|
569
|
+
return results
|
570
|
+
|
571
|
+
def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
|
572
|
+
return super().split(first, *others)
|
573
|
+
|
574
|
+
def to_dict_values(self) -> Dict:
|
575
|
+
"""
|
576
|
+
Convert the values into a dict.
|
577
|
+
"""
|
578
|
+
return {k: v.value for k, v in self.items()}
|
579
|
+
|
580
|
+
def _check_elem(self, elem):
|
581
|
+
assert isinstance(elem, State), f'must be instance of {State}'
|
582
|
+
|
583
|
+
def _set_elem(self, key: Any, value: Any) -> None:
|
584
|
+
self[key].value = value
|
585
|
+
|
586
|
+
|
587
|
+
class StateTraceStack(Generic[A]):
|
370
588
|
"""
|
371
|
-
|
372
|
-
if id_ not in self._id2index:
|
373
|
-
self._id2index[id_] = len(self.states)
|
374
|
-
self.states.append(state)
|
375
|
-
self.types.append('read')
|
376
|
-
self._org_values.append(state._value) # internal use
|
377
|
-
self.new_arg(state)
|
378
|
-
|
379
|
-
def write_its_value(self, state: State) -> None:
|
589
|
+
The state trace stack, which is used to trace the states automatically.
|
380
590
|
"""
|
381
|
-
Write the value of the state.
|
382
591
|
|
383
|
-
|
384
|
-
|
592
|
+
def __init__(self, new_arg: Callable = None):
|
593
|
+
self.states: List[State] = []
|
594
|
+
self.been_writen: List[bool] = [] # False: read, True: write
|
595
|
+
self._state_id_index = dict()
|
596
|
+
self._original_state_values = []
|
597
|
+
self._jax_trace_new_arg: Callable = new_arg
|
598
|
+
|
599
|
+
@property
|
600
|
+
def original_state_values(self) -> Tuple[PyTree, ...]:
|
601
|
+
"""
|
602
|
+
The original values of the states.
|
603
|
+
"""
|
604
|
+
return tuple(self._original_state_values)
|
605
|
+
|
606
|
+
def set_new_arg(self, new_arg: Callable) -> None:
|
607
|
+
self._jax_trace_new_arg = new_arg
|
608
|
+
|
609
|
+
def new_arg(self, state: State) -> None:
|
610
|
+
if self._jax_trace_new_arg is not None:
|
611
|
+
# internal use
|
612
|
+
state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
|
613
|
+
|
614
|
+
def __enter__(self) -> 'StateTraceStack':
|
615
|
+
TRACE_CONTEXT.state_stack.append(self)
|
616
|
+
return self
|
617
|
+
|
618
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
619
|
+
TRACE_CONTEXT.state_stack.pop()
|
620
|
+
|
621
|
+
def read_its_value(self, state: State) -> None:
|
622
|
+
"""
|
623
|
+
Read the value of the state.
|
624
|
+
|
625
|
+
Args:
|
626
|
+
state: The state.
|
627
|
+
"""
|
628
|
+
id_ = id(state)
|
629
|
+
if id_ not in self._state_id_index:
|
630
|
+
self._state_id_index[id_] = len(self.states)
|
631
|
+
self.states.append(state)
|
632
|
+
self.been_writen.append(False)
|
633
|
+
self._original_state_values.append(state._value) # internal use
|
634
|
+
self.new_arg(state)
|
635
|
+
|
636
|
+
def write_its_value(self, state: State) -> None:
|
637
|
+
"""
|
638
|
+
Write the value of the state.
|
639
|
+
|
640
|
+
Args:
|
641
|
+
state: The state.
|
642
|
+
"""
|
643
|
+
id_ = id(state)
|
644
|
+
if id_ not in self._state_id_index:
|
645
|
+
self.read_its_value(state)
|
646
|
+
index = self._state_id_index[id_]
|
647
|
+
self.been_writen[index] = True
|
648
|
+
|
649
|
+
def get_state_values(self, separate: bool = False, replace: bool = False
|
650
|
+
) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
651
|
+
"""
|
652
|
+
Get the values of the states.
|
653
|
+
"""
|
654
|
+
if separate:
|
655
|
+
if replace:
|
656
|
+
writes, reads = [], []
|
657
|
+
for st, been_writen in zip(self.states, self.been_writen):
|
658
|
+
if been_writen:
|
659
|
+
writes.append(st.value)
|
660
|
+
reads.append(None)
|
661
|
+
else:
|
662
|
+
reads.append(st.value)
|
663
|
+
writes.append(None)
|
664
|
+
return tuple(writes), tuple(reads)
|
665
|
+
else:
|
666
|
+
writes, reads = [], []
|
667
|
+
for st, been_writen in zip(self.states, self.been_writen):
|
668
|
+
if been_writen:
|
669
|
+
writes.append(st.value)
|
670
|
+
else:
|
671
|
+
reads.append(st.value)
|
672
|
+
return tuple(writes), tuple(reads)
|
673
|
+
else:
|
674
|
+
return tuple([st.value for st in self.states])
|
675
|
+
|
676
|
+
def recovery_original_values(self) -> None:
|
677
|
+
"""
|
678
|
+
Recovery the original values.
|
679
|
+
"""
|
680
|
+
for st, val in zip(self.states, self._original_state_values):
|
681
|
+
# internal use
|
682
|
+
st._value = val
|
683
|
+
|
684
|
+
def merge(self, *traces) -> 'StateTraceStack':
|
685
|
+
"""
|
686
|
+
Merge other state traces.
|
687
|
+
"""
|
688
|
+
trace: StateTraceStack
|
689
|
+
for trace in traces:
|
690
|
+
for st, been_writen, org_val in zip(trace.states, trace.been_writen, trace._original_state_values):
|
691
|
+
if id(st) not in self._state_id_index: # read the value
|
692
|
+
self._state_id_index[id(st)] = len(self.states)
|
693
|
+
self._original_state_values.append(org_val) # add the original value
|
694
|
+
self.states.append(st) # append the state
|
695
|
+
self.been_writen.append(False)
|
696
|
+
if been_writen:
|
697
|
+
self.write_its_value(st)
|
698
|
+
return self
|
699
|
+
|
700
|
+
def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
|
701
|
+
"""
|
702
|
+
Read the states that are read by the function.
|
703
|
+
|
704
|
+
Returns:
|
705
|
+
The states that are read by the function.
|
706
|
+
"""
|
707
|
+
if replace_writen:
|
708
|
+
return tuple([st if not been_writen else None
|
709
|
+
for st, been_writen in zip(self.states, self.been_writen)])
|
710
|
+
else:
|
711
|
+
return tuple([st for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
|
712
|
+
|
713
|
+
def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
|
714
|
+
"""
|
715
|
+
Read the states that are read by the function.
|
716
|
+
|
717
|
+
Returns:
|
718
|
+
The states that are read by the function.
|
719
|
+
"""
|
720
|
+
if replace_writen:
|
721
|
+
return tuple(
|
722
|
+
[st.value if not been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
|
723
|
+
else:
|
724
|
+
return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
|
725
|
+
|
726
|
+
def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
|
727
|
+
"""
|
728
|
+
Read the states that are written by the function.
|
729
|
+
|
730
|
+
Returns:
|
731
|
+
The states that are written by the function.
|
732
|
+
"""
|
733
|
+
if replace_read:
|
734
|
+
return tuple([st if been_writen else None
|
735
|
+
for st, been_writen in zip(self.states, self.been_writen)])
|
736
|
+
else:
|
737
|
+
return tuple([st for st, been_writen in zip(self.states, self.been_writen) if been_writen])
|
738
|
+
|
739
|
+
def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
|
740
|
+
"""
|
741
|
+
Read the states that are written by the function.
|
742
|
+
|
743
|
+
Returns:
|
744
|
+
The states that are written by the function.
|
745
|
+
"""
|
746
|
+
if replace_read:
|
747
|
+
return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
|
748
|
+
else:
|
749
|
+
return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if been_writen])
|
750
|
+
|
751
|
+
def __add__(self, other: 'StateTraceStack') -> 'StateTraceStack':
|
752
|
+
"""
|
753
|
+
Support the syntax of `+` to merge the state traces.
|
754
|
+
"""
|
755
|
+
return StateTraceStack().merge(self, other)
|
756
|
+
|
757
|
+
|
758
|
+
class TreefyState(Generic[A], PrettyRepr):
|
385
759
|
"""
|
386
|
-
|
387
|
-
if id_ not in self._id2index:
|
388
|
-
self.read_its_value(state)
|
389
|
-
if id_ not in self._written_ids:
|
390
|
-
index = self._id2index[id_]
|
391
|
-
self.types[index] = 'write'
|
392
|
-
self._written_ids.add(id_)
|
393
|
-
|
394
|
-
def collect_values(self, *categories: str, check_val_tree: bool = False) -> Tuple:
|
760
|
+
The state as a pytree.
|
395
761
|
"""
|
396
|
-
Collect the values by the given categories.
|
397
762
|
|
398
|
-
|
399
|
-
|
400
|
-
|
763
|
+
def __init__(
|
764
|
+
self,
|
765
|
+
type: type[State[Any]],
|
766
|
+
value: A,
|
767
|
+
**metadata
|
768
|
+
):
|
769
|
+
self.type = type
|
770
|
+
self.value = value
|
771
|
+
vars(self).update(metadata)
|
772
|
+
|
773
|
+
if TYPE_CHECKING:
|
774
|
+
def __getattr__(self, name: str) -> None: ...
|
775
|
+
|
776
|
+
def __setattr__(self, name: str, value: Any) -> None: ...
|
777
|
+
|
778
|
+
def __delattr__(self, name: str) -> None: ...
|
779
|
+
|
780
|
+
def __pretty_repr__(self):
|
781
|
+
yield PrettyType(type=type(self))
|
782
|
+
yield PrettyAttr('type', self.type.__name__)
|
783
|
+
for name, value in vars(self).items():
|
784
|
+
if name == '_value':
|
785
|
+
name = 'value'
|
786
|
+
if name == '_name':
|
787
|
+
if value is None:
|
788
|
+
continue
|
789
|
+
else:
|
790
|
+
name = 'name'
|
791
|
+
if name in ['_trace_state', '_level', '_source_info', 'type']:
|
792
|
+
continue
|
793
|
+
yield PrettyAttr(name, repr(value))
|
794
|
+
|
795
|
+
def __treescope_repr__(self, path, subtree_renderer):
|
796
|
+
children = {'type': self.type}
|
797
|
+
for name, value in vars(self).items():
|
798
|
+
if name == 'type':
|
799
|
+
continue
|
800
|
+
children[name] = value
|
801
|
+
|
802
|
+
import treescope # type: ignore[import-not-found,import-untyped]
|
803
|
+
return treescope.repr_lib.render_object_constructor(
|
804
|
+
object_type=type(self),
|
805
|
+
attributes=children,
|
806
|
+
path=path,
|
807
|
+
subtree_renderer=subtree_renderer,
|
808
|
+
)
|
401
809
|
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
810
|
+
def replace(self, value: B) -> TreefyState[B]:
|
811
|
+
"""
|
812
|
+
Replace the value of the state reference.
|
813
|
+
"""
|
814
|
+
return TreefyState(self.type, value, **self.get_metadata())
|
815
|
+
|
816
|
+
def to_state(self) -> State[A]:
|
817
|
+
"""
|
818
|
+
Convert the state reference to the state.
|
819
|
+
"""
|
820
|
+
# we use object.__new__ to avoid calling __init__ and bypass the
|
821
|
+
# __init__ logic which should not be called twice
|
822
|
+
metadata = self.get_metadata()
|
823
|
+
state = object.__new__(self.type)
|
824
|
+
vars(state).update(metadata, _value=self.value, _trace_state=StateJaxTracer(), _level=_get_trace_stack_level())
|
825
|
+
return state
|
826
|
+
|
827
|
+
def copy(self: TreefyState[A]) -> TreefyState[A]:
|
828
|
+
"""
|
829
|
+
Copy the state reference.
|
830
|
+
"""
|
831
|
+
return jax.tree.map(lambda x: x, self)
|
832
|
+
|
833
|
+
def get_metadata(self) -> Dict[str, Any]:
|
834
|
+
"""
|
835
|
+
Get the metadata of the state reference
|
836
|
+
"""
|
837
|
+
metadata = vars(self).copy()
|
838
|
+
del metadata['type']
|
839
|
+
del metadata['value']
|
840
|
+
return metadata
|
841
|
+
|
842
|
+
|
843
|
+
def _state_ref_flatten(x: TreefyState[Any], *, with_keys: bool):
|
844
|
+
metadata = tuple(x.get_metadata().items())
|
845
|
+
if with_keys:
|
846
|
+
node = (jax.tree_util.GetAttrKey('value'), x.value)
|
847
|
+
else:
|
848
|
+
node = x.value
|
849
|
+
return (node,), (x.type, metadata)
|
850
|
+
|
851
|
+
|
852
|
+
def _state_ref_unflatten(
|
853
|
+
static: Tuple[type[State[A]], Tuple[Tuple[str, Any], ...]],
|
854
|
+
children: Tuple[A],
|
855
|
+
) -> TreefyState[A]:
|
856
|
+
return TreefyState(type=static[0], value=children[0], **dict(static[1]))
|
857
|
+
|
858
|
+
|
859
|
+
jax.tree_util.register_pytree_with_keys(
|
860
|
+
TreefyState,
|
861
|
+
partial(_state_ref_flatten, with_keys=True), # type: ignore
|
862
|
+
_state_ref_unflatten, # type: ignore
|
863
|
+
flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
|
864
|
+
)
|